217 lines
6.5 KiB
Python
217 lines
6.5 KiB
Python
import mimetypes
|
|
import os
|
|
import pathlib
|
|
import socketserver
|
|
import sqlite3
|
|
import sys
|
|
import threading
|
|
import traceback
|
|
from xmlrpc.server import SimpleXMLRPCDispatcher, SimpleXMLRPCRequestHandler
|
|
import pillow_avif
|
|
import sqlite_vec
|
|
from watchdog.observers import Observer
|
|
from watchdog.events import FileSystemEventHandler, FileOpenedEvent
|
|
import model
|
|
|
|
|
|
print("Connecting to DB")
|
|
con = sqlite3.connect("index.db", check_same_thread=False)
|
|
con.enable_load_extension(True)
|
|
sqlite_vec.load(con)
|
|
cur = con.cursor()
|
|
cur.execute(
|
|
"CREATE TABLE IF NOT EXISTS idx (id INTEGER PRIMARY KEY, parent INTEGER, time INTEGER, path TEXT)"
|
|
)
|
|
cur.execute(
|
|
"CREATE VIRTUAL TABLE IF NOT EXISTS emb USING vec0(id INTEGER PRIMARY KEY, embedding float[1024] distance_metric=cosine)"
|
|
)
|
|
con.commit()
|
|
lock = threading.Lock()
|
|
|
|
|
|
def get_parent(path):
|
|
if path in watchdirs:
|
|
parent = "/"
|
|
else:
|
|
parent = pathlib.Path(path).parent
|
|
return os.stat(parent).st_ino
|
|
|
|
|
|
class EventHandler(FileSystemEventHandler):
|
|
def dispatch(self, event):
|
|
if not isinstance(event, FileOpenedEvent):
|
|
with lock:
|
|
print(event)
|
|
super().dispatch(event)
|
|
|
|
def on_created(self, event):
|
|
index(event.src_path, get_parent(event.src_path))
|
|
|
|
def on_modified(self, event):
|
|
if not event.is_directory:
|
|
self.on_created(event)
|
|
|
|
def on_deleted(self, event):
|
|
res = cur.execute("SELECT id FROM idx WHERE path = ?", (event.src_path,))
|
|
ids = res.fetchall()
|
|
if len(ids) == 1:
|
|
unindex(ids[0][0])
|
|
|
|
def on_moved(self, event):
|
|
# inode doesn't change after move
|
|
s = os.stat(event.dest_path)
|
|
cur.execute(
|
|
"INSERT OR REPLACE INTO idx VALUES (?, ?, ?, ?)",
|
|
(s.st_ino, get_parent(event.dest_path), s.st_mtime, event.dest_path),
|
|
)
|
|
cur.execute(
|
|
"UPDATE idx SET path = replace(path, ?, ?)",
|
|
(event.src_path, event.dest_path),
|
|
)
|
|
con.commit()
|
|
|
|
|
|
def index(path, parent):
|
|
if not os.path.exists(path) or os.path.basename(path).startswith("."):
|
|
# Skip nonexistent or hidden files
|
|
return
|
|
|
|
print("Indexing", path, parent)
|
|
s = os.stat(path)
|
|
if os.path.isfile(path):
|
|
res = cur.execute(
|
|
"SELECT time, parent, path FROM idx WHERE id = ?", (s.st_ino,)
|
|
)
|
|
db_vals = res.fetchall()
|
|
if len(db_vals) == 1 and (s.st_mtime, parent, path) == db_vals[0]:
|
|
# Already in DB, unmodified
|
|
return
|
|
|
|
if (
|
|
len(db_vals) == 0
|
|
or s.st_mtime != db_vals[0][1]
|
|
or len(
|
|
cur.execute("SELECT 1 FROM emb WHERE id = ?", (s.st_ino,)).fetchall()
|
|
)
|
|
== 0
|
|
):
|
|
# Modified or not in emb
|
|
emb = None
|
|
type = mimetypes.guess_type(path)[0]
|
|
if isinstance(type, str):
|
|
try:
|
|
if type.startswith("audio"):
|
|
emb = model.embed_audio(path)
|
|
elif type.startswith("image"):
|
|
emb = model.embed_image(path)
|
|
elif type.startswith("video") and os.path.getsize(path) < 2**25:
|
|
emb = model.embed_video(path)
|
|
except:
|
|
print(traceback.format_exc())
|
|
|
|
if emb is None:
|
|
# Might be in index but no longer valid
|
|
unindex(s.st_ino)
|
|
return
|
|
|
|
# sqlite-vec doesn't support INSERT OR REPLACE?
|
|
cur.execute("DELETE FROM emb WHERE id = ?", (s.st_ino,))
|
|
cur.execute("INSERT INTO emb VALUES (?, ?)", (s.st_ino, emb.cpu().numpy()))
|
|
|
|
cur.execute(
|
|
"INSERT OR REPLACE INTO idx VALUES (?, ?, ?, ?)",
|
|
(s.st_ino, parent, s.st_mtime, path),
|
|
)
|
|
con.commit()
|
|
|
|
if os.path.isdir(path):
|
|
if parent:
|
|
children = os.listdir(path)
|
|
else:
|
|
children = watchdirs
|
|
|
|
# Find and unindex dead children
|
|
children_id = set(
|
|
os.stat(os.path.join(path, child)).st_ino for child in children
|
|
)
|
|
res = cur.execute("SELECT id FROM idx WHERE parent = ?", (s.st_ino,))
|
|
db_children_id = res.fetchall()
|
|
for db_child_id in db_children_id:
|
|
if db_child_id[0] not in children_id:
|
|
# Don't unemb because might be a move
|
|
unindex(db_child_id[0], False)
|
|
|
|
# Index live children
|
|
for child in children:
|
|
if not parent:
|
|
observer.schedule(event_handler, child, recursive=True)
|
|
index(os.path.join(path, child), s.st_ino)
|
|
|
|
|
|
def unindex(id, unemb=True):
|
|
print("Unindexing", id)
|
|
res = cur.execute("SELECT id FROM idx WHERE parent = ?", (id,))
|
|
db_children_id = res.fetchall()
|
|
for db_child_id in db_children_id:
|
|
unindex(db_child_id[0])
|
|
cur.execute("DELETE FROM idx WHERE id = ?", (id,))
|
|
if unemb:
|
|
cur.execute("DELETE FROM emb WHERE id = ?", (id,))
|
|
con.commit()
|
|
|
|
|
|
def search(text, limit):
|
|
print("Search", text, limit)
|
|
emb = model.embed_text(text).cpu().numpy()
|
|
res = cur.execute(
|
|
"SELECT idx.path FROM emb JOIN idx ON emb.id = idx.id WHERE embedding MATCH ? AND k = ? ORDER BY distance",
|
|
(emb, limit),
|
|
)
|
|
return [i[0] for i in res.fetchall()]
|
|
|
|
|
|
print("Indexing files")
|
|
watchdirs = set(map(os.path.abspath, sys.argv[1:]))
|
|
observer = Observer()
|
|
observer.start()
|
|
event_handler = EventHandler()
|
|
with lock:
|
|
# Pretend that / is the parent of all indexed dirs
|
|
index("/", 0)
|
|
# Clean up emb
|
|
cur.execute("DELETE FROM emb WHERE id NOT IN (SELECT id FROM idx)")
|
|
con.commit()
|
|
|
|
|
|
class UnixStreamXMLRPCRequestHandler(SimpleXMLRPCRequestHandler):
|
|
disable_nagle_algorithm = False
|
|
|
|
def address_string(self):
|
|
return self.client_address
|
|
|
|
|
|
class UnixStreamXMLRPCServer(socketserver.UnixStreamServer, SimpleXMLRPCDispatcher):
|
|
def __init__(
|
|
self,
|
|
addr,
|
|
log_requests=True,
|
|
allow_none=True,
|
|
encoding=None,
|
|
bind_and_activate=True,
|
|
use_builtin_types=True,
|
|
):
|
|
self.logRequests = log_requests
|
|
SimpleXMLRPCDispatcher.__init__(self, allow_none, encoding, use_builtin_types)
|
|
socketserver.UnixStreamServer.__init__(
|
|
self,
|
|
addr,
|
|
UnixStreamXMLRPCRequestHandler,
|
|
bind_and_activate,
|
|
)
|
|
|
|
|
|
print("Starting RPC server")
|
|
pathlib.Path("search.sock").unlink(missing_ok=True)
|
|
server = UnixStreamXMLRPCServer("search.sock")
|
|
server.register_function(search)
|
|
server.serve_forever()
|