diff --git a/lotus/sem_ops/sem_index.py b/lotus/sem_ops/sem_index.py index b2ad72cd..61d2c328 100644 --- a/lotus/sem_ops/sem_index.py +++ b/lotus/sem_ops/sem_index.py @@ -1,3 +1,4 @@ +import hashlib from typing import Any import pandas as pd @@ -9,12 +10,13 @@ @pd.api.extensions.register_dataframe_accessor("sem_index") class SemIndexDataframe: """ - Create a vecgtor similarity index over a column in the DataFrame. Indexing is required for columns used in sem_search, sem_cluster_by, and sem_sim_join. + Create a vector similarity index over a column in the DataFrame. Indexing is required for columns used in sem_search, sem_cluster_by, and sem_sim_join. When using retrieval-based cascades for sem_filter and sem_join, indexing is required for the columns used in the semantic operation. Args: col_name (str): The column name to index. - index_dir (str): The directory to save the index. + index_dir (str): The directory to save the index. Required to prevent column name collisions. + override (bool): If True, recreate index even if it exists and data is consistent. Defaults to False. Returns: pd.DataFrame: The DataFrame with the index directory saved. @@ -46,6 +48,9 @@ class SemIndexDataframe: title 0 Machine learning tutorial 1 Data science guide + + # Example 3: force recreation of index with override=True + >>> df.sem_index('title', 'title_index', override=True) ## recreates index even if it exists """ def __init__(self, pandas_obj: Any) -> None: @@ -59,7 +64,18 @@ def _validate(obj: Any) -> None: raise AttributeError("Must be a DataFrame") @operator_cache - def __call__(self, col_name: str, index_dir: str) -> pd.DataFrame: + def __call__(self, col_name: str, index_dir: str, override: bool = False) -> pd.DataFrame: + """ + Create or load a semantic index for the specified column. + + Args: + col_name: Name of the column to index + index_dir: Directory where the index should be stored/loaded from + override: If True, recreate index even if it exists and data is consistent + + Returns: + DataFrame with index directory stored in attrs + """ lotus.logger.warning( "Do not reset the dataframe index to ensure proper functionality of get_vectors_from_index" ) @@ -71,7 +87,46 @@ def __call__(self, col_name: str, index_dir: str) -> pd.DataFrame: "The retrieval model must be an instance of RM, and the vector store must be an instance of VS. Please configure a valid retrieval model using lotus.settings.configure()" ) - embeddings = rm(self._obj[col_name].tolist()) - vs.index(self._obj[col_name], embeddings, index_dir) + # Get data from column + data = self._obj[col_name].tolist() + model_name = getattr(rm, "model", None) + + # Check if index exists and data is consistent. + index_exists = vs.index_exists(index_dir) + data_consistent = False + + if index_exists: + data_consistent = vs.is_data_consistent(index_dir, data, model_name) + + # Determine if we need to create a new index. + should_create_index = not index_exists or not data_consistent or override + + if should_create_index: + # Index does not exist, data is inconsistent, or override requested. Creating new index. + if index_exists and not data_consistent and not override: + raise ValueError( + f"Index exists at {index_dir} but data is inconsistent. " + f"Set override=True to recreate the index or use a different index_dir." + ) + + # Create data hash for consistency checking. + content = str(sorted(data)) + if model_name: + content += f"__{model_name}" + data_hash = hashlib.sha256(content.encode()).hexdigest()[:32] + + # Create new index. + embeddings = rm(data) + vs.index(self._obj[col_name], embeddings, index_dir) + + # Store metadata for data consistency checking (FAISS only). + if hasattr(vs, "_store_metadata"): + vs._store_metadata(index_dir, data_hash) + lotus.logger.info(f"Created new index at {index_dir}") + else: + # Load existing index. + vs.load_index(index_dir) + lotus.logger.info(f"Loaded existing index from {index_dir}") + self._obj.attrs["index_dirs"][col_name] = index_dir return self._obj diff --git a/lotus/sem_ops/sem_join.py b/lotus/sem_ops/sem_join.py index 14a4881a..1fad84d7 100644 --- a/lotus/sem_ops/sem_join.py +++ b/lotus/sem_ops/sem_join.py @@ -7,7 +7,7 @@ from lotus.cache import operator_cache from lotus.templates import task_instructions from lotus.types import CascadeArgs, ReasoningStrategy, SemanticJoinOutput -from lotus.utils import show_safe_mode +from lotus.utils import get_index_cache, show_safe_mode from .cascade_utils import calibrate_sem_sim_join, importance_sampling, learn_cascade_thresholds from .sem_filter import sem_filter @@ -355,7 +355,14 @@ def run_sem_sim_join(l1: pd.Series, l2: pd.Series, col1_label: str, col2_label: lotus.logger.error("l1 must be a pandas Series or DataFrame") l2_df = l2.to_frame(name=col2_label) - l2_df = l2_df.sem_index(col2_label, f"{col2_label}_index") + + # Use get_index_cache to create deterministic cache directory for l2 data + rm = lotus.settings.rm + model_name = getattr(rm, "model", None) + cache_dir = get_index_cache(col2_label, l2.tolist(), model_name) + + # This will reuse existing index if data is consistent, preventing duplicate creation + l2_df = l2_df.sem_index(col2_label, cache_dir) K = len(l2) # Run sem_sim_join as helper on the sampled data diff --git a/lotus/utils.py b/lotus/utils.py index bf75748b..a82a9ae6 100644 --- a/lotus/utils.py +++ b/lotus/utils.py @@ -1,6 +1,7 @@ import base64 import time from io import BytesIO +from pathlib import Path from typing import Callable import numpy as np @@ -132,3 +133,33 @@ def show_safe_mode(estimated_cost, estimated_LM_calls): except KeyboardInterrupt: print("\nExecution cancelled by user") exit(0) + + +def get_index_cache(col_name: str, data: list, model_name: str | None = None) -> str: + """ + Get cache path for semantic index. Uses a stable cache directory based on column name and model. + Data consistency is handled by sem_index() which checks if the cached data matches the current data. + + Args: + col_name: Name of the column being indexed + data: Data to index (used for consistency checking in sem_index, not for cache path) + model_name: Name of embedding model (different models = different embeddings) + + Returns: + Path a string of the cache directory in ~/.cache/lotus/indices/{col_name}_{model} + + Example: + >>> data = ['Tech', 'Food', 'Sports'] + >>> path = get_index_cache('category', data, 'text-embedding-3-small') + >>> # Returns: "~/.cache/lotus/indices/category_text-embedding-3-small" + """ + # Create stable cache directory based on column name and model + cache_name = col_name + if model_name: + cache_name += f"_{model_name}" + + # build cache location at ~/.cache/lotus/ + cache_path = Path("~/.cache/lotus/indices").expanduser() / cache_name + cache_path.mkdir(parents=True, exist_ok=True) + + return str(cache_path) diff --git a/lotus/vector_store/faiss_vs.py b/lotus/vector_store/faiss_vs.py index 6891d192..af12d6a0 100644 --- a/lotus/vector_store/faiss_vs.py +++ b/lotus/vector_store/faiss_vs.py @@ -1,3 +1,4 @@ +import json import os import pickle from typing import Any @@ -19,7 +20,13 @@ def __init__(self, factory_string: str = "Flat", metric=faiss.METRIC_INNER_PRODU self.faiss_index: faiss.Index | None = None self.vecs: NDArray[np.float64] | None = None - def index(self, docs: list[str], embeddings: NDArray[np.float64], index_dir: str, **kwargs: dict[str, Any]) -> None: + def index( + self, + docs: list[str], + embeddings: NDArray[np.float64], + index_dir: str, + **kwargs: dict[str, Any], + ) -> None: self.faiss_index = faiss.index_factory(embeddings.shape[1], self.factory_string, self.metric) self.faiss_index.add(embeddings) self.index_dir = index_dir @@ -29,6 +36,12 @@ def index(self, docs: list[str], embeddings: NDArray[np.float64], index_dir: str pickle.dump(embeddings, fp) faiss.write_index(self.faiss_index, f"{index_dir}/index") + def _store_metadata(self, index_dir: str, data_hash: str): + """Store metadata for data consistency checking (FAISS-specific)""" + metadata = {"data_hash": data_hash} + with open(f"{index_dir}/metadata.json", "w") as fp: + json.dump(metadata, fp) + def load_index(self, index_dir: str) -> None: self.index_dir = index_dir self.faiss_index = faiss.read_index(f"{index_dir}/index") diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py index 5cf18987..22cb9930 100644 --- a/lotus/vector_store/vs.py +++ b/lotus/vector_store/vs.py @@ -1,4 +1,7 @@ +import hashlib +import json from abc import ABC, abstractmethod +from pathlib import Path from typing import Any import numpy as np @@ -56,3 +59,36 @@ def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.f Retrieve vectors from a stored index given specific ids. """ pass + + def index_exists(self, index_dir: str) -> bool: + """ + Check if an index exists at the given directory. + Default implementation checks for common index files. + Subclasses can override for vector store specific checks. + """ + index_path = Path(index_dir) + return index_path.exists() and (index_path / "index").exists() + + def is_data_consistent(self, index_dir: str, data: list, model_name: str | None = None) -> bool: + """ + Check if the cached index is consistent with the current data. + Default implementation compares data hash stored in metadata. + Subclasses can override for more sophisticated consistency checks. + """ + metadata_path = Path(index_dir) / "metadata.json" + if not metadata_path.exists(): + return False + + try: + with open(metadata_path, "r") as f: + metadata = json.load(f) + + # Create hash of current data + content = str(sorted(data)) + if model_name: + content += f"__{model_name}" + current_hash = hashlib.sha256(content.encode()).hexdigest()[:32] + + return metadata.get("data_hash") == current_hash + except (json.JSONDecodeError, KeyError): + return False diff --git a/tests/test_index_cache.py b/tests/test_index_cache.py new file mode 100644 index 00000000..9db72d24 --- /dev/null +++ b/tests/test_index_cache.py @@ -0,0 +1,220 @@ +import shutil +from pathlib import Path + +import pandas as pd +import pytest + +import lotus +from lotus.models import SentenceTransformersRM +from lotus.sem_ops.sem_join import run_sem_sim_join +from lotus.utils import get_index_cache +from lotus.vector_store import FaissVS +from tests.base_test import BaseTest + +# Set logger level to DEBUG +lotus.logger.setLevel("DEBUG") + + +@pytest.fixture(scope="module") +def vs(): + return FaissVS() + + +@pytest.fixture(scope="module") +def rm(): + return SentenceTransformersRM(model="sentence-transformers/all-MiniLM-L6-v2") + + +@pytest.fixture +def sample_df(): + """ + Sample DataFrame for testing + """ + return pd.DataFrame({"category": ["Sports", "Food", "Video Games", "STEM"]}) + + +@pytest.fixture +def sample_df_2(): + """ + Second different DataFrame for additional testing + """ + return pd.DataFrame({"category": ["Sports", "STEM"]}) + + +class TestIndexCache(BaseTest): + """Test suite for index caching functionality""" + + @pytest.fixture(autouse=True) + def setup_vs(self, rm, vs): + lotus.settings.configure(rm=rm, vs=vs) + + def test_required_index_dir_parameter(self, sample_df): + """ + Test that index_dir parameter is required to prevent column name collisions + """ + df = sample_df.copy() + + # explicit index_dir provided + df_indexed = df.sem_index("category", "explicit_index_dir") + + # Should use the explicit directory + cache_dir = df_indexed.attrs["index_dirs"]["category"] + assert cache_dir == "explicit_index_dir" + assert Path(cache_dir).exists() + + # cleanup + if Path("explicit_index_dir").exists(): + shutil.rmtree("explicit_index_dir") + + def test_user_specified_index_directories(self, sample_df, sample_df_2): + """ + Test that user-specified index directories work correctly + """ + df1 = sample_df.copy() + df2 = sample_df_2.copy() + + df_indexed = df1.sem_index("category", "custom_index_name") + df_indexed_smart = df2.sem_index("category", "custom_index_name_2") + + dir_one = df_indexed.attrs["index_dirs"]["category"] + dir_two = df_indexed_smart.attrs["index_dirs"]["category"] + + assert dir_one == "custom_index_name" + assert dir_two == "custom_index_name_2" + assert Path(dir_one).exists() + assert Path(dir_two).exists() + assert dir_one != dir_two + + # delete added paths + if Path("custom_index_name").exists(): + shutil.rmtree("custom_index_name") + if Path("custom_index_name_2").exists(): + shutil.rmtree("custom_index_name_2") + + def test_cache_directory_location(self, sample_df): + """Test that cache is created in specified directory""" + df = sample_df.copy() + df_indexed = df.sem_index("category", "test_cache_dir") + + cache_dir = df_indexed.attrs["index_dirs"]["category"] + assert cache_dir == "test_cache_dir" + assert Path(cache_dir).exists() + + # cleanup + if Path("test_cache_dir").exists(): + shutil.rmtree("test_cache_dir") + + def test_cache_files_exist(self, sample_df): + """ + Test that cache creates both 'index' and 'vecs' files (FAISS creates this) + """ + df = sample_df.copy() + df_indexed = df.sem_index("category", "test_files_dir") + cache_dir = df_indexed.attrs["index_dirs"]["category"] + + assert Path(cache_dir).exists() + assert (Path(cache_dir) / "index").exists() + assert (Path(cache_dir) / "vecs").exists() + assert (Path(cache_dir) / "metadata.json").exists() + + # cleanup + if Path("test_files_dir").exists(): + shutil.rmtree("test_files_dir") + + def test_data_consistency(self, sample_df, sample_df_2): + """ + Test that calling sem_index with same data reuses the cache, while + different data creates new cache + """ + # same data should reuse cache + df1 = sample_df.copy() + df2 = sample_df.copy() + df3 = sample_df_2.copy() + + # get the dirs + first_cache_dir = df1.sem_index("category", "test_consistency_1").attrs["index_dirs"]["category"] + second_cache_dir = df2.sem_index("category", "test_consistency_1").attrs["index_dirs"]["category"] + third_cache_dir = df3.sem_index("category", "test_consistency_2").attrs["index_dirs"]["category"] + + # same data reuses cache, different data creates new + assert first_cache_dir == second_cache_dir + assert first_cache_dir != third_cache_dir + assert second_cache_dir != third_cache_dir + + # cleanup + if Path("test_consistency_1").exists(): + shutil.rmtree("test_consistency_1") + if Path("test_consistency_2").exists(): + shutil.rmtree("test_consistency_2") + + def test_override_flag(self, sample_df): + """ + Test that override flag forces recreation of index + """ + df = sample_df.copy() + + # Create initial index + df.sem_index("category", "test_override") + + # Verify index exists + assert Path("test_override").exists() + assert (Path("test_override") / "index").exists() + assert (Path("test_override") / "metadata.json").exists() + + # Create index with override=True - should not raise error + # and should recreate the index + df.sem_index("category", "test_override", override=True) + + # Verify index still exists and is valid + assert Path("test_override").exists() + assert (Path("test_override") / "index").exists() + assert (Path("test_override") / "metadata.json").exists() + + # cleanup + if Path("test_override").exists(): + shutil.rmtree("test_override") + + def test_data_inconsistency_raises_error(self, sample_df, sample_df_2): + """ + Test that inconsistent data raises error without override + """ + df1 = sample_df.copy() + df2 = sample_df_2.copy() # Different data + + # Create index with first dataset + df1.sem_index("category", "test_inconsistent") + + # Try to use same directory with different data - should raise error + with pytest.raises(ValueError, match="data is inconsistent"): + df2.sem_index("category", "test_inconsistent") + + # cleanup + if Path("test_inconsistent").exists(): + shutil.rmtree("test_inconsistent") + + def test_sem_join_no_duplicate_index_creation(self, sample_df): + """ + Test that sem_join_cascade doesn't create duplicate indices + """ + df1 = sample_df.copy() + df2 = sample_df.copy() + + # Track number of index creations by checking logs or index files + # First call + run_sem_sim_join(df1["category"], df2["category"], "cat1", "cat2") + + # Get cache directory + rm = lotus.settings.rm + model_name = getattr(rm, "model", None) + cache_dir = get_index_cache("cat2", df2["category"].tolist(), model_name) + initial_mtime = Path(cache_dir).stat().st_mtime + + # Second call with same l2 data - should hit cache + run_sem_sim_join(df1["category"], df2["category"], "cat1_mapped", "cat2") + + # Verify cache directory wasn't recreated (same mtime) + assert Path(cache_dir).stat().st_mtime == initial_mtime + + # cleanup + if Path(cache_dir).exists(): + shutil.rmtree(cache_dir)