Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 60 additions & 5 deletions lotus/sem_ops/sem_index.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
from typing import Any

import pandas as pd
Expand All @@ -9,12 +10,13 @@
@pd.api.extensions.register_dataframe_accessor("sem_index")
class SemIndexDataframe:
"""
Create a vecgtor similarity index over a column in the DataFrame. Indexing is required for columns used in sem_search, sem_cluster_by, and sem_sim_join.
Create a vector similarity index over a column in the DataFrame. Indexing is required for columns used in sem_search, sem_cluster_by, and sem_sim_join.
When using retrieval-based cascades for sem_filter and sem_join, indexing is required for the columns used in the semantic operation.

Args:
col_name (str): The column name to index.
index_dir (str): The directory to save the index.
index_dir (str): The directory to save the index. Required to prevent column name collisions.
override (bool): If True, recreate index even if it exists and data is consistent. Defaults to False.

Returns:
pd.DataFrame: The DataFrame with the index directory saved.
Expand Down Expand Up @@ -46,6 +48,9 @@ class SemIndexDataframe:
title
0 Machine learning tutorial
1 Data science guide

# Example 3: force recreation of index with override=True
>>> df.sem_index('title', 'title_index', override=True) ## recreates index even if it exists
"""

def __init__(self, pandas_obj: Any) -> None:
Expand All @@ -59,7 +64,18 @@ def _validate(obj: Any) -> None:
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(self, col_name: str, index_dir: str) -> pd.DataFrame:
def __call__(self, col_name: str, index_dir: str, override: bool = False) -> pd.DataFrame:
"""
Create or load a semantic index for the specified column.

Args:
col_name: Name of the column to index
index_dir: Directory where the index should be stored/loaded from
override: If True, recreate index even if it exists and data is consistent

Returns:
DataFrame with index directory stored in attrs
"""
lotus.logger.warning(
"Do not reset the dataframe index to ensure proper functionality of get_vectors_from_index"
)
Expand All @@ -71,7 +87,46 @@ def __call__(self, col_name: str, index_dir: str) -> pd.DataFrame:
"The retrieval model must be an instance of RM, and the vector store must be an instance of VS. Please configure a valid retrieval model using lotus.settings.configure()"
)

embeddings = rm(self._obj[col_name].tolist())
vs.index(self._obj[col_name], embeddings, index_dir)
# Get data from column
data = self._obj[col_name].tolist()
model_name = getattr(rm, "model", None)

# Check if index exists and data is consistent.
index_exists = vs.index_exists(index_dir)
data_consistent = False

if index_exists:
data_consistent = vs.is_data_consistent(index_dir, data, model_name)

# Determine if we need to create a new index.
should_create_index = not index_exists or not data_consistent or override

if should_create_index:
# Index does not exist, data is inconsistent, or override requested. Creating new index.
if index_exists and not data_consistent and not override:
raise ValueError(
f"Index exists at {index_dir} but data is inconsistent. "
f"Set override=True to recreate the index or use a different index_dir."
)

# Create data hash for consistency checking.
content = str(sorted(data))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hash creation is better as a function in the vector store class so that its not repeated everywhere.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can just pass the data in the index function and handle hash creation inside the class itself.
Creating a common function like hash_data that is used in both store_metadata and is_data_consistent would simplify things

if model_name:
content += f"__{model_name}"
data_hash = hashlib.sha256(content.encode()).hexdigest()[:32]

# Create new index.
embeddings = rm(data)
vs.index(self._obj[col_name], embeddings, index_dir)

# Store metadata for data consistency checking (FAISS only).
if hasattr(vs, "_store_metadata"):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be called directly in index function.
Definitely not good to call a private function outside class.

vs._store_metadata(index_dir, data_hash)
lotus.logger.info(f"Created new index at {index_dir}")
else:
# Load existing index.
vs.load_index(index_dir)
lotus.logger.info(f"Loaded existing index from {index_dir}")

self._obj.attrs["index_dirs"][col_name] = index_dir
return self._obj
11 changes: 9 additions & 2 deletions lotus/sem_ops/sem_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from lotus.cache import operator_cache
from lotus.templates import task_instructions
from lotus.types import CascadeArgs, ReasoningStrategy, SemanticJoinOutput
from lotus.utils import show_safe_mode
from lotus.utils import get_index_cache, show_safe_mode

from .cascade_utils import calibrate_sem_sim_join, importance_sampling, learn_cascade_thresholds
from .sem_filter import sem_filter
Expand Down Expand Up @@ -355,7 +355,14 @@ def run_sem_sim_join(l1: pd.Series, l2: pd.Series, col1_label: str, col2_label:
lotus.logger.error("l1 must be a pandas Series or DataFrame")

l2_df = l2.to_frame(name=col2_label)
l2_df = l2_df.sem_index(col2_label, f"{col2_label}_index")

# Use get_index_cache to create deterministic cache directory for l2 data
rm = lotus.settings.rm
model_name = getattr(rm, "model", None)
cache_dir = get_index_cache(col2_label, l2.tolist(), model_name)

# This will reuse existing index if data is consistent, preventing duplicate creation
l2_df = l2_df.sem_index(col2_label, cache_dir)

K = len(l2)
# Run sem_sim_join as helper on the sampled data
Expand Down
31 changes: 31 additions & 0 deletions lotus/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import time
from io import BytesIO
from pathlib import Path
from typing import Callable

import numpy as np
Expand Down Expand Up @@ -132,3 +133,33 @@ def show_safe_mode(estimated_cost, estimated_LM_calls):
except KeyboardInterrupt:
print("\nExecution cancelled by user")
exit(0)


def get_index_cache(col_name: str, data: list, model_name: str | None = None) -> str:
"""
Get cache path for semantic index. Uses a stable cache directory based on column name and model.
Data consistency is handled by sem_index() which checks if the cached data matches the current data.

Args:
col_name: Name of the column being indexed
data: Data to index (used for consistency checking in sem_index, not for cache path)
model_name: Name of embedding model (different models = different embeddings)

Returns:
Path a string of the cache directory in ~/.cache/lotus/indices/{col_name}_{model}

Example:
>>> data = ['Tech', 'Food', 'Sports']
>>> path = get_index_cache('category', data, 'text-embedding-3-small')
>>> # Returns: "~/.cache/lotus/indices/category_text-embedding-3-small"
"""
# Create stable cache directory based on column name and model
cache_name = col_name
if model_name:
cache_name += f"_{model_name}"

# build cache location at ~/.cache/lotus/
cache_path = Path("~/.cache/lotus/indices").expanduser() / cache_name
cache_path.mkdir(parents=True, exist_ok=True)

return str(cache_path)
15 changes: 14 additions & 1 deletion lotus/vector_store/faiss_vs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import pickle
from typing import Any
Expand All @@ -19,7 +20,13 @@ def __init__(self, factory_string: str = "Flat", metric=faiss.METRIC_INNER_PRODU
self.faiss_index: faiss.Index | None = None
self.vecs: NDArray[np.float64] | None = None

def index(self, docs: list[str], embeddings: NDArray[np.float64], index_dir: str, **kwargs: dict[str, Any]) -> None:
def index(
self,
docs: list[str],
embeddings: NDArray[np.float64],
index_dir: str,
**kwargs: dict[str, Any],
) -> None:
self.faiss_index = faiss.index_factory(embeddings.shape[1], self.factory_string, self.metric)
self.faiss_index.add(embeddings)
self.index_dir = index_dir
Expand All @@ -29,6 +36,12 @@ def index(self, docs: list[str], embeddings: NDArray[np.float64], index_dir: str
pickle.dump(embeddings, fp)
faiss.write_index(self.faiss_index, f"{index_dir}/index")

def _store_metadata(self, index_dir: str, data_hash: str):
"""Store metadata for data consistency checking (FAISS-specific)"""
metadata = {"data_hash": data_hash}
with open(f"{index_dir}/metadata.json", "w") as fp:
json.dump(metadata, fp)

def load_index(self, index_dir: str) -> None:
self.index_dir = index_dir
self.faiss_index = faiss.read_index(f"{index_dir}/index")
Expand Down
36 changes: 36 additions & 0 deletions lotus/vector_store/vs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import hashlib
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any

import numpy as np
Expand Down Expand Up @@ -56,3 +59,36 @@ def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.f
Retrieve vectors from a stored index given specific ids.
"""
pass

def index_exists(self, index_dir: str) -> bool:
"""
Check if an index exists at the given directory.
Default implementation checks for common index files.
Subclasses can override for vector store specific checks.
"""
index_path = Path(index_dir)
return index_path.exists() and (index_path / "index").exists()

def is_data_consistent(self, index_dir: str, data: list, model_name: str | None = None) -> bool:
"""
Check if the cached index is consistent with the current data.
Default implementation compares data hash stored in metadata.
Subclasses can override for more sophisticated consistency checks.
"""
metadata_path = Path(index_dir) / "metadata.json"
if not metadata_path.exists():
return False

try:
with open(metadata_path, "r") as f:
metadata = json.load(f)

# Create hash of current data
content = str(sorted(data))
if model_name:
content += f"__{model_name}"
current_hash = hashlib.sha256(content.encode()).hexdigest()[:32]

return metadata.get("data_hash") == current_hash
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should verify the index configs as well.

The data in vector db depends on the original data, model to create embeddings, and the vector store configs.

except (json.JSONDecodeError, KeyError):
return False
Loading