Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions litesearch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
__version__ = "0.0.5"
__version__ = "0.1.0"
from .postfix import usearch_fix
usearch_fix()
from .core import *
from .data import *
from .utils import *
3 changes: 3 additions & 0 deletions litesearch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def setup_db(pth_or_uri:str=':memory:', # the database name or URL
_db = Database(pth_or_uri, **kw)
if wal: _db.enable_wal()
if not sem_search: return _db
# Lazy initialization: apply usearch fix only when semantic search is enabled
from .postfix import usearch_fix
usearch_fix()
from usearch import sqlite_path
_db.conn.enableloadextension(True)
_db.conn.loadextension(sqlite_path())
Expand Down
25 changes: 16 additions & 9 deletions litesearch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def ext_im(self: Document, it=None):
if pix0.alpha: pix0 = Pixmap(pix0, 0) # remove alpha channel
mask = Pixmap(self.extract_image(smask)['image'])
try: pix = Pixmap(pix0, mask)
except: pix = Pixmap(self.extract_image(xref)['image'])
except (RuntimeError, ValueError, KeyError): pix = Pixmap(self.extract_image(xref)['image'])
ext = 'pam' if pix0.n > 3 else 'png'
return dict(ext=ext, colorspace=pix.colorspace.n, image=pix.tobytes(ext))
if '/ColorSpace' in self.xref_object(xref, compressed=True):
Expand Down Expand Up @@ -91,22 +91,28 @@ def pdf_ingest(

# %% ../nbs/02_data.ipynb 7
def clean(q:str # query to be passed for fts search
):
'''Clean the query by removing * and returning None for empty queries.'''
return q.replace('*', '') if q.strip() else None
) -> str:
'''Clean the query by removing * and returning empty string for empty queries.'''
if not q or not q.strip():
return ''
return q.replace('*', '')

def add_wc(q:str # query to be passed for fts search
):
) -> str:
'''Add wild card * to each word in the query.'''
if not q or not q.strip():
return ''
return ' '.join(map(lambda w: w + '*', q.split(' ')))

def mk_wider(q:str # query to be passed for fts search
):
) -> str:
'''Widen the query by joining words with OR operator.'''
if not q or not q.strip():
return ''
return ' OR '.join(map(lambda w: f'{w}', q.split(' ')))

def kw(q:str # query to be passed for fts search
):
) -> str:
'''Extract keywords from the query using YAKE library.'''
from yake import KeywordExtractor as KW
return ' '.join((set(concat([k.split(' ') for k, s in KW().extract_keywords(q)]))))
Expand All @@ -115,10 +121,11 @@ def pre(q:str, # query to be passed for fts search
wc=True, # add wild card to each word
wide=True, # widen the query with OR operator
extract_kw=True # extract keywords from the query
):
) -> str:
'''Preprocess the query for fts search.'''
q = clean(q)
if not q.strip(): return ''
if not q:
return ''
if extract_kw: q = kw(q)
if wc: q = add_wc(q)
if wide: q = mk_wider(q)
Expand Down
14 changes: 13 additions & 1 deletion litesearch/postfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,36 @@
import os, subprocess, platform

# %% ../nbs/00_postfix.ipynb 5
_usearch_fix_applied = False

def usearch_fix():
"""Apply usearch macOS fix if on macOS."""
"""Apply usearch macOS fix if on macOS. Safe to call multiple times."""
global _usearch_fix_applied
if _usearch_fix_applied:
return # Already applied

print('Applying usearch macOS fix if required...')
try:
from usearch import sqlite_path
dylib_path = sqlite_path()+'.dylib'
print('usearch dylib path: ', dylib_path)
if platform.system() != "Darwin":
print('Not on macOS, skipping usearch fix.')
_usearch_fix_applied = True
return
cmd = ['install_name_tool', '-add_rpath', '/usr/lib', dylib_path]
r = subprocess.run(cmd, capture_output=True, text=True, check=True)
if r.returncode == 0: print(f'✓ Applied usearch fix: Added /usr/lib rpath to {dylib_path}')
else: print(f'✗ Failed to apply fix: {r.stderr}')
_usearch_fix_applied = True
except ImportError as ie:
print('Warning: usearch not installed or import failed. you might need to install libsqlite3-dev. '
'For macs do `brew install libsqlite3-dev`. For linux `apt install libsqlite3-dev`. ', ie)
except subprocess.CalledProcessError as e:
# rpath already exists is not an error
if 'duplicate' in str(e.stderr):
_usearch_fix_applied = True
return
print(f'✗ install_name_tool failed: {e}')
print(f'Command output: {e.output}')
print(f'Command stderr: {e.stderr}')
Expand Down
9 changes: 5 additions & 4 deletions litesearch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def __init__(self,
prompt = prompt or (model_dict.prompt if model_dict else AttrDictDefault())
store_attr()
try: self.md = download_model(repo_id=repo_id, md=md, token=hf_token)
except Exception as ex: print(f'model download failed: {ex}. hint: is hf_token set')
except Exception as ex:
raise RuntimeError(f'Failed to download model {repo_id}: {ex}. Hint: is HF_TOKEN set?') from ex
self._load_enc()
def _load_enc(self):
try:
Expand All @@ -56,8 +57,7 @@ def _load_enc(self):
self._load_tok()
self.sess = ort.InferenceSession(onnx_p, sess_opt, providers=["CPUExecutionProvider"])
except Exception as ex:
print(f'Encoding setup errored out with exception: {ex}')
self.sess = None
raise RuntimeError(f'Failed to initialize ONNX session: {ex}') from ex
def _load_tok(self):
cfg = json.load(open(os.path.join(self.md, "config.json")))
tok_cfg = json.load(open(os.path.join(self.md, "tokenizer_config.json")))
Expand All @@ -80,7 +80,8 @@ def _mp(self, mout: np.ndarray, msk: np.ndarray):
sum_mask = np.clip(np.sum(input_mask_expanded, axis=1), a_min=1e-9, a_max=None)
return sum_embeddings / sum_mask
def encode(self, lns:list, **kw):
if not self.sess: print('ONNX session not initialized properly. Fix error during initialisation'); return None
if not self.sess:
raise RuntimeError('ONNX session not initialized. Check initialization errors.')
if not lns: return np.zeros((0, self.sess.get_outputs()[0].shape[-1]), dtype=self.dtype)
ids, msk = self._enc(lns)
if ids.ndim ==1: ids, msk = np.expand_dims(ids, axis=0), np.expand_dims(msk, axis=0)
Expand Down
28 changes: 2 additions & 26 deletions nbs/00_postfix.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -58,31 +58,7 @@
"id": "aaef4f5a5e9a9eed",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def usearch_fix():\n",
" \"\"\"Apply usearch macOS fix if on macOS.\"\"\"\n",
" print('Applying usearch macOS fix if required...')\n",
" try:\n",
" from usearch import sqlite_path\n",
" dylib_path = sqlite_path()+'.dylib'\n",
" print('usearch dylib path: ', dylib_path)\n",
" if platform.system() != \"Darwin\": \n",
" print('Not on macOS, skipping usearch fix.')\n",
" return\n",
" cmd = ['install_name_tool', '-add_rpath', '/usr/lib', dylib_path]\n",
" r = subprocess.run(cmd, capture_output=True, text=True, check=True)\n",
" if r.returncode == 0: print(f'✓ Applied usearch fix: Added /usr/lib rpath to {dylib_path}')\n",
" else: print(f'✗ Failed to apply fix: {r.stderr}')\n",
" except ImportError as ie:\n",
" print('Warning: usearch not installed or import failed. you might need to install libsqlite3-dev. '\n",
" 'For macs do `brew install libsqlite3-dev`. For linux `apt install libsqlite3-dev`. ', ie)\n",
" except subprocess.CalledProcessError as e: \n",
" print(f'✗ install_name_tool failed: {e}')\n",
" print(f'Command output: {e.output}')\n",
" print(f'Command stderr: {e.stderr}')\n",
" except Exception as e: print(f'Unexpected error during fix: {e}')"
]
"source": "#| export\n_usearch_fix_applied = False\n\ndef usearch_fix():\n \"\"\"Apply usearch macOS fix if on macOS. Safe to call multiple times.\"\"\"\n global _usearch_fix_applied\n if _usearch_fix_applied:\n return # Already applied\n \n print('Applying usearch macOS fix if required...')\n try:\n from usearch import sqlite_path\n dylib_path = sqlite_path()+'.dylib'\n print('usearch dylib path: ', dylib_path)\n if platform.system() != \"Darwin\": \n print('Not on macOS, skipping usearch fix.')\n _usearch_fix_applied = True\n return\n cmd = ['install_name_tool', '-add_rpath', '/usr/lib', dylib_path]\n r = subprocess.run(cmd, capture_output=True, text=True, check=True)\n if r.returncode == 0: print(f'✓ Applied usearch fix: Added /usr/lib rpath to {dylib_path}')\n else: print(f'✗ Failed to apply fix: {r.stderr}')\n _usearch_fix_applied = True\n except ImportError as ie:\n print('Warning: usearch not installed or import failed. you might need to install libsqlite3-dev. '\n 'For macs do `brew install libsqlite3-dev`. For linux `apt install libsqlite3-dev`. ', ie)\n except subprocess.CalledProcessError as e: \n # rpath already exists is not an error\n if 'duplicate' in str(e.stderr):\n _usearch_fix_applied = True\n return\n print(f'✗ install_name_tool failed: {e}')\n print(f'Command output: {e.output}')\n print(f'Command stderr: {e.stderr}')\n except Exception as e: print(f'Unexpected error during fix: {e}')"
},
{
"cell_type": "code",
Expand Down Expand Up @@ -126,4 +102,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
22 changes: 2 additions & 20 deletions nbs/01_core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -100,25 +100,7 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def setup_db(pth_or_uri:str=':memory:', # the database name or URL\n",
" wal:bool=True, # use WAL mode\n",
" sem_search:bool=True, # enable usearch extensions\n",
" **kw, # additional args to pass to apswutils database\n",
" ) -> Database:\n",
" '''Set up a database connection and load usearch extensions. You can refer usearch docs on sqlite plugins here: <https://unum-cloud.github.io/USearch/sqlite/index.html>'''\n",
"\n",
" if isinstance(pth_or_uri, (str, Path)): Path(pth_or_uri).parent.mkdir(exist_ok=True)\n",
" _db = Database(pth_or_uri, **kw)\n",
" if wal: _db.enable_wal()\n",
" if not sem_search: return _db\n",
" from usearch import sqlite_path\n",
" _db.conn.enableloadextension(True)\n",
" _db.conn.loadextension(sqlite_path())\n",
" _db.conn.enableloadextension(False)\n",
" return _db"
]
"source": "#| export\ndef setup_db(pth_or_uri:str=':memory:', # the database name or URL\n wal:bool=True, # use WAL mode\n sem_search:bool=True, # enable usearch extensions\n **kw, # additional args to pass to apswutils database\n ) -> Database:\n '''Set up a database connection and load usearch extensions. You can refer usearch docs on sqlite plugins here: <https://unum-cloud.github.io/USearch/sqlite/index.html>'''\n\n if isinstance(pth_or_uri, (str, Path)): Path(pth_or_uri).parent.mkdir(exist_ok=True)\n _db = Database(pth_or_uri, **kw)\n if wal: _db.enable_wal()\n if not sem_search: return _db\n # Lazy initialization: apply usearch fix only when semantic search is enabled\n from .postfix import usearch_fix\n usearch_fix()\n from usearch import sqlite_path\n _db.conn.enableloadextension(True)\n _db.conn.loadextension(sqlite_path())\n _db.conn.enableloadextension(False)\n return _db"
},
{
"cell_type": "code",
Expand Down Expand Up @@ -200,4 +182,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
73 changes: 3 additions & 70 deletions nbs/02_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,39 +55,7 @@
"id": "fc014ece1162af13",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"@patch\n",
"def get_texts(self: Document, st=0, end=-1, **kw):\n",
"\treturn L(self[st:end]).map(lambda p: p.get_text(**kw))\n",
"\n",
"@patch\n",
"def get_links(self: Document, st=0, end=-1):\n",
"\treturn L(self[st:end]).map(lambda p: p.get_links()).concat()\n",
"\n",
"@patch\n",
"def ext_im(self: Document, it=None):\n",
" if not it: return None\n",
" assert isinstance(it, tuple) and len(it) > 2, 'Invalid image tuple'\n",
" xref, smask = it[0], it[1]\n",
" if smask > 0:\n",
" pix0 = Pixmap(self.extract_image(xref)['image'])\n",
" if pix0.alpha: pix0 = Pixmap(pix0, 0) # remove alpha channel\n",
" mask = Pixmap(self.extract_image(smask)['image'])\n",
" try: pix = Pixmap(pix0, mask)\n",
" except: pix = Pixmap(self.extract_image(xref)['image'])\n",
" ext = 'pam' if pix0.n > 3 else 'png'\n",
" return dict(ext=ext, colorspace=pix.colorspace.n, image=pix.tobytes(ext))\n",
" if '/ColorSpace' in self.xref_object(xref, compressed=True):\n",
" pix = Pixmap(csRGB, Pixmap(self, xref))\n",
" return dict(ext='png', colorspace=3, image=pix.tobytes('png'))\n",
" return self.extract_image(xref)\n",
"\n",
"@patch\n",
"def ext_imgs(self: Document, st=0, end=-1):\n",
"\tf = lambda p: [ext_im(self,it) for it in p.get_images(full=True)]\n",
"\treturn L(self[st:end]).map(f).concat()"
]
"source": "#| export\n@patch\ndef get_texts(self: Document, st=0, end=-1, **kw):\n\treturn L(self[st:end]).map(lambda p: p.get_text(**kw))\n\n@patch\ndef get_links(self: Document, st=0, end=-1):\n\treturn L(self[st:end]).map(lambda p: p.get_links()).concat()\n\n@patch\ndef ext_im(self: Document, it=None):\n if not it: return None\n assert isinstance(it, tuple) and len(it) > 2, 'Invalid image tuple'\n xref, smask = it[0], it[1]\n if smask > 0:\n pix0 = Pixmap(self.extract_image(xref)['image'])\n if pix0.alpha: pix0 = Pixmap(pix0, 0) # remove alpha channel\n mask = Pixmap(self.extract_image(smask)['image'])\n try: pix = Pixmap(pix0, mask)\n except (RuntimeError, ValueError, KeyError): pix = Pixmap(self.extract_image(xref)['image'])\n ext = 'pam' if pix0.n > 3 else 'png'\n return dict(ext=ext, colorspace=pix.colorspace.n, image=pix.tobytes(ext))\n if '/ColorSpace' in self.xref_object(xref, compressed=True):\n pix = Pixmap(csRGB, Pixmap(self, xref))\n return dict(ext='png', colorspace=3, image=pix.tobytes('png'))\n return self.extract_image(xref)\n\n@patch\ndef ext_imgs(self: Document, st=0, end=-1):\n\tf = lambda p: [ext_im(self,it) for it in p.get_images(full=True)]\n\treturn L(self[st:end]).map(f).concat()"
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -158,42 +126,7 @@
"id": "f887a2d1e48c7e1d",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def clean(q:str # query to be passed for fts search\n",
" ):\n",
" '''Clean the query by removing * and returning None for empty queries.'''\n",
" return q.replace('*', '') if q.strip() else None\n",
"\n",
"def add_wc(q:str # query to be passed for fts search\n",
" ):\n",
" '''Add wild card * to each word in the query.'''\n",
" return ' '.join(map(lambda w: w + '*', q.split(' ')))\n",
"\n",
"def mk_wider(q:str # query to be passed for fts search\n",
" ):\n",
" '''Widen the query by joining words with OR operator.'''\n",
" return ' OR '.join(map(lambda w: f'{w}', q.split(' ')))\n",
"\n",
"def kw(q:str # query to be passed for fts search\n",
" ):\n",
" '''Extract keywords from the query using YAKE library.'''\n",
" from yake import KeywordExtractor as KW\n",
" return ' '.join((set(concat([k.split(' ') for k, s in KW().extract_keywords(q)]))))\n",
"\n",
"def pre(q:str, # query to be passed for fts search\n",
" wc=True, # add wild card to each word\n",
" wide=True, # widen the query with OR operator\n",
" extract_kw=True # extract keywords from the query\n",
" ):\n",
" '''Preprocess the query for fts search.'''\n",
" q = clean(q)\n",
" if not q.strip(): return ''\n",
" if extract_kw: q = kw(q)\n",
" if wc: q = add_wc(q)\n",
" if wide: q = mk_wider(q)\n",
" return q"
]
"source": "#| export\ndef clean(q:str # query to be passed for fts search\n ) -> str:\n '''Clean the query by removing * and returning empty string for empty queries.'''\n if not q or not q.strip():\n return ''\n return q.replace('*', '')\n\ndef add_wc(q:str # query to be passed for fts search\n ) -> str:\n '''Add wild card * to each word in the query.'''\n if not q or not q.strip():\n return ''\n return ' '.join(map(lambda w: w + '*', q.split(' ')))\n\ndef mk_wider(q:str # query to be passed for fts search\n ) -> str:\n '''Widen the query by joining words with OR operator.'''\n if not q or not q.strip():\n return ''\n return ' OR '.join(map(lambda w: f'{w}', q.split(' ')))\n\ndef kw(q:str # query to be passed for fts search\n ) -> str:\n '''Extract keywords from the query using YAKE library.'''\n from yake import KeywordExtractor as KW\n return ' '.join((set(concat([k.split(' ') for k, s in KW().extract_keywords(q)]))))\n\ndef pre(q:str, # query to be passed for fts search\n wc=True, # add wild card to each word\n wide=True, # widen the query with OR operator\n extract_kw=True # extract keywords from the query\n ) -> str:\n '''Preprocess the query for fts search.'''\n q = clean(q)\n if not q:\n return ''\n if extract_kw: q = kw(q)\n if wc: q = add_wc(q)\n if wide: q = mk_wider(q)\n return q"
},
{
"cell_type": "code",
Expand Down Expand Up @@ -224,4 +157,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
Loading