search/server.py

217 lines
6.5 KiB
Python

import mimetypes
import os
import pathlib
import socketserver
import sqlite3
import sys
import threading
import traceback
from xmlrpc.server import SimpleXMLRPCDispatcher, SimpleXMLRPCRequestHandler
import pillow_avif
import sqlite_vec
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler, FileOpenedEvent
import model
print("Connecting to DB")
con = sqlite3.connect("index.db", check_same_thread=False)
con.enable_load_extension(True)
sqlite_vec.load(con)
cur = con.cursor()
cur.execute(
"CREATE TABLE IF NOT EXISTS idx (id INTEGER PRIMARY KEY, parent INTEGER, time INTEGER, path TEXT)"
)
cur.execute(
"CREATE VIRTUAL TABLE IF NOT EXISTS emb USING vec0(id INTEGER PRIMARY KEY, embedding float[1024] distance_metric=cosine)"
)
con.commit()
lock = threading.Lock()
def get_parent(path):
if path in watchdirs:
parent = "/"
else:
parent = pathlib.Path(path).parent
return os.stat(parent).st_ino
class EventHandler(FileSystemEventHandler):
def dispatch(self, event):
if not isinstance(event, FileOpenedEvent):
with lock:
print(event)
super().dispatch(event)
def on_created(self, event):
index(event.src_path, get_parent(event.src_path))
def on_modified(self, event):
if not event.is_directory:
self.on_created(event)
def on_deleted(self, event):
res = cur.execute("SELECT id FROM idx WHERE path = ?", (event.src_path,))
ids = res.fetchall()
if len(ids) == 1:
unindex(ids[0][0])
def on_moved(self, event):
# inode doesn't change after move
s = os.stat(event.dest_path)
cur.execute(
"INSERT OR REPLACE INTO idx VALUES (?, ?, ?, ?)",
(s.st_ino, get_parent(event.dest_path), s.st_mtime, event.dest_path),
)
cur.execute(
"UPDATE idx SET path = replace(path, ?, ?)",
(event.src_path, event.dest_path),
)
con.commit()
def index(path, parent):
if not os.path.exists(path) or os.path.basename(path).startswith("."):
# Skip nonexistent or hidden files
return
print("Indexing", path, parent)
s = os.stat(path)
if os.path.isfile(path):
res = cur.execute(
"SELECT time, parent, path FROM idx WHERE id = ?", (s.st_ino,)
)
db_vals = res.fetchall()
if len(db_vals) == 1 and (s.st_mtime, parent, path) == db_vals[0]:
# Already in DB, unmodified
return
if (
len(db_vals) == 0
or s.st_mtime != db_vals[0][1]
or len(
cur.execute("SELECT 1 FROM emb WHERE id = ?", (s.st_ino,)).fetchall()
)
== 0
):
# Modified or not in emb
emb = None
type = mimetypes.guess_type(path)[0]
if isinstance(type, str):
try:
if type.startswith("audio"):
emb = model.embed_audio(path)
elif type.startswith("image"):
emb = model.embed_image(path)
elif type.startswith("video") and os.path.getsize(path) < 2**25:
emb = model.embed_video(path)
except:
print(traceback.format_exc())
if emb is None:
# Might be in index but no longer valid
unindex(s.st_ino)
return
# sqlite-vec doesn't support INSERT OR REPLACE?
cur.execute("DELETE FROM emb WHERE id = ?", (s.st_ino,))
cur.execute("INSERT INTO emb VALUES (?, ?)", (s.st_ino, emb.cpu().numpy()))
cur.execute(
"INSERT OR REPLACE INTO idx VALUES (?, ?, ?, ?)",
(s.st_ino, parent, s.st_mtime, path),
)
con.commit()
if os.path.isdir(path):
if parent:
children = os.listdir(path)
else:
children = watchdirs
# Find and unindex dead children
children_id = set(
os.stat(os.path.join(path, child)).st_ino for child in children
)
res = cur.execute("SELECT id FROM idx WHERE parent = ?", (s.st_ino,))
db_children_id = res.fetchall()
for db_child_id in db_children_id:
if db_child_id[0] not in children_id:
# Don't unemb because might be a move
unindex(db_child_id[0], False)
# Index live children
for child in children:
if not parent:
observer.schedule(event_handler, child, recursive=True)
index(os.path.join(path, child), s.st_ino)
def unindex(id, unemb=True):
print("Unindexing", id)
res = cur.execute("SELECT id FROM idx WHERE parent = ?", (id,))
db_children_id = res.fetchall()
for db_child_id in db_children_id:
unindex(db_child_id[0])
cur.execute("DELETE FROM idx WHERE id = ?", (id,))
if unemb:
cur.execute("DELETE FROM emb WHERE id = ?", (id,))
con.commit()
def search(text, limit):
print("Search", text, limit)
emb = model.embed_text(text).cpu().numpy()
res = cur.execute(
"SELECT idx.path FROM emb JOIN idx ON emb.id = idx.id WHERE embedding MATCH ? AND k = ? ORDER BY distance",
(emb, limit),
)
return [i[0] for i in res.fetchall()]
print("Indexing files")
watchdirs = set(map(os.path.abspath, sys.argv[1:]))
observer = Observer()
observer.start()
event_handler = EventHandler()
with lock:
# Pretend that / is the parent of all indexed dirs
index("/", 0)
# Clean up emb
cur.execute("DELETE FROM emb WHERE id NOT IN (SELECT id FROM idx)")
con.commit()
class UnixStreamXMLRPCRequestHandler(SimpleXMLRPCRequestHandler):
disable_nagle_algorithm = False
def address_string(self):
return self.client_address
class UnixStreamXMLRPCServer(socketserver.UnixStreamServer, SimpleXMLRPCDispatcher):
def __init__(
self,
addr,
log_requests=True,
allow_none=True,
encoding=None,
bind_and_activate=True,
use_builtin_types=True,
):
self.logRequests = log_requests
SimpleXMLRPCDispatcher.__init__(self, allow_none, encoding, use_builtin_types)
socketserver.UnixStreamServer.__init__(
self,
addr,
UnixStreamXMLRPCRequestHandler,
bind_and_activate,
)
print("Starting RPC server")
pathlib.Path("search.sock").unlink(missing_ok=True)
server = UnixStreamXMLRPCServer("search.sock")
server.register_function(search)
server.serve_forever()