diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 2ebb8320..9ddff784 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -8,6 +8,7 @@ from lotus.cache import CacheConfig, CacheFactory, CacheType from lotus.models import LM, SentenceTransformersRM from lotus.types import CascadeArgs +from lotus.vector_store import FaissVS ################################################################################ # Setup @@ -289,7 +290,8 @@ def test_filter_cascade(setup_models): def test_join_cascade(setup_models): models = setup_models rm = SentenceTransformersRM(model="intfloat/e5-base-v2") - lotus.settings.configure(lm=models["gpt-4o-mini"], rm=rm) + vs = FaissVS() + lotus.settings.configure(lm=models["gpt-4o-mini"], rm=rm, vs=vs) data1 = { "School": [ diff --git a/.github/tests/multimodality_tests.py b/.github/tests/multimodality_tests.py index c6311d81..64ce0bc9 100644 --- a/.github/tests/multimodality_tests.py +++ b/.github/tests/multimodality_tests.py @@ -6,6 +6,7 @@ import lotus from lotus.dtype_extensions import ImageArray from lotus.models import LM, SentenceTransformersRM +from lotus.vector_store import FaissVS ################################################################################ # Setup @@ -160,7 +161,8 @@ def test_topk_with_groupby_operation(setup_models, model): @pytest.mark.parametrize("model", get_enabled("clip-ViT-B-32")) def test_search_operation(setup_models, model): rm = setup_models[model] - lotus.settings.configure(rm=rm) + vs = FaissVS() + lotus.settings.configure(rm=rm, vs=vs) image_url = [ "https://img.etsystatic.com/il/4bee20/1469037676/il_340x270.1469037676_iiti.jpg?version=0", @@ -180,7 +182,8 @@ def test_search_operation(setup_models, model): @pytest.mark.parametrize("model", get_enabled("clip-ViT-B-32")) def test_sim_join_operation_image_index(setup_models, model): rm = setup_models[model] - lotus.settings.configure(rm=rm) + vs = FaissVS() + lotus.settings.configure(rm=rm, vs=vs) image_url = [ "https://img.etsystatic.com/il/4bee20/1469037676/il_340x270.1469037676_iiti.jpg?version=0", @@ -205,7 +208,8 @@ def test_sim_join_operation_image_index(setup_models, model): @pytest.mark.parametrize("model", get_enabled("clip-ViT-B-32")) def test_sim_join_operation_text_index(setup_models, model): rm = setup_models[model] - lotus.settings.configure(rm=rm) + vs = FaissVS() + lotus.settings.configure(rm=rm, vs=vs) image_url = [ "https://img.etsystatic.com/il/4bee20/1469037676/il_340x270.1469037676_iiti.jpg?version=0", diff --git a/.github/tests/rm_tests.py b/.github/tests/rm_tests.py index 2c00e116..31e6c0c0 100644 --- a/.github/tests/rm_tests.py +++ b/.github/tests/rm_tests.py @@ -5,6 +5,7 @@ import lotus from lotus.models import CrossEncoderReranker, LiteLLMRM, SentenceTransformersRM +from lotus.vector_store import ChromaVS, FaissVS, PineconeVS, QdrantVS, WeaviateVS ################################################################################ # Setup @@ -30,6 +31,14 @@ "text-embedding-3-small": LiteLLMRM, } +VECTOR_STORE_TO_CLS = { + 'local': FaissVS, + 'weaviate':WeaviateVS, + 'pinecone': PineconeVS, + 'chroma': ChromaVS, + 'qdrant': QdrantVS +} + def get_enabled(*candidate_models: str) -> list[str]: return [model for model in candidate_models if model in ENABLED_MODEL_NAMES] @@ -41,16 +50,28 @@ def setup_models(): for model_name in ENABLED_MODEL_NAMES: models[model_name] = MODEL_NAME_TO_CLS[model_name](model=model_name) + + return models +@pytest.fixture(scope='session') +def setup_vs(): + vs_model = {} + + for vs in VECTOR_STORE_TO_CLS: + vs_model[vs] = VECTOR_STORE_TO_CLS[vs]() + + return vs_model + ################################################################################ # RM Only Tests ################################################################################ @pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) def test_cluster_by(setup_models, model): rm = setup_models[model] - lotus.settings.configure(rm=rm) + vs = FaissVS() + lotus.settings.configure(rm=rm, vs=vs) data = { "Course Name": [ @@ -79,7 +100,9 @@ def test_cluster_by(setup_models, model): @pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) def test_search_rm_only(setup_models, model): rm = setup_models[model] - lotus.settings.configure(rm=rm) + vs = FaissVS() + + lotus.settings.configure(rm=rm, vs=vs) data = { "Course Name": [ @@ -98,7 +121,8 @@ def test_search_rm_only(setup_models, model): @pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) def test_sim_join(setup_models, model): rm = setup_models[model] - lotus.settings.configure(rm=rm) + vs = FaissVS() + lotus.settings.configure(rm=rm, vs=vs) data1 = { "Course Name": [ @@ -124,7 +148,8 @@ def test_sim_join(setup_models, model): ) def test_dedup(setup_models): rm = setup_models["intfloat/e5-small-v2"] - lotus.settings.configure(rm=rm) + vs = FaissVS() + lotus.settings.configure(rm=rm,vs=vs) data = { "Text": [ "Probability and Random Processes", @@ -142,6 +167,113 @@ def test_dedup(setup_models): assert "Probability" in kept[1], kept + +################################################################################ +# VS Only Tests +################################################################################ + + +@pytest.mark.parametrize("vs", VECTOR_STORE_TO_CLS.keys()) +@pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) +def test_vs_cluster_by(setup_models, setup_vs, vs, model): + rm = setup_models[model] + my_vs = setup_vs[vs] + lotus.settings.configure(rm=rm, vs=my_vs) + + data = { + "Course Name": [ + "Probability and Random Processes", + "Cooking", + "Food Sciences", + "Optimization Methods in Engineering", + ] + } + df = pd.DataFrame(data) + df = df.sem_index("Course Name", "indexdir") + df = df.sem_cluster_by("Course Name", 2) + groups = df.groupby("cluster_id")["Course Name"].apply(set).to_dict() + assert len(groups) == 2, groups + if "Cooking" in groups[0]: + cooking_group = groups[0] + probability_group = groups[1] + else: + cooking_group = groups[1] + probability_group = groups[0] + + assert cooking_group == {"Cooking", "Food Sciences"}, groups + assert probability_group == {"Probability and Random Processes", "Optimization Methods in Engineering"}, groups + +@pytest.mark.parametrize("vs", VECTOR_STORE_TO_CLS.keys()) +@pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) +def test_vs_search_rm_only(setup_models, setup_vs, vs, model): + rm = setup_models[model] + my_vs = setup_vs[vs] + lotus.settings.configure(rm=rm, vs=my_vs) + + data = { + "Course Name": [ + "Probability and Random Processes", + "Cooking", + "Food Sciences", + "Optimization Methods in Engineering", + ] + } + df = pd.DataFrame(data) + df = df.sem_index("Course Name", "secondindexdir") + df = df.sem_search("Course Name", "Optimization", K=1) + assert df["Course Name"].tolist() == ["Optimization Methods in Engineering"] + +@pytest.mark.parametrize("vs", VECTOR_STORE_TO_CLS.keys()) +@pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) +def test_vs_sim_join(setup_models, setup_vs, vs, model): + rm = setup_models[model] + my_vs = setup_vs[vs] + lotus.settings.configure(rm=rm, vs=my_vs) + + data1 = { + "Course Name": [ + "History of the Atlantic World", + "Riemannian Geometry", + ] + } + + data2 = {"Skill": ["Math", "History"]} + + df1 = pd.DataFrame(data1) + df2 = pd.DataFrame(data2).sem_index("Skill", "thirdindexdir") + joined_df = df1.sem_sim_join(df2, left_on="Course Name", right_on="Skill", K=1) + joined_pairs = set(zip(joined_df["Course Name"], joined_df["Skill"])) + expected_pairs = {("History of the Atlantic World", "History"), ("Riemannian Geometry", "Math")} + assert joined_pairs == expected_pairs, joined_pairs + + +# TODO: threshold is hardcoded for intfloat/e5-small-v2 +@pytest.mark.skipif( + "intfloat/e5-small-v2" not in ENABLED_MODEL_NAMES, + reason="Skipping test because intfloat/e5-small-v2 is not enabled", +) +@pytest.mark.parametrize("vs", VECTOR_STORE_TO_CLS.keys()) +def test_vs_dedup(setup_models, setup_vs, vs): + rm = setup_models["intfloat/e5-small-v2"] + my_vs = setup_vs[vs] + lotus.settings.configure(rm=rm, vs=my_vs) + data = { + "Text": [ + "Probability and Random Processes", + "Probability and Markov Chains", + "Harry Potter", + "Harry James Potter", + ] + } + df = pd.DataFrame(data) + df = df.sem_index("Text", "fourthindexdir").sem_dedup("Text", threshold=0.85) + kept = df["Text"].tolist() + kept.sort() + assert len(kept) == 2, kept + assert "Harry" in kept[0], kept + assert "Probability" in kept[1], kept + + ################################################################################ # Reranker Only Tests ################################################################################ @@ -171,8 +303,9 @@ def test_search_reranker_only(setup_models, model): def test_search(setup_models): models = setup_models rm = models["intfloat/e5-small-v2"] + vs = FaissVS() reranker = models["mixedbread-ai/mxbai-rerank-xsmall-v1"] - lotus.settings.configure(rm=rm, reranker=reranker) + lotus.settings.configure(rm=rm, vs = vs, reranker=reranker) data = { "Course Name": [ diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 494ded8c..b5c1e595 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -155,7 +155,7 @@ jobs: rm_test: name: Retrieval Model Tests runs-on: ubuntu-latest - timeout-minutes: 5 + timeout-minutes: 10 steps: - name: Checkout code diff --git a/examples/op_examples/cluster.py b/examples/op_examples/cluster.py index e117b249..e37e08f8 100644 --- a/examples/op_examples/cluster.py +++ b/examples/op_examples/cluster.py @@ -2,11 +2,13 @@ import lotus from lotus.models import LM, SentenceTransformersRM +from lotus.vector_store import FaissVS lm = LM(model="gpt-4o-mini") rm = SentenceTransformersRM(model="intfloat/e5-base-v2") +vs = FaissVS() -lotus.settings.configure(lm=lm, rm=rm) +lotus.settings.configure(lm=lm, rm=rm, vs=vs) data = { "Course Name": [ "Probability and Random Processes", diff --git a/examples/op_examples/dedup.py b/examples/op_examples/dedup.py index 1494df95..7a866918 100644 --- a/examples/op_examples/dedup.py +++ b/examples/op_examples/dedup.py @@ -2,10 +2,11 @@ import lotus from lotus.models import SentenceTransformersRM +from lotus.vector_store import FaissVS rm = SentenceTransformersRM(model="intfloat/e5-base-v2") - -lotus.settings.configure(rm=rm) +vs = FaissVS() +lotus.settings.configure(rm=rm, vs=vs) data = { "Text": [ "Probability and Random Processes", diff --git a/examples/op_examples/join_cascade.py b/examples/op_examples/join_cascade.py index 05e3e562..517ee6b7 100644 --- a/examples/op_examples/join_cascade.py +++ b/examples/op_examples/join_cascade.py @@ -3,11 +3,13 @@ import lotus from lotus.models import LM, SentenceTransformersRM from lotus.types import CascadeArgs +from lotus.vector_store import FaissVS lm = LM(model="gpt-4o-mini") rm = SentenceTransformersRM(model="intfloat/e5-base-v2") +vs = FaissVS() -lotus.settings.configure(lm=lm, rm=rm) +lotus.settings.configure(lm=lm, rm=rm, vs=vs) data = { "Course Name": [ "Digital Design and Integrated Circuits", diff --git a/examples/op_examples/partition.py b/examples/op_examples/partition.py index 932b170b..ed82015c 100644 --- a/examples/op_examples/partition.py +++ b/examples/op_examples/partition.py @@ -2,11 +2,13 @@ import lotus from lotus.models import LM, SentenceTransformersRM +from lotus.vector_store import FaissVS lm = LM(max_tokens=2048) rm = SentenceTransformersRM(model="intfloat/e5-base-v2") +vs = FaissVS() -lotus.settings.configure(lm=lm, rm=rm) +lotus.settings.configure(lm=lm, rm=rm, vs=vs) data = { "Course Name": [ "Probability and Random Processes", diff --git a/examples/op_examples/search.py b/examples/op_examples/search.py index c9382aae..49466ac3 100644 --- a/examples/op_examples/search.py +++ b/examples/op_examples/search.py @@ -2,12 +2,14 @@ import lotus from lotus.models import LM, CrossEncoderReranker, SentenceTransformersRM +from lotus.vector_store import FaissVS lm = LM(model="gpt-4o-mini") rm = SentenceTransformersRM(model="intfloat/e5-base-v2") -reranker = CrossEncoderReranker(model="mixedbread-ai/mxbai-rerank-large-v1") +reranker = CrossEncoderReranker(model="mixeddbread-ai/mxbai-rerank-large-v1") +vs = FaissVS() -lotus.settings.configure(lm=lm, rm=rm, reranker=reranker) +lotus.settings.configure(lm=lm, rm=rm, reranker=reranker, vs=vs) data = { "Course Name": [ "Probability and Random Processes", diff --git a/examples/op_examples/sim_join.py b/examples/op_examples/sim_join.py index 108ccb28..6a6bfc27 100644 --- a/examples/op_examples/sim_join.py +++ b/examples/op_examples/sim_join.py @@ -2,11 +2,13 @@ import lotus from lotus.models import LM, LiteLLMRM +from lotus.vector_store import FaissVS lm = LM(model="gpt-4o-mini") rm = LiteLLMRM(model="text-embedding-3-small") +vs = FaissVS() -lotus.settings.configure(lm=lm, rm=rm) +lotus.settings.configure(lm=lm, rm=rm, vs=vs) data = { "Course Name": [ "History of the Atlantic World", diff --git a/lotus/models/colbertv2_rm.py b/lotus/models/colbertv2_rm.py index 2bd8ed99..8018221f 100644 --- a/lotus/models/colbertv2_rm.py +++ b/lotus/models/colbertv2_rm.py @@ -6,7 +6,6 @@ from numpy.typing import NDArray from PIL import Image -from lotus.models.rm import RM from lotus.types import RMOutput try: @@ -16,7 +15,7 @@ pass -class ColBERTv2RM(RM): +class ColBERTv2RM(): def __init__(self) -> None: self.docs: list[str] | None = None self.kwargs: dict[str, Any] = {"doc_maxlen": 300, "nbits": 2} @@ -46,6 +45,9 @@ def load_index(self, index_dir: str) -> None: def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]: raise NotImplementedError("This method is not implemented for ColBERTv2RM") + + + # this should be called in vs.py if it's def __call__( self, queries: str | Image.Image | list | NDArray[np.float64], diff --git a/lotus/models/litellm_rm.py b/lotus/models/litellm_rm.py index a4486dc6..b0ee07be 100644 --- a/lotus/models/litellm_rm.py +++ b/lotus/models/litellm_rm.py @@ -7,10 +7,10 @@ from tqdm import tqdm from lotus.dtype_extensions import convert_to_base_data -from lotus.models.faiss_rm import FaissRM +from lotus.models.rm import RM -class LiteLLMRM(FaissRM): +class LiteLLMRM(RM): def __init__( self, model: str = "text-embedding-3-small", @@ -18,7 +18,7 @@ def __init__( factory_string: str = "Flat", metric=faiss.METRIC_INNER_PRODUCT, ): - super().__init__(factory_string, metric) + super() self.model: str = model self.max_batch_size: int = max_batch_size diff --git a/lotus/models/rm.py b/lotus/models/rm.py index c58a0c38..72324db8 100644 --- a/lotus/models/rm.py +++ b/lotus/models/rm.py @@ -1,42 +1,63 @@ from abc import ABC, abstractmethod -from typing import Any +from typing import Union import numpy as np import pandas as pd from numpy.typing import NDArray from PIL import Image -from lotus.types import RMOutput - class RM(ABC): - """Abstract class for retriever models.""" + #Abstract class for retriever models. def __init__(self) -> None: - self.index_dir: str | None = None + pass + @abstractmethod + def _embed(self, docs:pd.Series | list): + pass + + def __call__(self, docs: pd.Series | list): + return self._embed(docs) + + def convert_query_to_query_vector(self, queries: Union[pd.Series, str, Image.Image, list, NDArray[np.float64]], +): + if isinstance(queries, (str, Image.Image)): + queries = [queries] + + # Handle numpy array queries (pre-computed vectors) + if isinstance(queries, np.ndarray): + query_vectors = queries + else: + # Convert queries to list if needed + if isinstance(queries, pd.Series): + queries = queries.tolist() + # Create embeddings for text queries + query_vectors = self._embed(queries) + return query_vectors + """ @abstractmethod def index(self, docs: pd.Series, index_dir: str, **kwargs: dict[str, Any]) -> None: - """Create index and store it to a directory. + Create index and store it to a directory. Args: docs (list[str]): A list of documents to index. index_dir (str): The directory to save the index in. - """ + pass @abstractmethod def load_index(self, index_dir: str) -> None: - """Load the index into memory. + Load the index into memory. Args: index_dir (str): The directory of where the index is stored. - """ + pass @abstractmethod def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]: - """Get the vectors from the index. + Get the vectors from the index. Args: index_dir (str): Directory of the index. @@ -44,7 +65,7 @@ def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.f Returns: NDArray[np.float64]: The vectors matching the specified ids. - """ + pass @@ -55,7 +76,7 @@ def __call__( K: int, **kwargs: dict[str, Any], ) -> RMOutput: - """Run top-k search on the index. + Run top-k search on the index. Args: queries (str | list[str] | NDArray[np.float64]): Either a query or a list of queries or a 2D FP32 array. @@ -64,5 +85,8 @@ def __call__( Returns: RMOutput: An RMOutput object containing the distances and indices of the top-k vectors. - """ + pass + + + """ \ No newline at end of file diff --git a/lotus/models/sentence_transformers_rm.py b/lotus/models/sentence_transformers_rm.py index 76db6f3f..23559e21 100644 --- a/lotus/models/sentence_transformers_rm.py +++ b/lotus/models/sentence_transformers_rm.py @@ -1,4 +1,3 @@ -import faiss import numpy as np import pandas as pd import torch @@ -7,20 +6,18 @@ from tqdm import tqdm from lotus.dtype_extensions import convert_to_base_data -from lotus.models.faiss_rm import FaissRM +from lotus.models.rm import RM -class SentenceTransformersRM(FaissRM): +class SentenceTransformersRM(RM): def __init__( self, model: str = "intfloat/e5-base-v2", max_batch_size: int = 64, normalize_embeddings: bool = True, device: str | None = None, - factory_string: str = "Flat", - metric=faiss.METRIC_INNER_PRODUCT, ): - super().__init__(factory_string, metric) + #super().__init__(factory_string, metric) self.model: str = model self.max_batch_size: int = max_batch_size self.normalize_embeddings: bool = normalize_embeddings diff --git a/lotus/sem_ops/sem_cluster_by.py b/lotus/sem_ops/sem_cluster_by.py index fc8a9c6f..43e5324b 100644 --- a/lotus/sem_ops/sem_cluster_by.py +++ b/lotus/sem_ops/sem_cluster_by.py @@ -42,9 +42,11 @@ def __call__( Returns: pd.DataFrame: The DataFrame with the cluster assignments. """ - if lotus.settings.rm is None: + rm = lotus.settings.rm + vs = lotus.settings.vs + if rm is None or vs is None : raise ValueError( - "The retrieval model must be an instance of RM. Please configure a valid retrieval model using lotus.settings.configure()" + "The retrieval model must be an instance of RM, and the vector store must be an instance of VS. Please configure a valid retrieval model using lotus.settings.configure()" ) cluster_fn = lotus.utils.cluster(col_name, ncentroids) diff --git a/lotus/sem_ops/sem_dedup.py b/lotus/sem_ops/sem_dedup.py index 3360714d..cbc0ce58 100644 --- a/lotus/sem_ops/sem_dedup.py +++ b/lotus/sem_ops/sem_dedup.py @@ -36,9 +36,11 @@ def __call__( Returns: pd.DataFrame: The DataFrame with duplicates removed. """ - if lotus.settings.rm is None: + rm = lotus.settings.rm + vs = lotus.settings.vs + if rm is None or vs is None: raise ValueError( - "The retrieval model must be an instance of RM. Please configure a valid retrieval model using lotus.settings.configure()" + "The retrieval model must be an instance of RM, and the vector store must be an instance of VS. Please configure a valid retrieval model using lotus.settings.configure()" ) joined_df = self._obj.sem_sim_join(self._obj, col_name, col_name, len(self._obj), lsuffix="_l", rsuffix="_r") diff --git a/lotus/sem_ops/sem_index.py b/lotus/sem_ops/sem_index.py index ae8d7753..12015768 100644 --- a/lotus/sem_ops/sem_index.py +++ b/lotus/sem_ops/sem_index.py @@ -32,12 +32,17 @@ def __call__(self, col_name: str, index_dir: str) -> pd.DataFrame: Returns: pd.DataFrame: The DataFrame with the index directory saved. """ - if lotus.settings.rm is None: + + rm = lotus.settings.rm + vs = lotus.settings.vs + if rm is None or vs is None: raise ValueError( - "The retrieval model must be an instance of RM. Please configure a valid retrieval model using lotus.settings.configure()" + "The retrieval model must be an instance of RM, and the vector store must be an instance of VS. Please configure a valid retrieval model using lotus.settings.configure()" ) - rm = lotus.settings.rm - rm.index(self._obj[col_name], index_dir) + + + embeddings = rm(self._obj[col_name]) + vs.index(self._obj[col_name], embeddings, index_dir) self._obj.attrs["index_dirs"][col_name] = index_dir return self._obj diff --git a/lotus/sem_ops/sem_search.py b/lotus/sem_ops/sem_search.py index de2df357..6ae0e5ef 100644 --- a/lotus/sem_ops/sem_search.py +++ b/lotus/sem_ops/sem_search.py @@ -5,6 +5,7 @@ import lotus from lotus.cache import operator_cache from lotus.types import RerankerOutput, RMOutput +from lotus.vector_store.pinecone_vs import PineconeVS @pd.api.extensions.register_dataframe_accessor("sem_search") @@ -48,24 +49,29 @@ def __call__( if K is not None: # get retriever model and index rm = lotus.settings.rm - if rm is None: + vs = lotus.settings.vs + if rm is None or vs is None : raise ValueError( - "The retrieval model must be an instance of RM. Please configure a valid retrieval model using lotus.settings.configure()" + "The retrieval model must be an instance of RM, and the vector store should be an instance of VS. Please configure a valid retrieval model and vector store using lotus.settings.configure()" ) col_index_dir = self._obj.attrs["index_dirs"][col_name] - if rm.index_dir != col_index_dir: - rm.load_index(col_index_dir) - assert rm.index_dir == col_index_dir + if vs.index_dir != col_index_dir: + vs.load_index(col_index_dir) + assert vs.index_dir == col_index_dir df_idxs = self._obj.index - K = min(K, len(df_idxs)) + cur_min = len(df_idxs) + if isinstance(vs, PineconeVS): + cur_min = min(cur_min, 10000) + K = min(K, cur_min) search_K = K while True: - rm_output: RMOutput = rm(query, search_K) - doc_idxs = rm_output.indices[0] - scores = rm_output.distances[0] + query_vectors = rm.convert_query_to_query_vector(query) + vs_output: RMOutput = vs(query_vectors, search_K) + doc_idxs = vs_output.indices[0] + scores = vs_output.distances[0] assert len(doc_idxs) == len(scores) postfiltered_doc_idxs = [] diff --git a/lotus/sem_ops/sem_sim_join.py b/lotus/sem_ops/sem_sim_join.py index 47d3cbe3..3de4a853 100644 --- a/lotus/sem_ops/sem_sim_join.py +++ b/lotus/sem_ops/sem_sim_join.py @@ -6,6 +6,7 @@ from lotus.cache import operator_cache from lotus.models import RM from lotus.types import RMOutput +from lotus.vector_store import VS @pd.api.extensions.register_dataframe_accessor("sem_sim_join") @@ -51,20 +52,21 @@ def __call__( raise ValueError("Other Series must have a name") other = pd.DataFrame({other.name: other}) - rm = lotus.settings.rm - if not isinstance(rm, RM): + rm = lotus.settings.rm + vs = lotus.settings.vs + if not isinstance(rm, RM) or not isinstance(vs, VS): raise ValueError( - "The retrieval model must be an instance of RM. Please configure a valid retrieval model using lotus.settings.configure()" + "The retrieval model must be an instance of RM, and the vector store must be an instance of VS. Please configure a valid retrieval model or vector store using lotus.settings.configure()" ) - + # load query embeddings from index if they exist if left_on in self._obj.attrs.get("index_dirs", []): query_index_dir = self._obj.attrs["index_dirs"][left_on] - if rm.index_dir != query_index_dir: - rm.load_index(query_index_dir) - assert rm.index_dir == query_index_dir + if vs.index_dir != query_index_dir: + vs.load_index(query_index_dir) + assert vs.index_dir == query_index_dir try: - queries = rm.get_vectors_from_index(query_index_dir, self._obj.index) + queries = vs.get_vectors_from_index(query_index_dir, self._obj.index) except NotImplementedError: queries = self._obj[left_on] else: @@ -75,13 +77,15 @@ def __call__( col_index_dir = other.attrs["index_dirs"][right_on] except KeyError: raise ValueError(f"Index directory for column {right_on} not found in DataFrame") - if rm.index_dir != col_index_dir: - rm.load_index(col_index_dir) - assert rm.index_dir == col_index_dir + if vs.index_dir != col_index_dir: + vs.load_index(col_index_dir) + assert vs.index_dir == col_index_dir + + query_vectors = rm.convert_query_to_query_vector(queries) - rm_output: RMOutput = rm(queries, K) - distances = rm_output.distances - indices = rm_output.indices + vs_output: RMOutput = vs(query_vectors, K) + distances = vs_output.distances + indices = vs_output.indices other_index_set = set(other.index) join_results = [] diff --git a/lotus/settings.py b/lotus/settings.py index f7277a1c..571e6155 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -8,11 +8,12 @@ class Settings: # Models lm: lotus.models.LM | None = None - rm: lotus.models.RM | None = None + rm: lotus.models.RM | None = None # supposed to only generate embeddings helper_lm: lotus.models.LM | None = None reranker: lotus.models.Reranker | None = None vs: lotus.vector_store.VS | None = None + # Cache settings enable_cache: bool = False @@ -23,6 +24,8 @@ class Settings: parallel_groupby_max_threads: int = 8 def configure(self, **kwargs): + + for key, value in kwargs.items(): if not hasattr(self, key): raise ValueError(f"Invalid setting: {key}") @@ -30,6 +33,7 @@ def configure(self, **kwargs): def __str__(self): return str(vars(self)) + settings = Settings() diff --git a/lotus/utils.py b/lotus/utils.py index b3e68ac9..da4dc82a 100644 --- a/lotus/utils.py +++ b/lotus/utils.py @@ -42,9 +42,10 @@ def ret( # get rmodel and index rm = lotus.settings.rm - if rm is None: + vs = lotus.settings.vs + if rm is None or vs is None: raise ValueError( - "The retrieval model must be an instance of RM. Please configure a valid retrieval model using lotus.settings.configure()" + "The retrieval model must be an instance of RM, and the vector store must be an instance of VS. Please configure a valid retrieval model using lotus.settings.configure()" ) try: @@ -52,12 +53,12 @@ def ret( except KeyError: raise ValueError(f"Index directory for column {col_name} not found in DataFrame") - if rm.index_dir != col_index_dir: - rm.load_index(col_index_dir) - assert rm.index_dir == col_index_dir + if vs.index_dir != col_index_dir: + vs.load_index(col_index_dir) + assert vs.index_dir == col_index_dir ids = df.index.tolist() # assumes df index hasn't been resest and corresponds to faiss index ids - vec_set = rm.get_vectors_from_index(col_index_dir, ids) + vec_set = vs.get_vectors_from_index(col_index_dir, ids) d = vec_set.shape[1] kmeans = faiss.Kmeans(d, ncentroids, niter=niter, verbose=verbose) kmeans.train(vec_set) diff --git a/lotus/vector_store/__init__.py b/lotus/vector_store/__init__.py index 34c41998..b1efbeac 100644 --- a/lotus/vector_store/__init__.py +++ b/lotus/vector_store/__init__.py @@ -1,7 +1,8 @@ from lotus.vector_store.vs import VS +from lotus.vector_store.faiss_vs import FaissVS from lotus.vector_store.weaviate_vs import WeaviateVS from lotus.vector_store.pinecone_vs import PineconeVS from lotus.vector_store.chroma_vs import ChromaVS from lotus.vector_store.qdrant_vs import QdrantVS -__all__ = ["VS", "WeaviateVS", "PineconeVS", "ChromaVS", "QdrantVS"] +__all__ = ["VS", "FaissVS", "WeaviateVS", "PineconeVS", "ChromaVS", "QdrantVS"] diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index 298f05e9..ad3009a8 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -1,16 +1,174 @@ +from typing import Any, Mapping, Union + +import numpy as np +import pandas as pd +from numpy.typing import NDArray +from tqdm import tqdm + +from lotus.types import RMOutput from lotus.vector_store.vs import VS +try: + from chromadb import Client, ClientAPI + from chromadb.api import Collection + from chromadb.api.types import IncludeEnum + from chromadb.errors import InvalidDimensionException +except ImportError as err: + raise ImportError( + "The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`" + ) from err class ChromaVS(VS): - def __init__(self): + def __init__(self, max_batch_size: int = 64): + + client: ClientAPI = Client() + + """Initialize with ChromaDB client and embedding model""" + super() + self.client = client + self.collection: Collection | None = None + self.index_dir = None + self.max_batch_size = max_batch_size + + def __del__(self): + return + + def index(self, docs: Any, embeddings: Any, index_dir: str, **kwargs: dict[str, Any]): + """Create a collection and add documents with their embeddings""" + self.index_dir = index_dir + + # Create collection without embedding function (we'll provide embeddings directly) + self.collection = self.client.get_or_create_collection( + name=index_dir, + metadata={"hnsw:space": "cosine"} # Use cosine similarity for consistency + ) + + # Convert docs to list if it's a pandas Series + docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs + + # Prepare documents for addition + ids = [str(i) for i in range(len(docs_list))] + metadatas: list[Mapping[str, Union[str, int, float, bool]]] = [{"doc_id": int(i)} for i in range(len(docs_list))] + + # Add documents in batches + batch_size = 100 + for i in tqdm(range(0, len(docs_list), batch_size), desc="Uploading to ChromaDB"): + end_idx = min(i + batch_size, len(docs_list)) + try: + self.collection.add( + ids=ids[i:end_idx], + documents=docs_list[i:end_idx], + embeddings=embeddings[i:end_idx].tolist(), + metadatas=metadatas[i:end_idx] + ) + except InvalidDimensionException: + # delete, recreate, then add + self.client.delete_collection(index_dir) + # Create collection without embedding function (we'll provide embeddings directly) + self.collection = self.client.get_or_create_collection( + name=index_dir, + metadata={"hnsw:space": "cosine"} # Use cosine similarity for consistency + ) + self.collection.add( + ids=ids[i:end_idx], + documents=docs_list[i:end_idx], + embeddings=embeddings[i:end_idx].tolist(), + metadatas=metadatas[i:end_idx] + ) + + def load_index(self, index_dir: str): + """Load an existing collection""" try: - import chromadb - except ImportError: - chromadb = None + self.collection = self.client.get_collection(index_dir) + self.index_dir = index_dir + except ValueError as e: + raise ValueError(f"Collection {index_dir} not found") from e + + def __call__( + self, + query_vectors, + K: int, + **kwargs: dict[str, Any] + ) -> RMOutput: + """Perform vector search using ChromaDB""" + if self.collection is None: + raise ValueError("No collection loaded. Call load_index first.") + + + """ + # Convert single query to list + if isinstance(queries, (str, Image.Image)): + queries = [queries] + + # Handle numpy array queries (pre-computed vectors) + if isinstance(queries, np.ndarray): + query_vectors = queries + else: + # Convert queries to list if needed + if isinstance(queries, pd.Series): + queries = queries.tolist() + # Create embeddings for text queries + query_vectors = self._batch_embed(queries) + + """ + # Perform searches + all_distances = [] + all_indices = [] - if chromadb is None: - raise ImportError( - "The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`", + for query_vector in query_vectors: + results = self.collection.query( + query_embeddings=[query_vector.tolist()], + n_results=K, + include=[IncludeEnum.metadatas, IncludeEnum.distances] ) - pass + + # Extract distances and indices + distances = [] + indices = [] + + if results['metadatas'] and results['distances']: + for metadata, distance in zip(results['metadatas'][0], results['distances'][0]): + indices.append(metadata['doc_id']) + # ChromaDB returns squared L2 distances, convert to cosine similarity + # similarity = 1 - (distance / 2) # Convert L2 distance to cosine similarity + distances.append(1 - (distance / 2)) + + # Pad results if fewer than K matches + while len(indices) < K: + indices.append(-1) + distances.append(0.0) + + all_distances.append(distances) + all_indices.append(indices) + + return RMOutput( + distances=np.array(all_distances, dtype=np.float32).tolist(), + indices=np.array(all_indices, dtype=np.int64).tolist() + ) + + def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]: + """Retrieve vectors for specific document IDs""" + if self.collection is None or self.index_dir != index_dir: + self.load_index(index_dir) + + + if self.collection is None: # Add this check after load_index + raise ValueError(f"Failed to load collection {index_dir}") + + + # Convert integer ids to strings for ChromaDB + str_ids = [str(id) for id in ids] + + # Get embeddings from ChromaDB + results = self.collection.get( + ids=str_ids, + include=[IncludeEnum.embeddings] + ) + + if results['embeddings'] is None: + raise ValueError("No vectors found for the given ids", results['embeddings']) + + return np.array(results['embeddings'], dtype=np.float64) + + \ No newline at end of file diff --git a/lotus/models/faiss_rm.py b/lotus/vector_store/faiss_vs.py similarity index 70% rename from lotus/models/faiss_rm.py rename to lotus/vector_store/faiss_vs.py index ace1fd92..9f682469 100644 --- a/lotus/models/faiss_rm.py +++ b/lotus/vector_store/faiss_vs.py @@ -1,19 +1,17 @@ import os import pickle -from abc import abstractmethod from typing import Any import faiss import numpy as np import pandas as pd from numpy.typing import NDArray -from PIL import Image -from lotus.models.rm import RM from lotus.types import RMOutput +from lotus.vector_store.vs import VS -class FaissRM(RM): +class FaissVS(VS): def __init__(self, factory_string: str = "Flat", metric=faiss.METRIC_INNER_PRODUCT): super().__init__() self.factory_string = factory_string @@ -22,15 +20,14 @@ def __init__(self, factory_string: str = "Flat", metric=faiss.METRIC_INNER_PRODU self.faiss_index: faiss.Index | None = None self.vecs: NDArray[np.float64] | None = None - def index(self, docs: pd.Series, index_dir: str, **kwargs: dict[str, Any]) -> None: - vecs = self._embed(docs) - self.faiss_index = faiss.index_factory(vecs.shape[1], self.factory_string, self.metric) - self.faiss_index.add(vecs) + def index(self, docs: pd.Series, embeddings, index_dir: str, **kwargs: dict[str, Any]) -> None: + self.faiss_index = faiss.index_factory(embeddings.shape[1], self.factory_string, self.metric) + self.faiss_index.add(embeddings) self.index_dir = index_dir os.makedirs(index_dir, exist_ok=True) with open(f"{index_dir}/vecs", "wb") as fp: - pickle.dump(vecs, fp) + pickle.dump(embeddings, fp) faiss.write_index(self.faiss_index, f"{index_dir}/index") def load_index(self, index_dir: str) -> None: @@ -45,8 +42,11 @@ def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.f return vecs[ids] def __call__( - self, queries: pd.Series | str | Image.Image | list | NDArray[np.float64], K: int, **kwargs: dict[str, Any] + self, query_vectors, K: int, **kwargs: dict[str, Any] ) -> RMOutput: + + """ + do this processing in the rm if isinstance(queries, str) or isinstance(queries, Image.Image): queries = [queries] @@ -54,13 +54,10 @@ def __call__( embedded_queries = self._embed(queries) else: embedded_queries = np.asarray(queries, dtype=np.float32) - + """ if self.faiss_index is None: raise ValueError("Index not loaded") - distances, indices = self.faiss_index.search(embedded_queries, K) + distances, indices = self.faiss_index.search(query_vectors, K) return RMOutput(distances=distances, indices=indices) - @abstractmethod - def _embed(self, docs: pd.Series | list) -> NDArray[np.float64]: - pass diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py index fdb89800..f5c4ff32 100644 --- a/lotus/vector_store/pinecone_vs.py +++ b/lotus/vector_store/pinecone_vs.py @@ -1,17 +1,176 @@ +from typing import Any + +import numpy as np +import pandas as pd +from numpy.typing import NDArray +from tqdm import tqdm + +from lotus.types import RMOutput from lotus.vector_store.vs import VS +try: + from pinecone import Index, Pinecone, ServerlessSpec +except ImportError as err: + raise ImportError( + "The pinecone library is required to use PineconeVS. Install it with `pip install pinecone`", + ) from err class PineconeVS(VS): - def __init__(self): - try: - import pinecone - except ImportError: - pinecone = None + def __init__(self, max_batch_size: int = 64): + + api_key = 'pcsk_45ecSY_CW62eJeL4jwj6dUfaqM6j9dL3uwK12rudednzGisWMxJv9bHH2DLz6tWoY91W84' + + """Initialize Pinecone client with API key and environment""" + super() + self.pinecone = Pinecone(api_key=api_key) + self.pc_index:Index | None = None + self.max_batch_size = max_batch_size + + def __del__(self): + return + + + def index(self, docs: pd.Series, embeddings: Any, index_dir: str, **kwargs: dict[str, Any]): + """Create an index and add documents to it""" + self.index_dir = index_dir + + dimension = embeddings.shape[1] + + # Check if index already exists + if index_dir not in self.pinecone.list_indexes().names(): + # Create new index with the correct dimension + self.pinecone.create_index( + name=index_dir, + dimension=dimension, + metric="cosine", + spec=ServerlessSpec( + cloud='aws', + region='us-east-1' + ) + ) + elif self.pinecone.describe_index(index_dir).dimension != dimension: + # resolve any potential dimension-mismatch errors + self.pinecone.delete_index(index_dir) + self.pinecone.create_index( + name=index_dir, + dimension=dimension, + metric="cosine", + spec=ServerlessSpec( + cloud='aws', + region='us-east-1' + ) + ) + + # Connect to index + self.pc_index = self.pinecone.Index(index_dir) + + # Convert docs to list if it's a pandas Series + docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs + + # Prepare vectors for upsert + vectors = [] + for idx, (embedding, doc) in enumerate(zip(embeddings, docs_list)): + vectors.append({ + "id": str(idx), + "values": embedding.tolist(), # Pinecone expects lists, not numpy arrays + "metadata": { + "content": doc, + "doc_id": idx + } + }) + + # Upsert in batches of 100 + batch_size = 100 + for i in tqdm(range(0, len(vectors), batch_size), desc="Uploading to Pinecone"): + batch = vectors[i:i + batch_size] + self.pc_index.upsert(vectors=batch) + + def load_index(self, index_dir: str): + """Connect to an existing Pinecone index""" + if index_dir not in self.pinecone.list_indexes(): + raise ValueError(f"Index {index_dir} not found") + + self.index_dir = index_dir + self.pc_index = self.pinecone.Index(index_dir) + + def __call__( + self, + query_vectors, + K: int, + **kwargs: dict[str, Any] + ) -> RMOutput: + """Perform vector search using Pinecone""" + if self.pc_index is None: + raise ValueError("No index loaded. Call load_index first.") + """ + # Convert single query to list + if isinstance(queries, (str, Image.Image)): + queries = [queries] - if pinecone is None: - raise ImportError( - "The pinecone library is required to use PineconeVS. Install it with `pip install pinecone`", + # Handle numpy array queries (pre-computed vectors) + if isinstance(queries, np.ndarray): + query_vectors = queries + else: + # Convert queries to list if needed + if isinstance(queries, pd.Series): + queries = queries.tolist() + # Create embeddings for text queries + query_vectors = self._batch_embed(queries) + """ + + # Perform searches + all_distances = [] + all_indices = [] + + for query_vector in query_vectors: + # Query Pinecone + results = self.pc_index.query( + vector=query_vector.tolist(), + top_k=K, + include_metadata=True, + **kwargs ) - pass + # Extract distances and indices + distances = [] + indices = [] + + for match in results.matches: + indices.append(int(match.metadata["doc_id"])) + distances.append(match.score) + + # Pad results if fewer than K matches + while len(indices) < K: + indices.append(-1) # Use -1 for padding + distances.append(0.0) + + all_distances.append(distances) + all_indices.append(indices) + + return RMOutput( + distances=np.array(all_distances, dtype=np.float32).tolist(), + indices=np.array(all_indices, dtype=np.int64).tolist() + ) + + def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]: + """Retrieve vectors for specific document IDs""" + if self.pc_index is None or self.index_dir != index_dir: + self.load_index(index_dir) + + if self.pc_index is None: # Add this check after load_index + raise ValueError("Failed to initialize Pinecone index") + + + + # Fetch vectors from Pinecone + vectors = [] + for doc_id in ids: + response = self.pc_index.fetch(ids=[str(doc_id)]) + if str(doc_id) in response.vectors: + vector = response.vectors[str(doc_id)].values + vectors.append(vector) + else: + raise ValueError(f"Document with id {doc_id} not found") + + return np.array(vectors, dtype=np.float64) diff --git a/lotus/vector_store/qdrant_vs.py b/lotus/vector_store/qdrant_vs.py index 5ded2d7d..50f1017f 100644 --- a/lotus/vector_store/qdrant_vs.py +++ b/lotus/vector_store/qdrant_vs.py @@ -1,16 +1,187 @@ +from typing import Any + +import numpy as np +import pandas as pd +from numpy.typing import NDArray +from tqdm import tqdm + +from lotus.types import RMOutput from lotus.vector_store.vs import VS +try: + from qdrant_client import QdrantClient + from qdrant_client.models import Distance, PointStruct, VectorParams +except ImportError as err: + raise ImportError("Please install the qdrant client") from err class QdrantVS(VS): - def __init__(self): - try: - import qdrant_client - except ImportError: - qdrant_client = None + def __init__(self, max_batch_size: int = 64): + + API_KEY = '_Mic3dVln2gAkS6NLyia6p-CCyMScK42ayuq8Rapm5-xsV5j5_UlIA' + + URL = "https://6f8b9aec-a788-4aac-9aeb-417d307493e8.europe-west3-0.gcp.cloud.qdrant.io:6333" + + client: QdrantClient = QdrantClient( + url=URL, + api_key=API_KEY + ) + + """Initialize with Qdrant client and embedding model""" + super() # Fixed the super() call syntax + self.client: QdrantClient = client + self.max_batch_size = max_batch_size + + def __del__(self): + self.client.close() + + def index(self, docs:pd.Series, embeddings, index_dir: str, **kwargs: dict[str, Any]): + """Create a collection and add documents with their embeddings""" + self.index_dir = index_dir + + # Get sample embedding to determine vector dimension + dimension = np.reshape(embeddings, (len(embeddings), -1)).shape[1] + + # Create collection if it doesn't exist + if not self.client.collection_exists(index_dir): + self.client.create_collection( + collection_name=index_dir, + vectors_config=VectorParams(size=dimension, distance=Distance.COSINE) + ) + collection_info = self.client.get_collection(index_dir) + if (collection_info is not None and collection_info.config is not None and collection_info.config.params and collection_info.config.params.vectors): + + vectors = collection_info.config.params.vectors + if isinstance(vectors, dict): + # If it's a dict, decide how to handle it. + # Here we extract the first vector, but you may need a different logic. + vector = next(iter(vectors.values())) + size = vector.size + elif isinstance(vectors, VectorParams): + size = vectors.size + else: + size = None + + if size != dimension: + # If there's a discrepancy, create a new version of that collection + self.client.delete_collection(index_dir) + self.client.create_collection( + collection_name=index_dir, + vectors_config=VectorParams(size=dimension, distance=Distance.COSINE) + ) + + # Convert docs to list if it's a pandas Series + docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs + + # Prepare points for upload + points = [] + for idx, (doc, embedding) in enumerate(zip(docs_list, embeddings)): + points.append( + PointStruct( + id=idx, + vector=embedding.tolist(), + payload={ + "content": doc, + "doc_id": idx + } + ) + ) + + # Upload in batches + batch_size = 100 + for i in tqdm(range(0, len(points), batch_size), desc="Uploading to Qdrant"): + batch = points[i:i + batch_size] + self.client.upsert( + collection_name=index_dir, + points=batch + ) + + def load_index(self, index_dir: str): + """Set the collection name to use""" + if not self.client.collection_exists(index_dir): + raise ValueError(f"Collection {index_dir} not found") + self.index_dir = index_dir + + def __call__( + self, + query_vectors, + K: int, + **kwargs: dict[str, Any] + ) -> RMOutput: + """Perform vector search using Qdrant""" + if self.index_dir is None: + raise ValueError("No collection loaded. Call load_index first.") + + """ + do this in retriever module before passing into here - if qdrant_client is None: - raise ImportError( - "The qdrant library is required to use QdrantVS. Install it with `pip install qdrant_client`", + # Convert single query to list + if isinstance(queries, (str, Image.Image)): + queries = [queries] + + # Handle numpy array queries (pre-computed vectors) + if isinstance(queries, np.ndarray): + query_vectors = queries + else: + # Convert queries to list if needed + if isinstance(queries, pd.Series): + queries = queries.tolist() + # Create embeddings for text queries + query_vectors = self._batch_embed(queries) + """ + + # Perform searches + all_distances = [] + all_indices = [] + + for query_vector in query_vectors: + results = self.client.search( + collection_name=self.index_dir, + query_vector=query_vector.tolist(), + limit=K, + with_payload=True ) - pass + + # Extract distances and indices + distances = [] + indices = [] + + for result in results: + indices.append(result.id) + distances.append(result.score) # Qdrant returns cosine similarity directly + + # Pad results if fewer than K matches + while len(indices) < K: + indices.append(-1) + distances.append(0.0) + + all_distances.append(distances) + all_indices.append(indices) + + return RMOutput( + distances=np.array(all_distances, dtype=np.float32).tolist(), + indices=np.array(all_indices, dtype=np.int64).tolist() + ) + + def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]: + """Retrieve vectors for specific document IDs""" + if self.index_dir != index_dir: + self.load_index(index_dir) + + # Fetch points from Qdrant + points = self.client.retrieve( + collection_name=index_dir, + ids=ids, + with_vectors=True, + with_payload=False + ) + + # Extract and return vectors + vectors = [] + for point in points: + if point.vector is not None: + vectors.append(point.vector) + else: + raise ValueError(f"Vector not found for id {point.id}") + + return np.array(vectors, dtype=np.float64) \ No newline at end of file diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py index bb4878eb..7e1818e6 100644 --- a/lotus/vector_store/vs.py +++ b/lotus/vector_store/vs.py @@ -2,31 +2,65 @@ from typing import Any import numpy as np -import pandas as pd from numpy.typing import NDArray -from PIL import Image from lotus.types import RMOutput +""" +MODEL_NAME_TO_CLS = { + "intfloat/e5-small-v2": lambda model: SentenceTransformer(model_name_or_path=model), + "mixedbread-ai/mxbai-rerank-xsmall-v1": lambda model: CrossEncoder(model_name=model), + "text-embedding-3-small": lambda model: lambda batch: embedding(model=model, input=batch), +} + + +def initialize(model_name): + if model_name == 'intfloat/e5-small-v2': + return SentenceTransformer(model_name_or_path=model_name) + elif model_name== 'mixedbread-ai/mxbai-rerank-xsmall-v1': + return CrossEncoder(model_name=model_name) + return lambda batch: embedding(model=model_name, input=batch) +""" class VS(ABC): """Abstract class for vector stores.""" def __init__(self) -> None: - pass + self.index_dir: str | None = None + self.max_batch_size:int = 64 @abstractmethod - def index(self, docs: pd.Series, index_dir): + def index(self, docs, embeddings: Any, index_dir: str, **kwargs: dict[str, Any]): + """ + Create index and store it in vector store + """ pass + @abstractmethod + def load_index(self, index_dir: str): + """Load the index from the vector store into memory if needed""" + pass + @abstractmethod - def search(self, - queries: pd.Series | str | Image.Image | list | NDArray[np.float64], + def __call__(self, + query_vectors:Any, K:int, **kwargs: dict[str, Any], ) -> RMOutput: pass @abstractmethod - def get_vectors_from_index(self, collection_name:str, ids: list[int]) -> NDArray[np.float64]: - pass \ No newline at end of file + def get_vectors_from_index(self, index_dir:str, ids: list[Any]) -> NDArray[np.float64]: + pass + + """ + def _batch_embed(self, docs: pd.Series | list) -> NDArray[np.float64]: + Create embeddings using the provided embedding model with batching + all_embeddings = [] + for i in tqdm(range(0, len(docs), self.max_batch_size), desc="Creating embeddings"): + batch = docs[i : i + self.max_batch_size] + _batch = convert_to_base_data(batch) + embeddings = self._embed(_batch) + all_embeddings.append(embeddings) + return np.vstack(all_embeddings) + """ \ No newline at end of file diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 3b747235..67b4afc1 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -1,13 +1,194 @@ +from typing import Any, List + +import numpy as np +import pandas as pd +from numpy.typing import NDArray + +from lotus.types import RMOutput from lotus.vector_store.vs import VS +try: + import weaviate + from weaviate.classes.config import Configure, DataType, Property + from weaviate.classes.init import Auth + from weaviate.classes.query import MetadataQuery +except ImportError as err: + raise ImportError("Please install the weaviate client") from err class WeaviateVS(VS): - def __init__(self): + def __init__(self, max_batch_size: int = 64): + + REST_URL = 'https://aliugucnqnkzihdc3jqdig.c0.us-west3.gcp.weaviate.cloud' + + API_KEY = 'e1VUifT3atB7PHLB3kXQPYhL2PNXeG0JeGYK' + + weaviate_client: weaviate.WeaviateClient | None = None # need to set this up + + + weaviate_client = weaviate.connect_to_weaviate_cloud( + cluster_url=REST_URL, # Replace with your Weaviate Cloud URL + auth_credentials=Auth.api_key(API_KEY), # Replace with your Weaviate Cloud key + ) + + """Initialize with Weaviate client and embedding model""" + super() + self.client = weaviate_client + self.max_batch_size = max_batch_size + + def __del__(self): + self.client.close() + + def get_collection_dimension(self, index_dir): + self.client.collections.get(index_dir).config + + def index(self, docs: pd.Series, embeddings, index_dir: str, **kwargs: dict[str, Any]): + """Create a collection and add documents with their embeddings""" + self.index_dir = index_dir + + embedding_dim = np.reshape(embeddings, (len(embeddings), -1)).shape[1] + + # Create collection without vectorizer config (we'll provide vectors directly) + if not self.client.collections.exists(index_dir): + collection = self.client.collections.create( + name=index_dir, + properties=[ + Property( + name='content', + data_type=DataType.TEXT + ), + Property( + name='doc_id', + data_type=DataType.INT, + ) + ], + vectorizer_config=None, # No vectorizer needed as we provide vectors + vector_index_config=Configure.VectorIndex.hnsw() + ) + else: + collection = self.client.collections.get(index_dir) + if self.get_collection_dimension(index_dir) != embedding_dim: + self.client.collections.delete(index_dir) + collection = self.client.collections.create( + name=index_dir, + properties=[ + Property( + name='content', + data_type=DataType.TEXT + ), + Property( + name='doc_id', + data_type=DataType.INT, + ) + ], + vectorizer_config=None, # No vectorizer needed as we provide vectors + vector_index_config=Configure.VectorIndex.hnsw() + ) + + + # Generate embeddings for all documents + docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs + + + # Add documents to collection with their embeddings + with collection.batch.dynamic() as batch: + for idx, (doc, embedding) in enumerate(zip(docs_list, embeddings)): + properties = { + "content": doc, + "doc_id": idx + } + batch.add_object( + properties=properties, + vector=embedding.tolist(), # Provide pre-computed vector + ) + + def load_index(self, index_dir: str): + """Load/set the collection name to use""" + self.index_dir = index_dir + # Verify collection exists try: - import weaviate - except ImportError: - weaviate = None + self.client.collections.get(index_dir) + except weaviate.exceptions.UnexpectedStatusCodeException: + raise ValueError(f"Collection {index_dir} not found") + + def __call__(self, + query_vectors, + K: int, + **kwargs: dict[str, Any] + ) -> RMOutput: + """Perform vector search using pre-computed query vectors""" + if self.index_dir is None: + raise ValueError("No collection loaded. Call load_index first.") + + collection = self.client.collections.get(self.index_dir) + + """ + + do this in the retriever module + # Convert single query to list + if isinstance(queries, (str, Image.Image)): + queries = [queries] + + # Handle numpy array queries (pre-computed vectors) + if isinstance(queries, np.ndarray): + query_vectors = queries + else: + # Generate embeddings for text queries + query_vectors = self._batch_embed(queries) + """ + + # Perform searches + results = [] + for query_vector in query_vectors: + response = (collection.query + .near_vector( + near_vector=query_vector.tolist(), + limit=K, + return_metadata=MetadataQuery(distance=True) + )) + results.append(response) + + # Process results into expected format + all_distances = [] + all_indices = [] + + for result in results: + objects = result.objects - if weaviate is None: - raise ImportError("Please install the weaviate client") - pass + distances:List[float] = [] + indices = [] + for obj in objects: + indices.append(obj.properties.get('doc_id', -1)) + # Convert cosine distance to similarity score + distance = obj.metadata.distance if obj.metadata and obj.metadata.distance is not None else 1.0 + distances.append(1 - distance) # Convert distance to similarity + # Pad results if fewer than K matches + while len(indices) < K: + indices.append(-1) + distances.append(0.0) + + all_distances.append(distances) + all_indices.append(indices) + + return RMOutput( + distances=np.array(all_distances, dtype=np.float32).tolist(), + indices=np.array(all_indices, dtype=np.int64).tolist() + ) + + def get_vectors_from_index(self, index_dir: str, ids: list[Any]) -> NDArray[np.float64]: + """Retrieve vectors for specific document IDs""" + collection = self.client.collections.get(index_dir) + + # Query for documents with specific doc_ids + vectors = [] + for id in ids: + exists = False + for obj in collection.query.fetch_objects().objects: + print(f'object vector: {collection.query.__extract_vector_for_object(obj.metadata)}') + if id == obj.properties.get('doc_id', -1): + exists = True + vectors.append((1)) + if not exists: + raise ValueError(f'{id} does not exist in {index_dir}') + return np.array(vectors, dtype=np.float64) + +