diff --git a/.github/tests/rm_tests.py b/.github/tests/rm_tests.py index dc65b178..0668bbe0 100644 --- a/.github/tests/rm_tests.py +++ b/.github/tests/rm_tests.py @@ -5,7 +5,7 @@ import lotus from lotus.models import CrossEncoderReranker, LiteLLMRM, SentenceTransformersRM -from lotus.vector_store import FaissVS +from lotus.vector_store import ChromaVS, FaissVS ################################################################################ # Setup @@ -33,6 +33,7 @@ VECTOR_STORE_TO_CLS = { 'local': FaissVS, + 'chroma': ChromaVS, } @@ -254,6 +255,9 @@ def test_vs_sim_join(setup_models, setup_vs, vs, model): ) @pytest.mark.parametrize("vs", VECTOR_STORE_TO_CLS.keys()) def test_vs_dedup(setup_models, setup_vs, vs): + curr_threshold = 0.85 + if vs == "chroma": + curr_threshold = 0.9 rm = setup_models["intfloat/e5-small-v2"] my_vs = setup_vs[vs] lotus.settings.configure(rm=rm, vs=my_vs) @@ -266,7 +270,7 @@ def test_vs_dedup(setup_models, setup_vs, vs): ] } df = pd.DataFrame(data) - df = df.sem_index("Text", "fourthindexdir").sem_dedup("Text", threshold=0.85) + df = df.sem_index("Text", "fourthindexdir").sem_dedup("Text", threshold=curr_threshold) kept = df["Text"].tolist() kept.sort() assert len(kept) == 2, kept @@ -320,8 +324,9 @@ def test_search(setup_models): df = df.sem_search("Course Name", "Optimization", K=2, n_rerank=1) assert df["Course Name"].tolist() == ["Optimization Methods in Engineering"] +@pytest.mark.parametrize("vs", VECTOR_STORE_TO_CLS.keys()) @pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) -def test_filtered_vector_search(setup_models, model): +def test_filtered_vector_search(setup_models, setup_vs, vs, model): """ Test filtered vector search. @@ -336,7 +341,7 @@ def test_filtered_vector_search(setup_models, model): expected to pick out the culinary course "Gourmet Cooking Advanced". """ rm = setup_models[model] - vs = FaissVS() + vs = setup_vs[vs] lotus.settings.configure(rm=rm, vs=vs) data = { diff --git a/lotus/models/rm.py b/lotus/models/rm.py index 37a73c22..11855735 100644 --- a/lotus/models/rm.py +++ b/lotus/models/rm.py @@ -34,4 +34,4 @@ def convert_query_to_query_vector(self, queries: Union[pd.Series, str, Image.Ima queries = queries.tolist() # Create embeddings for text queries query_vectors = self._embed(queries) - return query_vectors \ No newline at end of file + return query_vectors diff --git a/lotus/sem_ops/sem_search.py b/lotus/sem_ops/sem_search.py index f470c2cc..4cdf973d 100644 --- a/lotus/sem_ops/sem_search.py +++ b/lotus/sem_ops/sem_search.py @@ -61,7 +61,6 @@ def __call__( df_idxs = self._obj.index cur_min = len(df_idxs) - K = min(K, cur_min) search_K = K diff --git a/lotus/vector_store/__init__.py b/lotus/vector_store/__init__.py index f1f130da..327ab06a 100644 --- a/lotus/vector_store/__init__.py +++ b/lotus/vector_store/__init__.py @@ -1,4 +1,5 @@ from lotus.vector_store.vs import VS from lotus.vector_store.faiss_vs import FaissVS +from lotus.vector_store.chroma_vs import ChromaVS -__all__ = ["VS", "FaissVS"] +__all__ = ["VS", "FaissVS", "ChromaVS"] diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py new file mode 100644 index 00000000..64ec9be2 --- /dev/null +++ b/lotus/vector_store/chroma_vs.py @@ -0,0 +1,181 @@ +from typing import Any, List, Mapping, Optional, Union, cast + +import numpy as np +import pandas as pd +from chromadb import Where +from numpy.typing import NDArray +from tqdm import tqdm + +from lotus.types import RMOutput +from lotus.vector_store.vs import VS + +try: + from chromadb import Client, ClientAPI + from chromadb.api import Collection + from chromadb.api.types import IncludeEnum + from chromadb.errors import InvalidDimensionException +except ImportError as err: + raise ImportError( + "The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`" + ) from err + +class ChromaVS(VS): + def __init__(self, max_batch_size: int = 64): + + client: ClientAPI = Client() + + """Initialize with ChromaDB client and embedding model""" + super() + self.client = client + self.collection: Collection | None = None + self.index_dir = None + self.max_batch_size = max_batch_size + + def __del__(self): + return + + def index(self, docs: Any, embeddings: Any, index_dir: str, **kwargs: dict[str, Any]): + """Create a collection and add documents with their embeddings""" + self.index_dir = index_dir + + # Create collection without embedding function (we'll provide embeddings directly) + self.collection = self.client.get_or_create_collection( + name=index_dir, + metadata={"hnsw:space": "cosine"} # Use cosine similarity for consistency + ) + + # Convert docs to list if it's a pandas Series + docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs + + # Prepare documents for addition + ids = [str(i) for i in range(len(docs_list))] + metadatas: list[Mapping[str, Union[str, int, float, bool]]] = [{"doc_id": int(i)} for i in range(len(docs_list))] + + # Add documents in batches + batch_size = 100 + for i in tqdm(range(0, len(docs_list), batch_size), desc="Uploading to ChromaDB"): + end_idx = min(i + batch_size, len(docs_list)) + try: + self.collection.add( + ids=ids[i:end_idx], + documents=docs_list[i:end_idx], + embeddings=embeddings[i:end_idx].tolist(), + metadatas=metadatas[i:end_idx] + ) + except InvalidDimensionException: + # delete, recreate, then add + self.client.delete_collection(index_dir) + # Create collection without embedding function (we'll provide embeddings directly) + self.collection = self.client.get_or_create_collection( + name=index_dir, + metadata={"hnsw:space": "cosine"} # Use cosine similarity for consistency + ) + self.collection.add( + ids=ids[i:end_idx], + documents=docs_list[i:end_idx], + embeddings=embeddings[i:end_idx].tolist(), + metadatas=metadatas[i:end_idx] + ) + + def load_index(self, index_dir: str): + """Load an existing collection""" + try: + self.collection = self.client.get_collection(index_dir) + self.index_dir = index_dir + except ValueError as e: + raise ValueError(f"Collection {index_dir} not found") from e + + def __call__( + self, + query_vectors, + K: int, + ids: Optional[list[int]] = None, + **kwargs: dict[str, Any] + ) -> RMOutput: + """ + Perform vector search using ChromaDB with optional filtering by document IDs. + + Args: + query_vectors: Pre-embedded query vectors. + K (int): Number of nearest neighbors to retrieve. + ids (Optional[list[Any]]): If provided, the search will be limited to documents with these ids. + **kwargs: Additional parameters. + + Returns: + RMOutput: Contains the distances and indices of the nearest neighbors. + """ + if self.collection is None: + raise ValueError("No collection loaded. Call load_index first.") + + all_distances: list[list[float]] = [] + all_indices: list[list[int]] = [] + + # Process each query vector. + for query_vector in query_vectors: + # Prepare the where clause by casting ids to a list of allowed types. + where_clause: Optional[dict[str, Union[dict[str, List[Union[str, int, float, bool]]]]]] = None + if ids: + where_clause = {"doc_id": {"$in": cast(List[Union[str, int, float, bool]], ids)}} + + results = self.collection.query( + query_embeddings=[query_vector.tolist()], + n_results=K, + include=[IncludeEnum.metadatas, IncludeEnum.distances], + where=cast(Where, where_clause), + ) + + distances: list[float] = [] + indices: list[int] = [] + + # Retrieve and cast search results to help the type checker. + metadatas = results.get("metadatas") + dists = results.get("distances") + if metadatas is not None and dists is not None: + metadatas = cast( + List[List[Mapping[str, Union[str, int, float, bool]]]], metadatas + ) + dists = cast(List[List[float]], dists) + for metadata, distance in zip(metadatas[0], dists[0]): + if metadata is not None and distance is not None: + indices.append(int(metadata["doc_id"])) + # Convert squared L2 distances to cosine similarity. + distances.append(1 - (distance / 2)) + + # Pad results if fewer than K matches are returned. + while len(indices) < K: + indices.append(-1) + distances.append(0.0) + + all_indices.append(indices) + all_distances.append(distances) + + return RMOutput( + distances=np.array(all_distances, dtype=np.float32).tolist(), + indices=np.array(all_indices, dtype=np.int64).tolist() + ) + + def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]: + """Retrieve vectors for specific document IDs""" + if self.collection is None or self.index_dir != index_dir: + self.load_index(index_dir) + + + if self.collection is None: # Add this check after load_index + raise ValueError(f"Failed to load collection {index_dir}") + + + # Convert integer ids to strings for ChromaDB + str_ids = [str(id) for id in ids] + + # Get embeddings from ChromaDB + results = self.collection.get( + ids=str_ids, + include=[IncludeEnum.embeddings] + ) + + if results['embeddings'] is None: + raise ValueError("No vectors found for the given ids", results['embeddings']) + + return np.array(results['embeddings'], dtype=np.float64) + + \ No newline at end of file diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py index a3632219..01087757 100644 --- a/lotus/vector_store/vs.py +++ b/lotus/vector_store/vs.py @@ -33,7 +33,7 @@ def __call__( self, query_vectors: Any, K: int, - ids: Optional[list[Any]] = None, + ids: Optional[list[int]] = None, **kwargs: dict[str, Any], ) -> RMOutput: """ @@ -52,7 +52,7 @@ def __call__( pass @abstractmethod - def get_vectors_from_index(self, index_dir: str, ids: list[Any]) -> NDArray[np.float64]: + def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]: """ Retrieve vectors from a stored index given specific ids. """