From e3abd9027618e1c267165eb4f8ce6f8a70ddac18 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Sun, 12 Jan 2025 11:12:21 -0800 Subject: [PATCH 01/65] initial scaffolding for adding vector store / vector database integration --- lotus/__init__.py | 2 ++ lotus/settings.py | 2 ++ lotus/vector_store/__init__.py | 13 +++++++++++++ lotus/vector_store/chroma_vs.py | 20 ++++++++++++++++++++ lotus/vector_store/pinecone_vs.py | 18 ++++++++++++++++++ lotus/vector_store/qdrant_vs.py | 21 +++++++++++++++++++++ lotus/vector_store/vs.py | 13 +++++++++++++ lotus/vector_store/weaviate_vs.py | 15 +++++++++++++++ 8 files changed, 104 insertions(+) create mode 100644 lotus/vector_store/__init__.py create mode 100644 lotus/vector_store/chroma_vs.py create mode 100644 lotus/vector_store/pinecone_vs.py create mode 100644 lotus/vector_store/qdrant_vs.py create mode 100644 lotus/vector_store/vs.py create mode 100644 lotus/vector_store/weaviate_vs.py diff --git a/lotus/__init__.py b/lotus/__init__.py index f66cfb5c..d20f710d 100644 --- a/lotus/__init__.py +++ b/lotus/__init__.py @@ -1,6 +1,7 @@ import logging import lotus.dtype_extensions import lotus.models +import lotus.vector_store import lotus.nl_expression import lotus.templates import lotus.utils @@ -44,6 +45,7 @@ "templates", "logger", "models", + "vector_store", "utils", "dtype_extensions", ] diff --git a/lotus/settings.py b/lotus/settings.py index 99e59449..a1c54c56 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -1,4 +1,5 @@ import lotus.models +import lotus.vector_store from lotus.types import SerializationFormat # NOTE: Settings class is not thread-safe @@ -10,6 +11,7 @@ class Settings: rm: lotus.models.RM | None = None helper_lm: lotus.models.LM | None = None reranker: lotus.models.Reranker | None = None + vs: lotus.vector_store.VS | None = None # Cache settings enable_message_cache: bool = False diff --git a/lotus/vector_store/__init__.py b/lotus/vector_store/__init__.py new file mode 100644 index 00000000..0f634b9a --- /dev/null +++ b/lotus/vector_store/__init__.py @@ -0,0 +1,13 @@ +from lotus.vector_store.vs import VS +from lotus.vector_store.weaviate_vs import WeaviateVS +from lotus.vector_store.pinecone_vs import PineconeVS +from lotus.vector_store.chroma_vs import ChromaVS +from lotus.vector_store.qdrant_vs import QdrantVS + +__all__ = [ + "VS", + "WeaviateVS", + "PineconeVS", + "ChromaVS", + "QdrantVS" +] \ No newline at end of file diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py new file mode 100644 index 00000000..4a70720e --- /dev/null +++ b/lotus/vector_store/chroma_vs.py @@ -0,0 +1,20 @@ +from lotus.vector_store.vs import VS + + +try: + import chromadb +except ImportError: + chromadb = None + + +if chromadb is None: + raise ImportError( + "The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`", + ) + + + +class ChromaVS(VS): + + def __init__(self): + pass \ No newline at end of file diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py new file mode 100644 index 00000000..ae1f116b --- /dev/null +++ b/lotus/vector_store/pinecone_vs.py @@ -0,0 +1,18 @@ +from lotus.vector_store.vs import VS + +try: + import pinecone +except ImportError: + pinecone = None + + +if pinecone is None: + raise ImportError( + "The pinecone library is required to use PineconeVS. Install it with `pip install pinecone`", + ) + +class PineconeVS(VS): + + def __init__(self): + pass + diff --git a/lotus/vector_store/qdrant_vs.py b/lotus/vector_store/qdrant_vs.py new file mode 100644 index 00000000..fa628e56 --- /dev/null +++ b/lotus/vector_store/qdrant_vs.py @@ -0,0 +1,21 @@ +from lotus.vector_store.vs import VS + + +try: + import qdrant_client +except ImportError: + qdrant_client = None + + +if qdrant_client is None: + raise ImportError( + "The qdrant library is required to use QdrantVS. Install it with `pip install qdrant_client`", + ) + + + + +class QdrantVS(VS): + + def __init__(self): + pass \ No newline at end of file diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py new file mode 100644 index 00000000..e13f64e5 --- /dev/null +++ b/lotus/vector_store/vs.py @@ -0,0 +1,13 @@ +from abc import ABC, abstractmethod + +import pandas as pd + +class VS(ABC): + """Abstract class for vector stores.""" + + def __init__(self) -> None: + pass + + @abstractmethod + def index(self, docs: pd.Series, index_dir): + pass \ No newline at end of file diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py new file mode 100644 index 00000000..4f18f35a --- /dev/null +++ b/lotus/vector_store/weaviate_vs.py @@ -0,0 +1,15 @@ +from lotus.vector_store.vs import VS + +try: + import weaviate +except ImportError as err: + raise ImportError( + "Please install the weaviate client" + ) + +class WeaviateVS(VS): + + def __init__(self): + pass + + \ No newline at end of file From bd1e8fddf01f6931d25a88cf00018851b86a80bd Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Sun, 12 Jan 2025 13:29:07 -0800 Subject: [PATCH 02/65] fixed linting, ruff checks pass --- lotus/vector_store/__init__.py | 10 ++-------- lotus/vector_store/chroma_vs.py | 7 ++----- lotus/vector_store/pinecone_vs.py | 7 +++---- lotus/vector_store/qdrant_vs.py | 10 +++------- lotus/vector_store/vs.py | 11 ++++++----- lotus/vector_store/weaviate_vs.py | 16 +++++++--------- 6 files changed, 23 insertions(+), 38 deletions(-) diff --git a/lotus/vector_store/__init__.py b/lotus/vector_store/__init__.py index 0f634b9a..34c41998 100644 --- a/lotus/vector_store/__init__.py +++ b/lotus/vector_store/__init__.py @@ -1,13 +1,7 @@ -from lotus.vector_store.vs import VS +from lotus.vector_store.vs import VS from lotus.vector_store.weaviate_vs import WeaviateVS from lotus.vector_store.pinecone_vs import PineconeVS from lotus.vector_store.chroma_vs import ChromaVS from lotus.vector_store.qdrant_vs import QdrantVS -__all__ = [ - "VS", - "WeaviateVS", - "PineconeVS", - "ChromaVS", - "QdrantVS" -] \ No newline at end of file +__all__ = ["VS", "WeaviateVS", "PineconeVS", "ChromaVS", "QdrantVS"] diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index 4a70720e..e94544e0 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -1,10 +1,9 @@ from lotus.vector_store.vs import VS - try: import chromadb except ImportError: - chromadb = None + chromadb = None if chromadb is None: @@ -13,8 +12,6 @@ ) - class ChromaVS(VS): - def __init__(self): - pass \ No newline at end of file + pass diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py index ae1f116b..19fc0967 100644 --- a/lotus/vector_store/pinecone_vs.py +++ b/lotus/vector_store/pinecone_vs.py @@ -3,7 +3,7 @@ try: import pinecone except ImportError: - pinecone = None + pinecone = None if pinecone is None: @@ -11,8 +11,7 @@ "The pinecone library is required to use PineconeVS. Install it with `pip install pinecone`", ) -class PineconeVS(VS): +class PineconeVS(VS): def __init__(self): - pass - + pass diff --git a/lotus/vector_store/qdrant_vs.py b/lotus/vector_store/qdrant_vs.py index fa628e56..6fc82dae 100644 --- a/lotus/vector_store/qdrant_vs.py +++ b/lotus/vector_store/qdrant_vs.py @@ -1,10 +1,9 @@ -from lotus.vector_store.vs import VS - +from lotus.vector_store.vs import VS try: import qdrant_client except ImportError: - qdrant_client = None + qdrant_client = None if qdrant_client is None: @@ -13,9 +12,6 @@ ) - - class QdrantVS(VS): - def __init__(self): - pass \ No newline at end of file + pass diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py index e13f64e5..8bfc43e8 100644 --- a/lotus/vector_store/vs.py +++ b/lotus/vector_store/vs.py @@ -1,13 +1,14 @@ -from abc import ABC, abstractmethod +from abc import ABC, abstractmethod + +import pandas as pd -import pandas as pd class VS(ABC): """Abstract class for vector stores.""" def __init__(self) -> None: - pass + pass - @abstractmethod + @abstractmethod def index(self, docs: pd.Series, index_dir): - pass \ No newline at end of file + pass diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 4f18f35a..0ee5ded2 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -1,15 +1,13 @@ +from typing import Optional, Union + from lotus.vector_store.vs import VS try: import weaviate -except ImportError as err: - raise ImportError( - "Please install the weaviate client" - ) - -class WeaviateVS(VS): +except ImportError: + raise ImportError("Please install the weaviate client") - def __init__(self): - pass - \ No newline at end of file +class WeaviateVS(VS): + def __init__(self, weaviate_collection_name:str, weaviate_client: Union[weaviate.WeaviateClient, weaviate.Client], weaviate_collection_text_key: Optional[str] = "content"): + pass From 880c31f110ec14b72ed35496643aa3f58abcc9fb Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Sun, 12 Jan 2025 13:58:02 -0800 Subject: [PATCH 03/65] added changes to requirements.txt file and added additional abstract methods --- lotus/vector_store/vs.py | 18 ++++++++++++++++++ requirements.txt | 6 +++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py index 8bfc43e8..bb4878eb 100644 --- a/lotus/vector_store/vs.py +++ b/lotus/vector_store/vs.py @@ -1,6 +1,12 @@ from abc import ABC, abstractmethod +from typing import Any +import numpy as np import pandas as pd +from numpy.typing import NDArray +from PIL import Image + +from lotus.types import RMOutput class VS(ABC): @@ -12,3 +18,15 @@ def __init__(self) -> None: @abstractmethod def index(self, docs: pd.Series, index_dir): pass + + @abstractmethod + def search(self, + queries: pd.Series | str | Image.Image | list | NDArray[np.float64], + K:int, + **kwargs: dict[str, Any], + ) -> RMOutput: + pass + + @abstractmethod + def get_vectors_from_index(self, collection_name:str, ids: list[int]) -> NDArray[np.float64]: + pass \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 226370bc..e645c716 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,8 @@ numpy==1.26.4 pandas==2.2.2 sentence-transformers==3.0.1 tiktoken==0.7.0 -tqdm==4.66.4 \ No newline at end of file +tqdm==4.66.4 +weaviate-client==4.10.2 +pinecone==5.4.2 +chromadb==0.6.2 +qdrant-client==1.12.2 \ No newline at end of file From 7b5dfd375ffa19a9e70fc252180755d5b3e7d28c Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Sun, 12 Jan 2025 14:09:02 -0800 Subject: [PATCH 04/65] refactored --- lotus/vector_store/chroma_vs.py | 21 ++++++++++----------- lotus/vector_store/pinecone_vs.py | 20 ++++++++++---------- lotus/vector_store/qdrant_vs.py | 21 ++++++++++----------- lotus/vector_store/weaviate_vs.py | 16 ++++++++-------- 4 files changed, 38 insertions(+), 40 deletions(-) diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index e94544e0..298f05e9 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -1,17 +1,16 @@ from lotus.vector_store.vs import VS -try: - import chromadb -except ImportError: - chromadb = None - - -if chromadb is None: - raise ImportError( - "The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`", - ) - class ChromaVS(VS): def __init__(self): + try: + import chromadb + except ImportError: + chromadb = None + + + if chromadb is None: + raise ImportError( + "The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`", + ) pass diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py index 19fc0967..fdb89800 100644 --- a/lotus/vector_store/pinecone_vs.py +++ b/lotus/vector_store/pinecone_vs.py @@ -1,17 +1,17 @@ from lotus.vector_store.vs import VS -try: - import pinecone -except ImportError: - pinecone = None +class PineconeVS(VS): + def __init__(self): + try: + import pinecone + except ImportError: + pinecone = None -if pinecone is None: - raise ImportError( - "The pinecone library is required to use PineconeVS. Install it with `pip install pinecone`", - ) + if pinecone is None: + raise ImportError( + "The pinecone library is required to use PineconeVS. Install it with `pip install pinecone`", + ) -class PineconeVS(VS): - def __init__(self): pass diff --git a/lotus/vector_store/qdrant_vs.py b/lotus/vector_store/qdrant_vs.py index 6fc82dae..5ded2d7d 100644 --- a/lotus/vector_store/qdrant_vs.py +++ b/lotus/vector_store/qdrant_vs.py @@ -1,17 +1,16 @@ from lotus.vector_store.vs import VS -try: - import qdrant_client -except ImportError: - qdrant_client = None - - -if qdrant_client is None: - raise ImportError( - "The qdrant library is required to use QdrantVS. Install it with `pip install qdrant_client`", - ) - class QdrantVS(VS): def __init__(self): + try: + import qdrant_client + except ImportError: + qdrant_client = None + + + if qdrant_client is None: + raise ImportError( + "The qdrant library is required to use QdrantVS. Install it with `pip install qdrant_client`", + ) pass diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 0ee5ded2..3b747235 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -1,13 +1,13 @@ -from typing import Optional, Union - from lotus.vector_store.vs import VS -try: - import weaviate -except ImportError: - raise ImportError("Please install the weaviate client") - class WeaviateVS(VS): - def __init__(self, weaviate_collection_name:str, weaviate_client: Union[weaviate.WeaviateClient, weaviate.Client], weaviate_collection_text_key: Optional[str] = "content"): + def __init__(self): + try: + import weaviate + except ImportError: + weaviate = None + + if weaviate is None: + raise ImportError("Please install the weaviate client") pass From 08dfabab9fc7d36dd85c8fd26d5bdf6b93e745f7 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Sun, 12 Jan 2025 22:29:30 -0800 Subject: [PATCH 05/65] added tests for clustering and filtering --- tests/test_cluster.py | 104 ++++++++++++++++++++++++++++++++++++++++++ tests/test_filter.py | 102 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 206 insertions(+) create mode 100644 tests/test_cluster.py create mode 100644 tests/test_filter.py diff --git a/tests/test_cluster.py b/tests/test_cluster.py new file mode 100644 index 00000000..5266feae --- /dev/null +++ b/tests/test_cluster.py @@ -0,0 +1,104 @@ +import pandas as pd +import pytest + +from tests.base_test import BaseTest + + +@pytest.fixture +def sample_df(): + return pd.DataFrame({ + "Course Name": [ + "Probability and Random Processes", + "Statistics and Data Analysis", + "Cooking Basics", + "Advanced Culinary Arts", + "Digital Circuit Design", + "Computer Architecture" + ] + }) + + +class TestClusterBy(BaseTest): + def test_basic_clustering(self, sample_df): + """Test basic clustering functionality with 2 clusters""" + result = sample_df.sem_cluster_by("Course Name", 2) + assert "cluster_id" in result.columns + assert len(result["cluster_id"].unique()) == 2 + assert len(result) == len(sample_df) + + + # Get the two clusters + cluster_0_courses = set(result[result["cluster_id"] == 0]["Course Name"]) + cluster_1_courses = set(result[result["cluster_id"] == 1]["Course Name"]) + + # Define the expected course groupings + tech_courses = { + "Probability and Random Processes", + "Statistics and Data Analysis", + "Digital Circuit Design", + "Computer Architecture" + } + culinary_courses = { + "Cooking Basics", + "Advanced Culinary Arts" + } + + # Check that one cluster contains tech courses and the other contains culinary courses + assert (cluster_0_courses == tech_courses and cluster_1_courses == culinary_courses) or \ + (cluster_1_courses == tech_courses and cluster_0_courses == culinary_courses), \ + "Clusters don't match expected course groupings" + + def test_clustering_with_more_clusters(self, sample_df): + """Test clustering with more clusters than necessary""" + result = sample_df.sem_cluster_by("Course Name", 3) + assert len(result["cluster_id"].unique()) == 3 + assert len(result) == len(sample_df) + + def test_clustering_with_single_cluster(self, sample_df): + """Test clustering with single cluster""" + result = sample_df.sem_cluster_by("Course Name", 1) + assert len(result["cluster_id"].unique()) == 1 + assert result["cluster_id"].iloc[0] == 0 + + def test_clustering_with_invalid_column(self, sample_df): + """Test clustering with non-existent column""" + with pytest.raises(ValueError, match="Column .* not found in DataFrame"): + sample_df.sem_cluster_by("NonExistentColumn", 2) + + def test_clustering_with_empty_dataframe(self): + """Test clustering on empty dataframe""" + empty_df = pd.DataFrame(columns=["Course Name"]) + result = empty_df.sem_cluster_by("Course Name", 2) + assert len(result) == 0 + assert "cluster_id" in result.columns + + def test_clustering_similar_items(self, sample_df): + """Test that similar items are clustered together""" + result = sample_df.sem_cluster_by("Course Name", 3) + + # Get cluster IDs for similar courses + stats_cluster = result[result["Course Name"].str.contains("Statistics")]["cluster_id"].iloc[0] + prob_cluster = result[result["Course Name"].str.contains("Probability")]["cluster_id"].iloc[0] + + # Similar courses should be in the same cluster + assert stats_cluster == prob_cluster + + cooking_cluster = result[result["Course Name"].str.contains("Cooking")]["cluster_id"].iloc[0] + culinary_cluster = result[result["Course Name"].str.contains("Culinary")]["cluster_id"].iloc[0] + + assert cooking_cluster == culinary_cluster + + def test_clustering_with_verbose(self, sample_df): + """Test clustering with verbose output""" + result = sample_df.sem_cluster_by("Course Name", 2, verbose=True) + assert "cluster_id" in result.columns + assert len(result["cluster_id"].unique()) == 2 + + def test_clustering_with_iterations(self, sample_df): + """Test clustering with different iteration counts""" + result1 = sample_df.sem_cluster_by("Course Name", 2, niter=5) + result2 = sample_df.sem_cluster_by("Course Name", 2, niter=20) + + # Both should produce valid clusterings + assert len(result1["cluster_id"].unique()) == 2 + assert len(result2["cluster_id"].unique()) == 2 diff --git a/tests/test_filter.py b/tests/test_filter.py new file mode 100644 index 00000000..c3023824 --- /dev/null +++ b/tests/test_filter.py @@ -0,0 +1,102 @@ +import pandas as pd +import pytest + +from lotus.types import CascadeArgs +from tests.base_test import BaseTest + + +@pytest.fixture +def sample_df(): + return pd.DataFrame({ + "Name": ["Alice", "Bob", "Charlie"], + "Age": [25, 30, 17], + "City": ["New York", "London", "Paris"] + }) + + +class TestFilteredSearch(BaseTest): + def test_basic_filter(self, sample_df): + """Test basic filtering functionality""" + result = sample_df.sem_filter("Age greater than 20") + assert len(result) == 2 + assert all(age > 20 for age in result["Age"]) + + def test_filter_with_examples(self, sample_df): + """Test filtering with example data""" + examples = pd.DataFrame({ + "Name": ["David", "Eve"], + "Age": [40, 15], + "City": ["Berlin", "Tokyo"], + "Answer": [True, False] + }) + result = sample_df.sem_filter( + "Age greater than 20", + examples=examples + ) + assert len(result) == 2 + assert all(age > 20 for age in result["Age"]) + + def test_filter_with_explanations(self, sample_df): + """Test filtering with explanations returned""" + result = sample_df.sem_filter( + "Age greater than 20", + return_explanations=True + ) + assert "explanation_filter" in result.columns + assert len(result["explanation_filter"]) == len(result) + + def test_filter_with_raw_outputs(self, sample_df): + """Test filtering with raw outputs returned""" + result = sample_df.sem_filter( + "Age greater than 20", + return_raw_outputs=True + ) + assert "raw_output_filter" in result.columns + assert len(result["raw_output_filter"]) == len(result) + + def test_filter_with_cot_strategy(self, sample_df): + """Test filtering with chain-of-thought reasoning""" + examples = pd.DataFrame({ + "Name": ["David"], + "Age": [40], + "City": ["Berlin"], + "Answer": [True], + "Reasoning": ["The age is 40, which is greater than 20"] + }) + result = sample_df.sem_filter( + "Age greater than 20", + examples=examples, + strategy="cot", + return_explanations=True + ) + assert "explanation_filter" in result.columns + assert len(result) == 2 + + def test_filter_with_invalid_column(self, sample_df): + """Test filtering with non-existent column""" + with pytest.raises(ValueError, match="Column .* not found in DataFrame"): + sample_df.sem_filter("InvalidColumn greater than 20") + + def test_filter_with_cascade(self, sample_df): + """Test filtering with cascade arguments""" + cascade_args = CascadeArgs( + recall_target=0.9, + precision_target=0.9, + sampling_percentage=0.1, + failure_probability=0.2 + ) + result, stats = sample_df.sem_filter( + "Age greater than 20", + cascade_args=cascade_args, + return_stats=True + ) + assert isinstance(stats, dict) + assert "pos_cascade_threshold" in stats + assert "neg_cascade_threshold" in stats + assert len(result) == 2 + + def test_empty_dataframe(self): + """Test filtering on empty dataframe""" + empty_df = pd.DataFrame(columns=["Name", "Age", "City"]) + result = empty_df.sem_filter("Age greater than 20") + assert len(result) == 0 \ No newline at end of file From f3a82c1f80390c1b5cff6e1b03552834c2fc2cc5 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Mon, 13 Jan 2025 12:05:44 -0800 Subject: [PATCH 06/65] made edits to test_filter --- tests/test_filter.py | 170 ++++++++++++++++++++++--------------------- 1 file changed, 89 insertions(+), 81 deletions(-) diff --git a/tests/test_filter.py b/tests/test_filter.py index c3023824..1611340a 100644 --- a/tests/test_filter.py +++ b/tests/test_filter.py @@ -1,102 +1,110 @@ import pandas as pd import pytest -from lotus.types import CascadeArgs from tests.base_test import BaseTest @pytest.fixture def sample_df(): return pd.DataFrame({ - "Name": ["Alice", "Bob", "Charlie"], - "Age": [25, 30, 17], - "City": ["New York", "London", "Paris"] + "Course Name": [ + "Introduction to Programming", + "Advanced Programming", + "Cooking Basics", + "Advanced Culinary Arts", + "Data Structures", + "Algorithms", + "French Cuisine", + "Italian Cooking" + ], + "Department": [ + "CS", "CS", "Culinary", "Culinary", + "CS", "CS", "Culinary", "Culinary" + ], + "Level": [ + 100, 200, 100, 200, + 300, 300, 200, 200 + ] }) -class TestFilteredSearch(BaseTest): - def test_basic_filter(self, sample_df): - """Test basic filtering functionality""" - result = sample_df.sem_filter("Age greater than 20") +class TestSearch(BaseTest): + def test_basic_search(self, sample_df): + """Test basic semantic search functionality""" + df = sample_df.sem_index("Course Name", "course_index") + result = df.sem_search("Course Name", "programming courses", K=2) assert len(result) == 2 - assert all(age > 20 for age in result["Age"]) + assert "Introduction to Programming" in result["Course Name"].values + assert "Advanced Programming" in result["Course Name"].values - def test_filter_with_examples(self, sample_df): - """Test filtering with example data""" - examples = pd.DataFrame({ - "Name": ["David", "Eve"], - "Age": [40, 15], - "City": ["Berlin", "Tokyo"], - "Answer": [True, False] - }) - result = sample_df.sem_filter( - "Age greater than 20", - examples=examples - ) + def test_filtered_search_relational(self, sample_df): + """Test semantic search with relational filter""" + # Index the dataframe + df = sample_df.sem_index("Course Name", "course_index") + + # Apply relational filter and search + filtered_df = df[df["Department"] == "CS"] + result = filtered_df.sem_search("Course Name", "advanced courses", K=2) + assert len(result) == 2 - assert all(age > 20 for age in result["Age"]) - - def test_filter_with_explanations(self, sample_df): - """Test filtering with explanations returned""" - result = sample_df.sem_filter( - "Age greater than 20", - return_explanations=True - ) - assert "explanation_filter" in result.columns - assert len(result["explanation_filter"]) == len(result) + # Should only return CS courses + assert all(dept == "CS" for dept in result["Department"]) + assert "Advanced Programming" in result["Course Name"].values - def test_filter_with_raw_outputs(self, sample_df): - """Test filtering with raw outputs returned""" - result = sample_df.sem_filter( - "Age greater than 20", - return_raw_outputs=True - ) - assert "raw_output_filter" in result.columns - assert len(result["raw_output_filter"]) == len(result) + def test_filtered_search_semantic(self, sample_df): + """Test semantic search after semantic filter""" + # Index the dataframe + df = sample_df.sem_index("Course Name", "course_index") + + # Apply semantic filter and search + filtered_df = df.sem_filter("{Course Name} is related to cooking") + result = filtered_df.sem_search("Course Name", "advanced level courses", K=2) + + assert len(result) == 2 + # Should only return cooking-related courses + assert all(dept == "Culinary" for dept in result["Department"]) + assert "Advanced Culinary Arts" in result["Course Name"].values - def test_filter_with_cot_strategy(self, sample_df): - """Test filtering with chain-of-thought reasoning""" - examples = pd.DataFrame({ - "Name": ["David"], - "Age": [40], - "City": ["Berlin"], - "Answer": [True], - "Reasoning": ["The age is 40, which is greater than 20"] - }) - result = sample_df.sem_filter( - "Age greater than 20", - examples=examples, - strategy="cot", - return_explanations=True - ) - assert "explanation_filter" in result.columns + def test_filtered_search_combined(self, sample_df): + """Test semantic search with both relational and semantic filters""" + # Index the dataframe + df = sample_df.sem_index("Course Name", "course_index") + + # Apply both filters and search + filtered_df = df[df["Level"] >= 200] # relational filter + filtered_df = filtered_df.sem_filter("{Course Name} is related to computer science") # semantic filter + result = filtered_df.sem_search("Course Name", "data structures and algorithms", K=2) + assert len(result) == 2 + # Should only return advanced CS courses + assert all(dept == "CS" for dept in result["Department"]) + assert all(level >= 200 for level in result["Level"]) + assert "Data Structures" in result["Course Name"].values + assert "Algorithms" in result["Course Name"].values - def test_filter_with_invalid_column(self, sample_df): - """Test filtering with non-existent column""" - with pytest.raises(ValueError, match="Column .* not found in DataFrame"): - sample_df.sem_filter("InvalidColumn greater than 20") + def test_filtered_search_empty_result(self, sample_df): + """Test semantic search when filter returns empty result""" + df = sample_df.sem_index("Course Name", "course_index") + + # Apply filter that should return no results + filtered_df = df[df["Level"] > 1000] + result = filtered_df.sem_search("Course Name", "any course", K=2) + + assert len(result) == 0 - def test_filter_with_cascade(self, sample_df): - """Test filtering with cascade arguments""" - cascade_args = CascadeArgs( - recall_target=0.9, - precision_target=0.9, - sampling_percentage=0.1, - failure_probability=0.2 - ) - result, stats = sample_df.sem_filter( - "Age greater than 20", - cascade_args=cascade_args, - return_stats=True + def test_filtered_search_with_scores(self, sample_df): + """Test filtered semantic search with similarity scores""" + df = sample_df.sem_index("Course Name", "course_index") + + filtered_df = df[df["Department"] == "CS"] + result = filtered_df.sem_search( + "Course Name", + "programming courses", + K=2, + return_scores=True ) - assert isinstance(stats, dict) - assert "pos_cascade_threshold" in stats - assert "neg_cascade_threshold" in stats - assert len(result) == 2 - - def test_empty_dataframe(self): - """Test filtering on empty dataframe""" - empty_df = pd.DataFrame(columns=["Name", "Age", "City"]) - result = empty_df.sem_filter("Age greater than 20") - assert len(result) == 0 \ No newline at end of file + + assert "vec_scores_sim_score" in result.columns + assert len(result["vec_scores_sim_score"]) == 2 + # Scores should be between 0 and 1 + assert all(0 <= score <= 1 for score in result["vec_scores_sim_score"]) \ No newline at end of file From fc62846cbda59281499824af33de0bae53661c00 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Mon, 13 Jan 2025 22:07:32 -0800 Subject: [PATCH 07/65] added implementations for weaviate and pinecone vs --- lotus/vector_store/pinecone_vs.py | 154 +++++++++++++++++++++++++-- lotus/vector_store/vs.py | 30 +++++- lotus/vector_store/weaviate_vs.py | 171 ++++++++++++++++++++++++++++-- 3 files changed, 335 insertions(+), 20 deletions(-) diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py index fdb89800..19e7f810 100644 --- a/lotus/vector_store/pinecone_vs.py +++ b/lotus/vector_store/pinecone_vs.py @@ -1,17 +1,153 @@ +from typing import Any, Callable, Union + +import numpy as np +import pandas as pd +from numpy.typing import NDArray +from PIL import Image +from tqdm import tqdm + +from lotus.types import RMOutput from lotus.vector_store.vs import VS +try: + from pinecone import Pinecone +except ImportError as err: + raise ImportError( + "The pinecone library is required to use PineconeVS. Install it with `pip install pinecone`", + ) from err class PineconeVS(VS): - def __init__(self): - try: - import pinecone - except ImportError: - pinecone = None + def __init__(self, api_key: str, environment: str, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + """Initialize Pinecone client with API key and environment""" + super().__init__(embedding_model) + self.pinecone = Pinecone(api_key=api_key) + self.index = None + self.max_batch_size = max_batch_size - if pinecone is None: - raise ImportError( - "The pinecone library is required to use PineconeVS. Install it with `pip install pinecone`", + def index(self, docs: pd.Series, collection_name: str): + """Create an index and add documents to it""" + self.collection_name = collection_name + + # Get sample embedding to determine vector dimension + sample_embedding = self._embed([docs.iloc[0]]) + dimension = sample_embedding.shape[1] + + # Check if index already exists + if collection_name not in self.pinecone.list_indexes(): + # Create new index with the correct dimension + self.pinecone.create_index( + name=collection_name, + dimension=dimension, + metric="cosine" ) + + # Connect to index + self.pc_index = self.pinecone.Index(collection_name) + + # Convert docs to list if it's a pandas Series + docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs + + # Create embeddings using the provided embedding model + embeddings = self._batch_embed(docs_list) + + # Prepare vectors for upsert + vectors = [] + for idx, (embedding, doc) in enumerate(zip(embeddings, docs_list)): + vectors.append({ + "id": str(idx), + "values": embedding.tolist(), # Pinecone expects lists, not numpy arrays + "metadata": { + "content": doc, + "doc_id": idx + } + }) + + # Upsert in batches of 100 + batch_size = 100 + for i in tqdm(range(0, len(vectors), batch_size), desc="Uploading to Pinecone"): + batch = vectors[i:i + batch_size] + self.pc_index.upsert(vectors=batch) + + def load_index(self, collection_name: str): + """Connect to an existing Pinecone index""" + if collection_name not in self.pinecone.list_indexes(): + raise ValueError(f"Index {collection_name} not found") + + self.collection_name = collection_name + self.pc_index = self.pinecone.Index(collection_name) + + def __call__( + self, + queries: Union[pd.Series, str, Image.Image, list, NDArray[np.float64]], + K: int, + **kwargs: dict[str, Any] + ) -> RMOutput: + """Perform vector search using Pinecone""" + if self.pc_index is None: + raise ValueError("No index loaded. Call load_index first.") + + # Convert single query to list + if isinstance(queries, (str, Image.Image)): + queries = [queries] + + # Handle numpy array queries (pre-computed vectors) + if isinstance(queries, np.ndarray): + query_vectors = queries + else: + # Convert queries to list if needed + if isinstance(queries, pd.Series): + queries = queries.tolist() + # Create embeddings for text queries + query_vectors = self._batch_embed(queries) + + # Perform searches + all_distances = [] + all_indices = [] + + for query_vector in query_vectors: + # Query Pinecone + results = self.index.query( + vector=query_vector.tolist(), + top_k=K, + include_metadata=True, + **kwargs + ) + + # Extract distances and indices + distances = [] + indices = [] + + for match in results.matches: + indices.append(int(match.metadata["doc_id"])) + distances.append(match.score) + + # Pad results if fewer than K matches + while len(indices) < K: + indices.append(-1) # Use -1 for padding + distances.append(0.0) + + all_distances.append(distances) + all_indices.append(indices) + + return RMOutput( + distances=np.array(all_distances, dtype=np.float32), + indices=np.array(all_indices, dtype=np.int64) + ) + + def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]: + """Retrieve vectors for specific document IDs""" + if self.pc_index is None or self.collection_name != collection_name: + self.load_index(collection_name) + + # Fetch vectors from Pinecone + vectors = [] + for doc_id in ids: + response = self.pc_index.fetch(ids=[str(doc_id)]) + if str(doc_id) in response.vectors: + vector = response.vectors[str(doc_id)].values + vectors.append(vector) + else: + raise ValueError(f"Document with id {doc_id} not found") - pass + return np.array(vectors, dtype=np.float64) diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py index bb4878eb..89a7bce8 100644 --- a/lotus/vector_store/vs.py +++ b/lotus/vector_store/vs.py @@ -1,26 +1,38 @@ from abc import ABC, abstractmethod +from collections.abc import Callable from typing import Any import numpy as np import pandas as pd +import tqdm from numpy.typing import NDArray from PIL import Image +from lotus.dtype_extensions import convert_to_base_data from lotus.types import RMOutput class VS(ABC): """Abstract class for vector stores.""" - def __init__(self) -> None: + def __init__(self, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]]) -> None: + self.collection_name: str | None = None + self._embed: Callable[[pd.Series | list], NDArray[np.float64]] = embedding_model pass @abstractmethod - def index(self, docs: pd.Series, index_dir): + def index(self, docs: pd.Series, collection_name: str): + """ + Create index and store it in vector store + """ pass + @abstractmethod + def load_index(self, collection_name: str): + """Load the index from the vector store into memory ?? (not sure if this is needed )""" + @abstractmethod - def search(self, + def __call__(self, queries: pd.Series | str | Image.Image | list | NDArray[np.float64], K:int, **kwargs: dict[str, Any], @@ -29,4 +41,14 @@ def search(self, @abstractmethod def get_vectors_from_index(self, collection_name:str, ids: list[int]) -> NDArray[np.float64]: - pass \ No newline at end of file + pass + + def _batch_embed(self, docs: pd.Series | list) -> NDArray[np.float64]: + """Create embeddings using the provided embedding model with batching""" + all_embeddings = [] + for i in tqdm(range(0, len(docs), self.max_batch_size), desc="Creating embeddings"): + batch = docs[i : i + self.max_batch_size] + _batch = convert_to_base_data(batch) + embeddings = self._embed(_batch) + all_embeddings.append(embeddings) + return np.vstack(all_embeddings) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 3b747235..7e84a2a4 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -1,13 +1,170 @@ +from typing import Any, Callable, Union + +import numpy as np +import pandas as pd +from numpy.typing import NDArray +from PIL import Image + +from lotus.types import RMOutput from lotus.vector_store.vs import VS +try: + from uuid import uuid4 + + import weaviate + from weaviate.util import get_valid_uuid +except ImportError as err: + raise ImportError("Please install the weaviate client") from err class WeaviateVS(VS): - def __init__(self): + def __init__(self, weaviate_client: Union[weaviate.WeaviateClient, weaviate.Client], embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + """Initialize with Weaviate client and embedding model""" + super().__init__(embedding_model) + self.client = weaviate_client + self.max_batch_size = max_batch_size + + def index(self, docs: pd.Series, collection_name: str): + """Create a collection and add documents with their embeddings""" + self.collection_name = collection_name + + # Get sample embedding to determine vector dimension + sample_embedding = self._embed([docs.iloc[0]]) + vector_dim = sample_embedding.shape[1] + + # Create collection without vectorizer config (we'll provide vectors directly) + collection = self.client.collections.create( + name=collection_name, + properties=[ + { + "name": "content", + "dataType": ["text"], + }, + { + "name": "doc_id", + "dataType": ["int"], + } + ], + vectorizer_config=None, # No vectorizer needed as we provide vectors + vector_index_config={"distance": "cosine"}, + vectorIndexConfig={ + "distance": "cosine", + "dimension": vector_dim + } + ) + + # Generate embeddings for all documents + docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs + embeddings = self._batch_embed(docs_list) + + # Add documents to collection with their embeddings + with collection.batch.dynamic() as batch: + for idx, (doc, embedding) in enumerate(zip(docs_list, embeddings)): + properties = { + "content": doc, + "doc_id": idx + } + batch.add_object( + properties=properties, + vector=embedding.tolist(), # Provide pre-computed vector + uuid=get_valid_uuid(str(uuid4())) + ) + + def load_index(self, collection_name: str): + """Load/set the collection name to use""" + self.collection_name = collection_name + # Verify collection exists try: - import weaviate - except ImportError: - weaviate = None + self.client.collections.get(collection_name) + except weaviate.exceptions.UnexpectedStatusCodeException: + raise ValueError(f"Collection {collection_name} not found") + + def __call__(self, + queries: Union[pd.Series, str, Image.Image, list, NDArray[np.float64]], + K: int, + **kwargs: dict[str, Any] + ) -> RMOutput: + """Perform vector search using pre-computed query vectors""" + if self.collection_name is None: + raise ValueError("No collection loaded. Call load_index first.") + + collection = self.client.collections.get(self.collection_name) + + # Convert single query to list + if isinstance(queries, (str, Image.Image)): + queries = [queries] + + # Handle numpy array queries (pre-computed vectors) + if isinstance(queries, np.ndarray): + query_vectors = queries + else: + # Generate embeddings for text queries + query_vectors = self._batch_embed(queries) + + # Perform searches + results = [] + for query_vector in query_vectors: + response = (collection.query + .near_vector({ + "vector": query_vector.tolist() + }) + .with_limit(K) + .with_additional(['distance']) + .with_fields(['doc_id']) + .do()) + results.append(response) + + # Process results into expected format + all_distances = [] + all_indices = [] + + for result in results: + objects = result.get('data', {}).get('Get', {}).get(self.collection_name, []) + + distances = [] + indices = [] + for obj in objects: + indices.append(obj['doc_id']) + # Convert cosine distance to similarity score + distance = obj.get('_additional', {}).get('distance', 0) + distances.append(1 - distance) # Convert distance to similarity + + # Pad results if fewer than K matches + while len(indices) < K: + indices.append(-1) + distances.append(0.0) + + all_distances.append(distances) + all_indices.append(indices) + + return RMOutput( + distances=np.array(all_distances, dtype=np.float32), + indices=np.array(all_indices, dtype=np.int64) + ) + + def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]: + """Retrieve vectors for specific document IDs""" + collection = self.client.collections.get(collection_name) + + # Query for documents with specific doc_ids + vectors = [] + for doc_id in ids: + response = (collection.query + .with_fields(['_additional {vector}']) + .with_where({ + 'path': ['doc_id'], + 'operator': 'Equal', + 'valueNumber': doc_id + }) + .do()) - if weaviate is None: - raise ImportError("Please install the weaviate client") - pass + # Extract vector from response + objects = response.get('data', {}).get('Get', {}).get(collection_name, []) + if objects: + vector = objects[0].get('_additional', {}).get('vector', []) + vectors.append(vector) + else: + raise ValueError(f"Document with id {doc_id} not found") + + return np.array(vectors, dtype=np.float64) + + \ No newline at end of file From f2937adcd0ee6f02589e922e0b426a95e8410d0a Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Tue, 14 Jan 2025 13:37:21 -0800 Subject: [PATCH 08/65] added extra refactoring and added implementations for qdrant and chroma_vs --- lotus/vector_store/chroma_vs.py | 147 ++++++++++++++++++++++++++-- lotus/vector_store/pinecone_vs.py | 2 +- lotus/vector_store/qdrant_vs.py | 155 ++++++++++++++++++++++++++++-- lotus/vector_store/vs.py | 3 +- lotus/vector_store/weaviate_vs.py | 2 +- 5 files changed, 289 insertions(+), 20 deletions(-) diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index 298f05e9..d05b1e47 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -1,16 +1,147 @@ +from typing import Any, Callable, Union + +import numpy as np +import pandas as pd +from numpy.typing import NDArray +from PIL import Image +from tqdm import tqdm + +from lotus.types import RMOutput from lotus.vector_store.vs import VS +try: + import chromadb + from chromadb.api import Collection +except ImportError as err: + raise ImportError( + "The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`" + ) from err class ChromaVS(VS): - def __init__(self): + def __init__(self, client: chromadb.Client, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + """Initialize with ChromaDB client and embedding model""" + super().__init__(embedding_model) + self.client = client + self.collection: Collection | None = None + self.collection_name = None + self.max_batch_size = max_batch_size + + + def index(self, docs: pd.Series, collection_name: str): + """Create a collection and add documents with their embeddings""" + self.collection_name = collection_name + + # Create collection without embedding function (we'll provide embeddings directly) + self.collection = self.client.create_collection( + name=collection_name, + metadata={"hnsw:space": "cosine"} # Use cosine similarity for consistency + ) + + # Convert docs to list if it's a pandas Series + docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs + + # Generate embeddings + embeddings = self._batch_embed(docs_list) + + # Prepare documents for addition + ids = [str(i) for i in range(len(docs_list))] + metadatas = [{"doc_id": i} for i in range(len(docs_list))] + + # Add documents in batches + batch_size = 100 + for i in tqdm(range(0, len(docs_list), batch_size), desc="Uploading to ChromaDB"): + end_idx = min(i + batch_size, len(docs_list)) + self.collection.add( + ids=ids[i:end_idx], + documents=docs_list[i:end_idx], + embeddings=embeddings[i:end_idx].tolist(), + metadatas=metadatas[i:end_idx] + ) + + def load_index(self, collection_name: str): + """Load an existing collection""" try: - import chromadb - except ImportError: - chromadb = None + self.collection = self.client.get_collection(collection_name) + self.collection_name = collection_name + except ValueError as e: + raise ValueError(f"Collection {collection_name} not found") from e + + def __call__( + self, + queries: Union[pd.Series, str, Image.Image, list, NDArray[np.float64]], + K: int, + **kwargs: dict[str, Any] + ) -> RMOutput: + """Perform vector search using ChromaDB""" + if self.collection is None: + raise ValueError("No collection loaded. Call load_index first.") + + # Convert single query to list + if isinstance(queries, (str, Image.Image)): + queries = [queries] + + # Handle numpy array queries (pre-computed vectors) + if isinstance(queries, np.ndarray): + query_vectors = queries + else: + # Convert queries to list if needed + if isinstance(queries, pd.Series): + queries = queries.tolist() + # Create embeddings for text queries + query_vectors = self._batch_embed(queries) + # Perform searches + all_distances = [] + all_indices = [] - if chromadb is None: - raise ImportError( - "The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`", + for query_vector in query_vectors: + results = self.collection.query( + query_embeddings=[query_vector.tolist()], + n_results=K, + include=['metadatas', 'distances'] ) - pass + + # Extract distances and indices + distances = [] + indices = [] + + if results['metadatas']: + for metadata, distance in zip(results['metadatas'][0], results['distances'][0]): + indices.append(metadata['doc_id']) + # ChromaDB returns squared L2 distances, convert to cosine similarity + # similarity = 1 - (distance / 2) # Convert L2 distance to cosine similarity + distances.append(1 - (distance / 2)) + + # Pad results if fewer than K matches + while len(indices) < K: + indices.append(-1) + distances.append(0.0) + + all_distances.append(distances) + all_indices.append(indices) + + return RMOutput( + distances=np.array(all_distances, dtype=np.float32), + indices=np.array(all_indices, dtype=np.int64) + ) + + def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]: + """Retrieve vectors for specific document IDs""" + if self.collection is None or self.collection_name != collection_name: + self.load_index(collection_name) + + # Convert integer ids to strings for ChromaDB + str_ids = [str(id) for id in ids] + + # Get embeddings from ChromaDB + results = self.collection.get( + ids=str_ids, + include=['embeddings'] + ) + + if not results['embeddings']: + raise ValueError("No vectors found for the given ids") + + return np.array(results['embeddings'], dtype=np.float64) + + \ No newline at end of file diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py index 19e7f810..da8cc343 100644 --- a/lotus/vector_store/pinecone_vs.py +++ b/lotus/vector_store/pinecone_vs.py @@ -17,7 +17,7 @@ ) from err class PineconeVS(VS): - def __init__(self, api_key: str, environment: str, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + def __init__(self, api_key: str, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): """Initialize Pinecone client with API key and environment""" super().__init__(embedding_model) self.pinecone = Pinecone(api_key=api_key) diff --git a/lotus/vector_store/qdrant_vs.py b/lotus/vector_store/qdrant_vs.py index 5ded2d7d..c672727a 100644 --- a/lotus/vector_store/qdrant_vs.py +++ b/lotus/vector_store/qdrant_vs.py @@ -1,16 +1,153 @@ +from typing import Any, Callable, Union + +import numpy as np +import pandas as pd +from numpy.typing import NDArray +from PIL import Image +from tqdm import tqdm + +from lotus.types import RMOutput from lotus.vector_store.vs import VS +try: + from qdrant_client import QdrantClient + from qdrant_client.models import Distance, PointStruct, VectorParams +except ImportError as err: + raise ImportError("Please install the qdrant client") from err class QdrantVS(VS): - def __init__(self): - try: - import qdrant_client - except ImportError: - qdrant_client = None + def __init__(self, client: QdrantClient, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + """Initialize with Qdrant client and embedding model""" + super().__init__(embedding_model) # Fixed the super() call syntax + self.client = client + self.max_batch_size = max_batch_size + + def index(self, docs: pd.Series, collection_name: str): + """Create a collection and add documents with their embeddings""" + self.collection_name = collection_name + + # Get sample embedding to determine vector dimension + sample_embedding = self._embed([docs.iloc[0]]) + dimension = sample_embedding.shape[1] + + # Create collection if it doesn't exist + if not self.client.collection_exists(collection_name): + self.client.create_collection( + collection_name=collection_name, + vectors_config=VectorParams(size=dimension, distance=Distance.COSINE) + ) + + # Convert docs to list if it's a pandas Series + docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs + + # Generate embeddings + embeddings = self._batch_embed(docs_list) + + # Prepare points for upload + points = [] + for idx, (doc, embedding) in enumerate(zip(docs_list, embeddings)): + points.append( + PointStruct( + id=idx, + vector=embedding.tolist(), + payload={ + "content": doc, + "doc_id": idx + } + ) + ) + + # Upload in batches + batch_size = 100 + for i in tqdm(range(0, len(points), batch_size), desc="Uploading to Qdrant"): + batch = points[i:i + batch_size] + self.client.upsert( + collection_name=collection_name, + points=batch + ) + + def load_index(self, collection_name: str): + """Set the collection name to use""" + if not self.client.collection_exists(collection_name): + raise ValueError(f"Collection {collection_name} not found") + self.collection_name = collection_name + def __call__( + self, + queries: Union[pd.Series, str, Image.Image, list, NDArray[np.float64]], + K: int, + **kwargs: dict[str, Any] + ) -> RMOutput: + """Perform vector search using Qdrant""" + if self.collection_name is None: + raise ValueError("No collection loaded. Call load_index first.") - if qdrant_client is None: - raise ImportError( - "The qdrant library is required to use QdrantVS. Install it with `pip install qdrant_client`", + # Convert single query to list + if isinstance(queries, (str, Image.Image)): + queries = [queries] + + # Handle numpy array queries (pre-computed vectors) + if isinstance(queries, np.ndarray): + query_vectors = queries + else: + # Convert queries to list if needed + if isinstance(queries, pd.Series): + queries = queries.tolist() + # Create embeddings for text queries + query_vectors = self._batch_embed(queries) + + # Perform searches + all_distances = [] + all_indices = [] + + for query_vector in query_vectors: + results = self.client.search( + collection_name=self.collection_name, + query_vector=query_vector.tolist(), + limit=K, + with_payload=True ) - pass + + # Extract distances and indices + distances = [] + indices = [] + + for result in results: + indices.append(result.payload["doc_id"]) + distances.append(result.score) # Qdrant returns cosine similarity directly + + # Pad results if fewer than K matches + while len(indices) < K: + indices.append(-1) + distances.append(0.0) + + all_distances.append(distances) + all_indices.append(indices) + + return RMOutput( + distances=np.array(all_distances, dtype=np.float32), + indices=np.array(all_indices, dtype=np.int64) + ) + + def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]: + """Retrieve vectors for specific document IDs""" + if self.collection_name != collection_name: + self.load_index(collection_name) + + # Fetch points from Qdrant + points = self.client.retrieve( + collection_name=collection_name, + ids=ids, + with_vectors=True, + with_payload=False + ) + + # Extract and return vectors + vectors = [] + for point in points: + if point.vector is not None: + vectors.append(point.vector) + else: + raise ValueError(f"Vector not found for id {point.id}") + + return np.array(vectors, dtype=np.float64) \ No newline at end of file diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py index 89a7bce8..0f37a1a7 100644 --- a/lotus/vector_store/vs.py +++ b/lotus/vector_store/vs.py @@ -18,7 +18,8 @@ class VS(ABC): def __init__(self, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]]) -> None: self.collection_name: str | None = None self._embed: Callable[[pd.Series | list], NDArray[np.float64]] = embedding_model - pass + self.max_batch_size:int = 64 + @abstractmethod def index(self, docs: pd.Series, collection_name: str): diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 364bca31..e1e957b5 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -17,7 +17,7 @@ raise ImportError("Please install the weaviate client") from err class WeaviateVS(VS): - def __init__(self, weaviate_client: Union[weaviate.WeaviateClient, weaviate.Client], embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + def __init__(self, weaviate_client: weaviate.WeaviateClient, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): """Initialize with Weaviate client and embedding model""" super().__init__(embedding_model) self.client = weaviate_client From a4c741817b74c3e243f2f0bce79f91445cbc25df Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Tue, 14 Jan 2025 15:00:51 -0800 Subject: [PATCH 09/65] fixed some type errors --- lotus/vector_store/weaviate_vs.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index e1e957b5..0de158af 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -12,6 +12,7 @@ from uuid import uuid4 import weaviate + from weaviate.classes.config import Configure, DataType, Property from weaviate.util import get_valid_uuid except ImportError as err: raise ImportError("Please install the weaviate client") from err @@ -27,29 +28,21 @@ def index(self, docs: pd.Series, collection_name: str): """Create a collection and add documents with their embeddings""" self.collection_name = collection_name - # Get sample embedding to determine vector dimension - sample_embedding = self._embed([docs.iloc[0]]) - vector_dim = sample_embedding.shape[1] - # Create collection without vectorizer config (we'll provide vectors directly) collection = self.client.collections.create( name=collection_name, properties=[ - { - "name": "content", - "dataType": ["text"], - }, - { - "name": "doc_id", - "dataType": ["int"], - } + Property( + name='content', + data_type=DataType.TEXT + ), + Property( + name='doc_id', + data_type=DataType.INT, + ) ], vectorizer_config=None, # No vectorizer needed as we provide vectors - vector_index_config={"distance": "cosine"}, - vectorIndexConfig={ - "distance": "cosine", - "dimension": vector_dim - } + vector_index_config=Configure.VectorIndex.dynamic() ) # Generate embeddings for all documents From 1357fb339fd5abb7a4492242152bac2bee8a32cd Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Tue, 14 Jan 2025 18:24:45 -0800 Subject: [PATCH 10/65] made further corrections --- lotus/vector_store/weaviate_vs.py | 48 ++++++++++++------------------- 1 file changed, 19 insertions(+), 29 deletions(-) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 0de158af..d05fe6eb 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -13,6 +13,7 @@ import weaviate from weaviate.classes.config import Configure, DataType, Property + from weaviate.classes.query import MetadataQuery from weaviate.util import get_valid_uuid except ImportError as err: raise ImportError("Please install the weaviate client") from err @@ -58,7 +59,7 @@ def index(self, docs: pd.Series, collection_name: str): } batch.add_object( properties=properties, - vector=embedding.tolist(), # Provide pre-computed vector + vector=embedding.tolist(), # Provide pre-computed vector uuid=get_valid_uuid(str(uuid4())) ) @@ -97,13 +98,11 @@ def __call__(self, results = [] for query_vector in query_vectors: response = (collection.query - .near_vector({ - "vector": query_vector.tolist() - }) - .with_limit(K) - .with_additional(['distance']) - .with_fields(['doc_id']) - .do()) + .near_vector( + near_vector=query_vector.tolist(), + limit=K, + return_metadata=MetadataQuery(distance=True) + )) results.append(response) # Process results into expected format @@ -111,14 +110,14 @@ def __call__(self, all_indices = [] for result in results: - objects = result.get('data', {}).get('Get', {}).get(self.collection_name, []) + objects = result.objects distances = [] indices = [] for obj in objects: - indices.append(obj['doc_id']) + indices.append(obj.properties.get('content')) # Convert cosine distance to similarity score - distance = obj.get('_additional', {}).get('distance', 0) + distance = obj.metadata.distance distances.append(1 - distance) # Convert distance to similarity # Pad results if fewer than K matches @@ -130,8 +129,8 @@ def __call__(self, all_indices.append(indices) return RMOutput( - distances=np.array(all_distances, dtype=np.float32), - indices=np.array(all_indices, dtype=np.int64) + distances=np.array(all_distances, dtype=np.float32).tolist(), + indices=np.array(all_indices, dtype=np.int64).tolist() ) def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]: @@ -140,23 +139,14 @@ def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArra # Query for documents with specific doc_ids vectors = [] - for doc_id in ids: - response = (collection.query - .with_fields(['_additional {vector}']) - .with_where({ - 'path': ['doc_id'], - 'operator': 'Equal', - 'valueNumber': doc_id - }) - .do()) - - # Extract vector from response - objects = response.get('data', {}).get('Get', {}).get(collection_name, []) - if objects: - vector = objects[0].get('_additional', {}).get('vector', []) - vectors.append(vector) + + response = collection.query.fetch_objects_by_ids(ids=ids) + for id in ids: + response = collection.query.fetch_object_by_id(uuid=id) + if response: + vectors.append(response.vector) else: - raise ValueError(f"Document with id {doc_id} not found") + raise ValueError(f'{id} does not exist in {collection_name}') return np.array(vectors, dtype=np.float64) From c76b658519512dc4ab52e997efa257eade523c15 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Tue, 14 Jan 2025 18:33:53 -0800 Subject: [PATCH 11/65] edit uuid type --- lotus/vector_store/weaviate_vs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index d05fe6eb..eab46d0c 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -133,7 +133,7 @@ def __call__(self, indices=np.array(all_indices, dtype=np.int64).tolist() ) - def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]: + def get_vectors_from_index(self, collection_name: str, ids: list[uuid4]) -> NDArray[np.float64]: """Retrieve vectors for specific document IDs""" collection = self.client.collections.get(collection_name) From 9f257f77eab80be60ea30a876414f6294287b5ed Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Tue, 14 Jan 2025 18:57:31 -0800 Subject: [PATCH 12/65] changed uuid type --- lotus/vector_store/weaviate_vs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index eab46d0c..82fff0b4 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -133,7 +133,7 @@ def __call__(self, indices=np.array(all_indices, dtype=np.int64).tolist() ) - def get_vectors_from_index(self, collection_name: str, ids: list[uuid4]) -> NDArray[np.float64]: + def get_vectors_from_index(self, collection_name: str, ids: list[str]) -> NDArray[np.float64]: """Retrieve vectors for specific document IDs""" collection = self.client.collections.get(collection_name) From 99cb535ad92f0bcab42f818b8366e973dd4c8ed7 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Tue, 14 Jan 2025 22:15:43 -0800 Subject: [PATCH 13/65] made type changes to weaviate file --- lotus/vector_store/vs.py | 2 +- lotus/vector_store/weaviate_vs.py | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py index 0f37a1a7..c370df61 100644 --- a/lotus/vector_store/vs.py +++ b/lotus/vector_store/vs.py @@ -41,7 +41,7 @@ def __call__(self, pass @abstractmethod - def get_vectors_from_index(self, collection_name:str, ids: list[int]) -> NDArray[np.float64]: + def get_vectors_from_index(self, collection_name:str, ids: list[any]) -> NDArray[np.float64]: pass def _batch_embed(self, docs: pd.Series | list) -> NDArray[np.float64]: diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 82fff0b4..26585cc8 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Union +from typing import Any, Callable, List, Union import numpy as np import pandas as pd @@ -103,6 +103,7 @@ def __call__(self, limit=K, return_metadata=MetadataQuery(distance=True) )) + response.objects[0].metadata.distance results.append(response) # Process results into expected format @@ -112,12 +113,12 @@ def __call__(self, for result in results: objects = result.objects - distances = [] + distances:List[float] = [] indices = [] for obj in objects: indices.append(obj.properties.get('content')) # Convert cosine distance to similarity score - distance = obj.metadata.distance + distance:float = obj.metadata.distance distances.append(1 - distance) # Convert distance to similarity # Pad results if fewer than K matches @@ -133,21 +134,19 @@ def __call__(self, indices=np.array(all_indices, dtype=np.int64).tolist() ) - def get_vectors_from_index(self, collection_name: str, ids: list[str]) -> NDArray[np.float64]: + def get_vectors_from_index(self, collection_name: str, ids: list[any]) -> NDArray[np.float64]: """Retrieve vectors for specific document IDs""" collection = self.client.collections.get(collection_name) # Query for documents with specific doc_ids vectors = [] - response = collection.query.fetch_objects_by_ids(ids=ids) for id in ids: response = collection.query.fetch_object_by_id(uuid=id) if response: vectors.append(response.vector) else: raise ValueError(f'{id} does not exist in {collection_name}') - return np.array(vectors, dtype=np.float64) From 3c8a742f5123888f3c0197acee18228893d22fa9 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Tue, 14 Jan 2025 22:21:27 -0800 Subject: [PATCH 14/65] made another change --- lotus/vector_store/vs.py | 2 +- lotus/vector_store/weaviate_vs.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py index c370df61..7d3e2c00 100644 --- a/lotus/vector_store/vs.py +++ b/lotus/vector_store/vs.py @@ -41,7 +41,7 @@ def __call__(self, pass @abstractmethod - def get_vectors_from_index(self, collection_name:str, ids: list[any]) -> NDArray[np.float64]: + def get_vectors_from_index(self, collection_name:str, ids: list[Any]) -> NDArray[np.float64]: pass def _batch_embed(self, docs: pd.Series | list) -> NDArray[np.float64]: diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 26585cc8..702fb510 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -134,7 +134,7 @@ def __call__(self, indices=np.array(all_indices, dtype=np.int64).tolist() ) - def get_vectors_from_index(self, collection_name: str, ids: list[any]) -> NDArray[np.float64]: + def get_vectors_from_index(self, collection_name: str, ids: list[Any]) -> NDArray[np.float64]: """Retrieve vectors for specific document IDs""" collection = self.client.collections.get(collection_name) From ccd9e489a5ff23d6a85e3420b5a3972d8f18f544 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Tue, 14 Jan 2025 22:30:20 -0800 Subject: [PATCH 15/65] typecheck passes for weaviate? --- lotus/vector_store/weaviate_vs.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 702fb510..44eb7463 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -118,9 +118,8 @@ def __call__(self, for obj in objects: indices.append(obj.properties.get('content')) # Convert cosine distance to similarity score - distance:float = obj.metadata.distance - distances.append(1 - distance) # Convert distance to similarity - + distance = obj.metadata.distance if obj.metadata and obj.metadata.distance is not None else 1.0 + distances.append(1 - distance) # Convert distance to similarity # Pad results if fewer than K matches while len(indices) < K: indices.append(-1) From 89bf9743ec5358241ac96595d05a3ffbba075739 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Wed, 15 Jan 2025 17:09:59 -0800 Subject: [PATCH 16/65] type changes for weaviate and qdrant files --- lotus/vector_store/qdrant_vs.py | 6 +++--- lotus/vector_store/weaviate_vs.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lotus/vector_store/qdrant_vs.py b/lotus/vector_store/qdrant_vs.py index c672727a..28a7c7dd 100644 --- a/lotus/vector_store/qdrant_vs.py +++ b/lotus/vector_store/qdrant_vs.py @@ -113,7 +113,7 @@ def __call__( indices = [] for result in results: - indices.append(result.payload["doc_id"]) + indices.append(result.id) distances.append(result.score) # Qdrant returns cosine similarity directly # Pad results if fewer than K matches @@ -125,8 +125,8 @@ def __call__( all_indices.append(indices) return RMOutput( - distances=np.array(all_distances, dtype=np.float32), - indices=np.array(all_indices, dtype=np.int64) + distances=np.array(all_distances, dtype=np.float32).tolist(), + indices=np.array(all_indices, dtype=np.int64).tolist() ) def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]: diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 44eb7463..9f18ce0c 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -103,7 +103,7 @@ def __call__(self, limit=K, return_metadata=MetadataQuery(distance=True) )) - response.objects[0].metadata.distance + response.objects[0].uuid results.append(response) # Process results into expected format @@ -116,7 +116,7 @@ def __call__(self, distances:List[float] = [] indices = [] for obj in objects: - indices.append(obj.properties.get('content')) + indices.append(obj.uuid) # Convert cosine distance to similarity score distance = obj.metadata.distance if obj.metadata and obj.metadata.distance is not None else 1.0 distances.append(1 - distance) # Convert distance to similarity From a76adb76c29dfc3c21d409077e9b08c411e3a746 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Wed, 15 Jan 2025 17:18:23 -0800 Subject: [PATCH 17/65] made changes to weaviate file --- lotus/vector_store/weaviate_vs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 9f18ce0c..e12f34e2 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -9,6 +9,7 @@ from lotus.vector_store.vs import VS try: + import uuid from uuid import uuid4 import weaviate @@ -122,7 +123,7 @@ def __call__(self, distances.append(1 - distance) # Convert distance to similarity # Pad results if fewer than K matches while len(indices) < K: - indices.append(-1) + indices.append(uuid.UUID(0)) distances.append(0.0) all_distances.append(distances) From c3e0f0c9e4f54657a7647dff523f002e11e46918 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Wed, 15 Jan 2025 17:29:03 -0800 Subject: [PATCH 18/65] made changes to weaviate file --- lotus/vector_store/weaviate_vs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index e12f34e2..8d9937fa 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -123,7 +123,7 @@ def __call__(self, distances.append(1 - distance) # Convert distance to similarity # Pad results if fewer than K matches while len(indices) < K: - indices.append(uuid.UUID(0)) + indices.append(uuid.UUID(int=0)) distances.append(0.0) all_distances.append(distances) From 1782281d957bc7f59b09dd05f55c4f522cc40d62 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Wed, 15 Jan 2025 22:14:07 -0800 Subject: [PATCH 19/65] fixed pinecone type errors --- lotus/vector_store/pinecone_vs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py index da8cc343..34867018 100644 --- a/lotus/vector_store/pinecone_vs.py +++ b/lotus/vector_store/pinecone_vs.py @@ -21,7 +21,7 @@ def __init__(self, api_key: str, embedding_model: Callable[[pd.Series | list], N """Initialize Pinecone client with API key and environment""" super().__init__(embedding_model) self.pinecone = Pinecone(api_key=api_key) - self.index = None + self.pc_index = None self.max_batch_size = max_batch_size @@ -107,7 +107,7 @@ def __call__( for query_vector in query_vectors: # Query Pinecone - results = self.index.query( + results = self.pc_index.query( vector=query_vector.tolist(), top_k=K, include_metadata=True, @@ -131,8 +131,8 @@ def __call__( all_indices.append(indices) return RMOutput( - distances=np.array(all_distances, dtype=np.float32), - indices=np.array(all_indices, dtype=np.int64) + distances=np.array(all_distances, dtype=np.float32).tolist(), + indices=np.array(all_indices, dtype=np.int64).tolist() ) def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]: From 0621b9baffaa19e6d5142bcaa53b707a79d1266b Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Wed, 15 Jan 2025 22:46:15 -0800 Subject: [PATCH 20/65] fixed pinecone type errors --- lotus/vector_store/pinecone_vs.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py index 34867018..892faaa2 100644 --- a/lotus/vector_store/pinecone_vs.py +++ b/lotus/vector_store/pinecone_vs.py @@ -10,7 +10,7 @@ from lotus.vector_store.vs import VS try: - from pinecone import Pinecone + from pinecone import Index, Pinecone except ImportError as err: raise ImportError( "The pinecone library is required to use PineconeVS. Install it with `pip install pinecone`", @@ -21,7 +21,7 @@ def __init__(self, api_key: str, embedding_model: Callable[[pd.Series | list], N """Initialize Pinecone client with API key and environment""" super().__init__(embedding_model) self.pinecone = Pinecone(api_key=api_key) - self.pc_index = None + self.pc_index:Index | None = None self.max_batch_size = max_batch_size @@ -140,6 +140,11 @@ def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArra if self.pc_index is None or self.collection_name != collection_name: self.load_index(collection_name) + if self.pc_index is None: # Add this check after load_index + raise ValueError("Failed to initialize Pinecone index") + + + # Fetch vectors from Pinecone vectors = [] for doc_id in ids: From b568d1ef5868f0538c8cd3abbb3571aac50366f0 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Thu, 16 Jan 2025 09:58:33 -0800 Subject: [PATCH 21/65] type checks all pass locally --- lotus/vector_store/chroma_vs.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index d05b1e47..3735f798 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Union +from typing import Any, Callable, Mapping, Union import numpy as np import pandas as pd @@ -11,14 +11,16 @@ try: import chromadb + from chromadb import ClientAPI from chromadb.api import Collection + from chromadb.api.types import IncludeEnum except ImportError as err: raise ImportError( "The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`" ) from err class ChromaVS(VS): - def __init__(self, client: chromadb.Client, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + def __init__(self, client: ClientAPI, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): """Initialize with ChromaDB client and embedding model""" super().__init__(embedding_model) self.client = client @@ -45,7 +47,7 @@ def index(self, docs: pd.Series, collection_name: str): # Prepare documents for addition ids = [str(i) for i in range(len(docs_list))] - metadatas = [{"doc_id": i} for i in range(len(docs_list))] + metadatas: list[Mapping[str, Union[str, int, float, bool]]] = [{"doc_id": int(i)} for i in range(len(docs_list))] # Add documents in batches batch_size = 100 @@ -98,14 +100,14 @@ def __call__( results = self.collection.query( query_embeddings=[query_vector.tolist()], n_results=K, - include=['metadatas', 'distances'] + include=[IncludeEnum.metadatas, IncludeEnum.distances] ) # Extract distances and indices distances = [] indices = [] - if results['metadatas']: + if results['metadatas'] and results['distances']: for metadata, distance in zip(results['metadatas'][0], results['distances'][0]): indices.append(metadata['doc_id']) # ChromaDB returns squared L2 distances, convert to cosine similarity @@ -121,8 +123,8 @@ def __call__( all_indices.append(indices) return RMOutput( - distances=np.array(all_distances, dtype=np.float32), - indices=np.array(all_indices, dtype=np.int64) + distances=np.array(all_distances, dtype=np.float32).tolist(), + indices=np.array(all_indices, dtype=np.int64).tolist() ) def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]: @@ -130,13 +132,18 @@ def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArra if self.collection is None or self.collection_name != collection_name: self.load_index(collection_name) + + if self.collection is None: # Add this check after load_index + raise ValueError(f"Failed to load collection {collection_name}") + + # Convert integer ids to strings for ChromaDB str_ids = [str(id) for id in ids] # Get embeddings from ChromaDB results = self.collection.get( ids=str_ids, - include=['embeddings'] + include=[IncludeEnum.embeddings] ) if not results['embeddings']: From 9b33a1f832e7e7c96af5134c8161a19ffe1dae67 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Thu, 16 Jan 2025 09:59:16 -0800 Subject: [PATCH 22/65] fixed linting errors --- lotus/vector_store/chroma_vs.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index 3735f798..7a7e8de9 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -10,10 +10,9 @@ from lotus.vector_store.vs import VS try: - import chromadb - from chromadb import ClientAPI + from chromadb import ClientAPI from chromadb.api import Collection - from chromadb.api.types import IncludeEnum + from chromadb.api.types import IncludeEnum except ImportError as err: raise ImportError( "The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`" From 820f3beff44ee9ed123d233695b5d1a3d29a2220 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Thu, 16 Jan 2025 21:30:00 -0800 Subject: [PATCH 23/65] made refactors to allow for testing --- .github/tests/rm_tests.py | 6 ++++++ lotus/vector_store/chroma_vs.py | 4 ++-- lotus/vector_store/pinecone_vs.py | 4 ++-- lotus/vector_store/qdrant_vs.py | 4 ++-- lotus/vector_store/vs.py | 23 +++++++++++++++++++---- lotus/vector_store/weaviate_vs.py | 4 ++-- 6 files changed, 33 insertions(+), 12 deletions(-) diff --git a/.github/tests/rm_tests.py b/.github/tests/rm_tests.py index 2c00e116..f9bb5fae 100644 --- a/.github/tests/rm_tests.py +++ b/.github/tests/rm_tests.py @@ -41,9 +41,15 @@ def setup_models(): for model_name in ENABLED_MODEL_NAMES: models[model_name] = MODEL_NAME_TO_CLS[model_name](model=model_name) + + return models +@pytest.fixture(scope='session') +def setup_vs(): + pass + ################################################################################ # RM Only Tests ################################################################################ diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index 7a7e8de9..c304658e 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Mapping, Union +from typing import Any, Mapping, Union import numpy as np import pandas as pd @@ -19,7 +19,7 @@ ) from err class ChromaVS(VS): - def __init__(self, client: ClientAPI, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + def __init__(self, client: ClientAPI, embedding_model: str, max_batch_size: int = 64): """Initialize with ChromaDB client and embedding model""" super().__init__(embedding_model) self.client = client diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py index 892faaa2..2aeedda8 100644 --- a/lotus/vector_store/pinecone_vs.py +++ b/lotus/vector_store/pinecone_vs.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Union +from typing import Any, Union import numpy as np import pandas as pd @@ -17,7 +17,7 @@ ) from err class PineconeVS(VS): - def __init__(self, api_key: str, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + def __init__(self, api_key: str, embedding_model: str, max_batch_size: int = 64): """Initialize Pinecone client with API key and environment""" super().__init__(embedding_model) self.pinecone = Pinecone(api_key=api_key) diff --git a/lotus/vector_store/qdrant_vs.py b/lotus/vector_store/qdrant_vs.py index 28a7c7dd..f6ae180b 100644 --- a/lotus/vector_store/qdrant_vs.py +++ b/lotus/vector_store/qdrant_vs.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Union +from typing import Any, Union import numpy as np import pandas as pd @@ -16,7 +16,7 @@ raise ImportError("Please install the qdrant client") from err class QdrantVS(VS): - def __init__(self, client: QdrantClient, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + def __init__(self, client: QdrantClient, embedding_model: str, max_batch_size: int = 64): """Initialize with Qdrant client and embedding model""" super().__init__(embedding_model) # Fixed the super() call syntax self.client = client diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py index 7d3e2c00..2ad4c063 100644 --- a/lotus/vector_store/vs.py +++ b/lotus/vector_store/vs.py @@ -1,24 +1,39 @@ from abc import ABC, abstractmethod -from collections.abc import Callable from typing import Any import numpy as np import pandas as pd import tqdm +from litellm import embedding from numpy.typing import NDArray from PIL import Image +from sentence_transformers import CrossEncoder, SentenceTransformer from lotus.dtype_extensions import convert_to_base_data from lotus.types import RMOutput +MODEL_NAME_TO_CLS = { + "intfloat/e5-small-v2": lambda model: SentenceTransformer(model_name_or_path=model), + "mixedbread-ai/mxbai-rerank-xsmall-v1": lambda model: CrossEncoder(model_name=model), + "text-embedding-3-small": lambda model: lambda batch: embedding(model=model, input=batch), +} + + +def initialize(model_name): + if model_name == 'intfloat/e5-small-v2': + return SentenceTransformer(model_name=model_name) + elif model_name== 'mixedbread-ai/mxbai-rerank-xsmall-v1': + return CrossEncoder(model_name=model_name) + return lambda batch: embedding(model=model_name, input=batch) + class VS(ABC): """Abstract class for vector stores.""" - def __init__(self, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]]) -> None: + def __init__(self, embedding_model: str) -> None: self.collection_name: str | None = None - self._embed: Callable[[pd.Series | list], NDArray[np.float64]] = embedding_model - self.max_batch_size:int = 64 + self._embed = initialize(embedding_model) + self.max_batch_size:int = 64 @abstractmethod diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 8d9937fa..786fe49a 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Union +from typing import Any, List, Union import numpy as np import pandas as pd @@ -20,7 +20,7 @@ raise ImportError("Please install the weaviate client") from err class WeaviateVS(VS): - def __init__(self, weaviate_client: weaviate.WeaviateClient, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + def __init__(self, weaviate_client: weaviate.WeaviateClient, embedding_model: str, max_batch_size: int = 64): """Initialize with Weaviate client and embedding model""" super().__init__(embedding_model) self.client = weaviate_client From a0a70d27f91010ae41d2ad7de329cfbc4ff5448c Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Tue, 21 Jan 2025 18:05:32 -0800 Subject: [PATCH 24/65] made changes to tests --- .github/tests/rm_tests.py | 153 +++++++++++++++++++++++++++++- lotus/sem_ops/sem_index.py | 2 +- lotus/sem_ops/sem_search.py | 4 +- lotus/sem_ops/sem_sim_join.py | 7 +- lotus/settings.py | 7 ++ lotus/vector_store/chroma_vs.py | 33 ++++--- lotus/vector_store/pinecone_vs.py | 34 ++++--- lotus/vector_store/qdrant_vs.py | 45 +++++---- lotus/vector_store/vs.py | 11 ++- lotus/vector_store/weaviate_vs.py | 43 ++++++--- 10 files changed, 270 insertions(+), 69 deletions(-) diff --git a/.github/tests/rm_tests.py b/.github/tests/rm_tests.py index f9bb5fae..08893115 100644 --- a/.github/tests/rm_tests.py +++ b/.github/tests/rm_tests.py @@ -5,6 +5,7 @@ import lotus from lotus.models import CrossEncoderReranker, LiteLLMRM, SentenceTransformersRM +from lotus.vector_store import ChromaVS, PineconeVS, QdrantVS, WeaviateVS ################################################################################ # Setup @@ -30,6 +31,13 @@ "text-embedding-3-small": LiteLLMRM, } +VECTOR_STORE_TO_CLS = { + 'weaviate':WeaviateVS, + 'pinecone': PineconeVS, + 'chroma': ChromaVS, + 'qdrant': QdrantVS +} + def get_enabled(*candidate_models: str) -> list[str]: return [model for model in candidate_models if model in ENABLED_MODEL_NAMES] @@ -48,7 +56,13 @@ def setup_models(): @pytest.fixture(scope='session') def setup_vs(): - pass + vs_and_embed_model = {} + + for vs in VECTOR_STORE_TO_CLS: + for model_name in ENABLED_MODEL_NAMES: + vs_and_embed_model[(vs, model_name)] = VECTOR_STORE_TO_CLS[vs](embedding_model=model_name) + + return vs_and_embed_model ################################################################################ # RM Only Tests @@ -131,6 +145,143 @@ def test_sim_join(setup_models, model): def test_dedup(setup_models): rm = setup_models["intfloat/e5-small-v2"] lotus.settings.configure(rm=rm) + data = { + "Text": [ + "Probability and Random Processes", + "Probability and Markov Chains", + "Harry Potter",3 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + "Harry James Potter", + ] + } + df = pd.DataFrame(data) + df = df.sem_index("Text", "index_dir").sem_dedup("Text", threshold=0.85) + kept = df["Text"].tolist() + kept.sort() + assert len(kept) == 2, kept + assert "Harry" in kept[0], kept + assert "Probability" in kept[1], kept + + + +################################################################################ +# VS Only Tests +################################################################################ + + +@pytest.mark.parametrize("vs", VECTOR_STORE_TO_CLS.keys()) +@pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) +def test_vs_cluster_by(setup_vs, vs, model): + my_vs = setup_vs[(vs, model)] + lotus.settings.configure(vs=my_vs) + + data = { + "Course Name": [ + "Probability and Random Processes", + "Cooking", + "Food Sciences", + "Optimization Methods in Engineering", + ] + } + df = pd.DataFrame(data) + df = df.sem_index("Course Name", "index_dir") + df = df.sem_cluster_by("Course Name", 2) + groups = df.groupby("cluster_id")["Course Name"].apply(set).to_dict() + assert len(groups) == 2, groups + if "Cooking" in groups[0]: + cooking_group = groups[0] + probability_group = groups[1] + else: + cooking_group = groups[1] + probability_group = groups[0] + + assert cooking_group == {"Cooking", "Food Sciences"}, groups + assert probability_group == {"Probability and Random Processes", "Optimization Methods in Engineering"}, groups + +@pytest.mark.parametrize("vs", VECTOR_STORE_TO_CLS.keys()) +@pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) +def test_vs_search_rm_only(setup_vs, vs, model): + my_vs = setup_vs[(vs, model)] + lotus.settings.configure(vs=my_vs) + + data = { + "Course Name": [ + "Probability and Random Processes", + "Cooking", + "Food Sciences", + "Optimization Methods in Engineering", + ] + } + df = pd.DataFrame(data) + df = df.sem_index("Course Name", "index_dir") + df = df.sem_search("Course Name", "Optimization", K=1) + assert df["Course Name"].tolist() == ["Optimization Methods in Engineering"] + +@pytest.mark.parametrize("vs", VECTOR_STORE_TO_CLS.keys()) +@pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) +def test_vs_sim_join(setup_vs, vs, model): + my_vs = setup_vs[(vs, model)] + lotus.settings.configure(vs=my_vs) + + data1 = { + "Course Name": [ + "History of the Atlantic World", + "Riemannian Geometry", + ] + } + + data2 = {"Skill": ["Math", "History"]} + + df1 = pd.DataFrame(data1) + df2 = pd.DataFrame(data2).sem_index("Skill", "index_dir") + joined_df = df1.sem_sim_join(df2, left_on="Course Name", right_on="Skill", K=1) + joined_pairs = set(zip(joined_df["Course Name"], joined_df["Skill"])) + expected_pairs = {("History of the Atlantic World", "History"), ("Riemannian Geometry", "Math")} + assert joined_pairs == expected_pairs, joined_pairs + + +# TODO: threshold is hardcoded for intfloat/e5-small-v2 +@pytest.mark.skipif( + "intfloat/e5-small-v2" not in ENABLED_MODEL_NAMES, + reason="Skipping test because intfloat/e5-small-v2 is not enabled", +) +@pytest.mark.parametrize("vs", VECTOR_STORE_TO_CLS.keys()) +def test_vs_dedup(setup_vs, vs): + my_vs = setup_vs[(vs ,"intfloat/e5-small-v2")] + lotus.settings.configure(vs=my_vs) data = { "Text": [ "Probability and Random Processes", diff --git a/lotus/sem_ops/sem_index.py b/lotus/sem_ops/sem_index.py index cb7d8cd4..7a31a3f4 100644 --- a/lotus/sem_ops/sem_index.py +++ b/lotus/sem_ops/sem_index.py @@ -35,7 +35,7 @@ def __call__(self, col_name: str, index_dir: str) -> pd.DataFrame: "The retrieval model must be an instance of RM. Please configure a valid retrieval model using lotus.settings.configure()" ) - rm = lotus.settings.rm + rm = lotus.settings.get_rm_or_vs() rm.index(self._obj[col_name], index_dir) self._obj.attrs["index_dirs"][col_name] = index_dir return self._obj diff --git a/lotus/sem_ops/sem_search.py b/lotus/sem_ops/sem_search.py index de2df357..340db425 100644 --- a/lotus/sem_ops/sem_search.py +++ b/lotus/sem_ops/sem_search.py @@ -47,10 +47,10 @@ def __call__( assert not (K is None and n_rerank is None), "K or n_rerank must be provided" if K is not None: # get retriever model and index - rm = lotus.settings.rm + rm = lotus.settings.get_rm_or_vs() if rm is None: raise ValueError( - "The retrieval model must be an instance of RM. Please configure a valid retrieval model using lotus.settings.configure()" + "The retrieval model must be an instance of RM or VS. Please configure a valid retrieval model pr vector store using lotus.settings.configure()" ) col_index_dir = self._obj.attrs["index_dirs"][col_name] diff --git a/lotus/sem_ops/sem_sim_join.py b/lotus/sem_ops/sem_sim_join.py index 47d3cbe3..15c62c84 100644 --- a/lotus/sem_ops/sem_sim_join.py +++ b/lotus/sem_ops/sem_sim_join.py @@ -6,6 +6,7 @@ from lotus.cache import operator_cache from lotus.models import RM from lotus.types import RMOutput +from lotus.vector_store import VS @pd.api.extensions.register_dataframe_accessor("sem_sim_join") @@ -51,10 +52,10 @@ def __call__( raise ValueError("Other Series must have a name") other = pd.DataFrame({other.name: other}) - rm = lotus.settings.rm - if not isinstance(rm, RM): + rm = lotus.settings.get_rm_or_vs() + if not isinstance(rm, RM) and not isinstance(rm, VS): raise ValueError( - "The retrieval model must be an instance of RM. Please configure a valid retrieval model using lotus.settings.configure()" + "The retrieval model must be an instance of RM or VS. Please configure a valid retrieval model or vector store using lotus.settings.configure()" ) # load query embeddings from index if they exist diff --git a/lotus/settings.py b/lotus/settings.py index f7277a1c..97122f65 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -13,6 +13,7 @@ class Settings: reranker: lotus.models.Reranker | None = None vs: lotus.vector_store.VS | None = None + # Cache settings enable_cache: bool = False @@ -26,10 +27,16 @@ def configure(self, **kwargs): for key, value in kwargs.items(): if not hasattr(self, key): raise ValueError(f"Invalid setting: {key}") + if (key == 'vs' and hasattr(self, 'rm')) or (key == 'rm' and hasattr(self, 'vs')): + raise ValueError('Invalid settings: you can only set a retriever module or a vector store, but not both') + setattr(self, key, value) def __str__(self): return str(vars(self)) + + def get_rm_or_vs(self): + return self.rm or self.vs settings = Settings() diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index c304658e..583b238a 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -10,7 +10,7 @@ from lotus.vector_store.vs import VS try: - from chromadb import ClientAPI + from chromadb import Client, ClientAPI from chromadb.api import Collection from chromadb.api.types import IncludeEnum except ImportError as err: @@ -19,22 +19,27 @@ ) from err class ChromaVS(VS): - def __init__(self, client: ClientAPI, embedding_model: str, max_batch_size: int = 64): + def __init__(self, embedding_model: str, max_batch_size: int = 64): + + client: ClientAPI = Client() + """Initialize with ChromaDB client and embedding model""" super().__init__(embedding_model) self.client = client self.collection: Collection | None = None - self.collection_name = None + self.index_dir = None self.max_batch_size = max_batch_size + def __del__(self): + return - def index(self, docs: pd.Series, collection_name: str): + def index(self, docs: pd.Series, index_dir: str): """Create a collection and add documents with their embeddings""" - self.collection_name = collection_name + self.index_dir = index_dir # Create collection without embedding function (we'll provide embeddings directly) self.collection = self.client.create_collection( - name=collection_name, + name=index_dir, metadata={"hnsw:space": "cosine"} # Use cosine similarity for consistency ) @@ -59,13 +64,13 @@ def index(self, docs: pd.Series, collection_name: str): metadatas=metadatas[i:end_idx] ) - def load_index(self, collection_name: str): + def load_index(self, index_dir: str): """Load an existing collection""" try: - self.collection = self.client.get_collection(collection_name) - self.collection_name = collection_name + self.collection = self.client.get_collection(index_dir) + self.index_dir = index_dir except ValueError as e: - raise ValueError(f"Collection {collection_name} not found") from e + raise ValueError(f"Collection {index_dir} not found") from e def __call__( self, @@ -126,14 +131,14 @@ def __call__( indices=np.array(all_indices, dtype=np.int64).tolist() ) - def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]: + def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]: """Retrieve vectors for specific document IDs""" - if self.collection is None or self.collection_name != collection_name: - self.load_index(collection_name) + if self.collection is None or self.index_dir != index_dir: + self.load_index(index_dir) if self.collection is None: # Add this check after load_index - raise ValueError(f"Failed to load collection {collection_name}") + raise ValueError(f"Failed to load collection {index_dir}") # Convert integer ids to strings for ChromaDB diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py index 2aeedda8..4a1d58a3 100644 --- a/lotus/vector_store/pinecone_vs.py +++ b/lotus/vector_store/pinecone_vs.py @@ -17,33 +17,39 @@ ) from err class PineconeVS(VS): - def __init__(self, api_key: str, embedding_model: str, max_batch_size: int = 64): + def __init__(self, embedding_model: str, max_batch_size: int = 64): + + api_key = 'pcsk_45ecSY_CW62eJeL4jwj6dUfaqM6j9dL3uwK12rudednzGisWMxJv9bHH2DLz6tWoY91W84' + """Initialize Pinecone client with API key and environment""" super().__init__(embedding_model) self.pinecone = Pinecone(api_key=api_key) self.pc_index:Index | None = None self.max_batch_size = max_batch_size + def __del__(self): + return + - def index(self, docs: pd.Series, collection_name: str): + def index(self, docs: pd.Series, index_dir: str): """Create an index and add documents to it""" - self.collection_name = collection_name + self.index_dir = index_dir # Get sample embedding to determine vector dimension sample_embedding = self._embed([docs.iloc[0]]) dimension = sample_embedding.shape[1] # Check if index already exists - if collection_name not in self.pinecone.list_indexes(): + if index_dir not in self.pinecone.list_indexes(): # Create new index with the correct dimension self.pinecone.create_index( - name=collection_name, + name=index_dir, dimension=dimension, metric="cosine" ) # Connect to index - self.pc_index = self.pinecone.Index(collection_name) + self.pc_index = self.pinecone.Index(index_dir) # Convert docs to list if it's a pandas Series docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs @@ -69,13 +75,13 @@ def index(self, docs: pd.Series, collection_name: str): batch = vectors[i:i + batch_size] self.pc_index.upsert(vectors=batch) - def load_index(self, collection_name: str): + def load_index(self, index_dir: str): """Connect to an existing Pinecone index""" - if collection_name not in self.pinecone.list_indexes(): - raise ValueError(f"Index {collection_name} not found") + if index_dir not in self.pinecone.list_indexes(): + raise ValueError(f"Index {index_dir} not found") - self.collection_name = collection_name - self.pc_index = self.pinecone.Index(collection_name) + self.index_dir = index_dir + self.pc_index = self.pinecone.Index(index_dir) def __call__( self, @@ -135,10 +141,10 @@ def __call__( indices=np.array(all_indices, dtype=np.int64).tolist() ) - def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]: + def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]: """Retrieve vectors for specific document IDs""" - if self.pc_index is None or self.collection_name != collection_name: - self.load_index(collection_name) + if self.pc_index is None or self.index_dir != index_dir: + self.load_index(index_dir) if self.pc_index is None: # Add this check after load_index raise ValueError("Failed to initialize Pinecone index") diff --git a/lotus/vector_store/qdrant_vs.py b/lotus/vector_store/qdrant_vs.py index f6ae180b..a7cb5d15 100644 --- a/lotus/vector_store/qdrant_vs.py +++ b/lotus/vector_store/qdrant_vs.py @@ -16,24 +16,37 @@ raise ImportError("Please install the qdrant client") from err class QdrantVS(VS): - def __init__(self, client: QdrantClient, embedding_model: str, max_batch_size: int = 64): + def __init__(self, embedding_model: str, max_batch_size: int = 64): + + API_KEY = '_Mic3dVln2gAkS6NLyia6p-CCyMScK42ayuq8Rapm5-xsV5j5_UlIA' + + URL = "https://6f8b9aec-a788-4aac-9aeb-417d307493e8.europe-west3-0.gcp.cloud.qdrant.io:6333" + + client: QdrantClient = QdrantClient( + url=URL, + api_key=API_KEY + ) + """Initialize with Qdrant client and embedding model""" super().__init__(embedding_model) # Fixed the super() call syntax self.client = client self.max_batch_size = max_batch_size - def index(self, docs: pd.Series, collection_name: str): + def __del__(self): + self.client.close() + + def index(self, docs: pd.Series, index_dir: str): """Create a collection and add documents with their embeddings""" - self.collection_name = collection_name + self.index_dir = index_dir # Get sample embedding to determine vector dimension sample_embedding = self._embed([docs.iloc[0]]) dimension = sample_embedding.shape[1] # Create collection if it doesn't exist - if not self.client.collection_exists(collection_name): + if not self.client.collection_exists(index_dir): self.client.create_collection( - collection_name=collection_name, + collection_name=index_dir, vectors_config=VectorParams(size=dimension, distance=Distance.COSINE) ) @@ -62,15 +75,15 @@ def index(self, docs: pd.Series, collection_name: str): for i in tqdm(range(0, len(points), batch_size), desc="Uploading to Qdrant"): batch = points[i:i + batch_size] self.client.upsert( - collection_name=collection_name, + index_dir=index_dir, points=batch ) - def load_index(self, collection_name: str): + def load_index(self, index_dir: str): """Set the collection name to use""" - if not self.client.collection_exists(collection_name): - raise ValueError(f"Collection {collection_name} not found") - self.collection_name = collection_name + if not self.client.collection_exists(index_dir): + raise ValueError(f"Collection {index_dir} not found") + self.index_dir = index_dir def __call__( self, @@ -79,7 +92,7 @@ def __call__( **kwargs: dict[str, Any] ) -> RMOutput: """Perform vector search using Qdrant""" - if self.collection_name is None: + if self.index_dir is None: raise ValueError("No collection loaded. Call load_index first.") # Convert single query to list @@ -102,7 +115,7 @@ def __call__( for query_vector in query_vectors: results = self.client.search( - collection_name=self.collection_name, + collection_name=self.index_dir, query_vector=query_vector.tolist(), limit=K, with_payload=True @@ -129,14 +142,14 @@ def __call__( indices=np.array(all_indices, dtype=np.int64).tolist() ) - def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]: + def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]: """Retrieve vectors for specific document IDs""" - if self.collection_name != collection_name: - self.load_index(collection_name) + if self.index_dir != index_dir: + self.load_index(index_dir) # Fetch points from Qdrant points = self.client.retrieve( - collection_name=collection_name, + collection_name=index_dir, ids=ids, with_vectors=True, with_payload=False diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py index 2ad4c063..6513c4ac 100644 --- a/lotus/vector_store/vs.py +++ b/lotus/vector_store/vs.py @@ -3,11 +3,11 @@ import numpy as np import pandas as pd -import tqdm from litellm import embedding from numpy.typing import NDArray from PIL import Image from sentence_transformers import CrossEncoder, SentenceTransformer +from tqdm import tqdm from lotus.dtype_extensions import convert_to_base_data from lotus.types import RMOutput @@ -31,7 +31,7 @@ class VS(ABC): """Abstract class for vector stores.""" def __init__(self, embedding_model: str) -> None: - self.collection_name: str | None = None + self.index_dir: str | None = None self._embed = initialize(embedding_model) self.max_batch_size:int = 64 @@ -44,8 +44,9 @@ def index(self, docs: pd.Series, collection_name: str): pass @abstractmethod - def load_index(self, collection_name: str): - """Load the index from the vector store into memory ?? (not sure if this is needed )""" + def load_index(self, index_dir: str): + """Load the index from the vector store into memory if needed""" + pass @abstractmethod def __call__(self, @@ -56,7 +57,7 @@ def __call__(self, pass @abstractmethod - def get_vectors_from_index(self, collection_name:str, ids: list[Any]) -> NDArray[np.float64]: + def get_vectors_from_index(self, index_dir:str, ids: list[Any]) -> NDArray[np.float64]: pass def _batch_embed(self, docs: pd.Series | list) -> NDArray[np.float64]: diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 786fe49a..025e4d9a 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -14,25 +14,42 @@ import weaviate from weaviate.classes.config import Configure, DataType, Property + from weaviate.classes.init import Auth from weaviate.classes.query import MetadataQuery from weaviate.util import get_valid_uuid except ImportError as err: raise ImportError("Please install the weaviate client") from err class WeaviateVS(VS): - def __init__(self, weaviate_client: weaviate.WeaviateClient, embedding_model: str, max_batch_size: int = 64): + def __init__(self, embedding_model: str, max_batch_size: int = 64): + + REST_URL = 'https://dovieiknqr20pmgoticrmw.c0.us-west3.gcp.weaviate.cloud' + + API_KEY = 'nwRhjKLulSWbhPjX67WBklmJs7dgUS9XGWrZ' + + weaviate_client: weaviate.WeaviateClient = None # need to set this up + + + weaviate_client = weaviate.connect_to_weaviate_cloud( + cluster_url=REST_URL, # Replace with your Weaviate Cloud URL + auth_credentials=Auth.api_key(API_KEY), # Replace with your Weaviate Cloud key + ) + """Initialize with Weaviate client and embedding model""" super().__init__(embedding_model) self.client = weaviate_client self.max_batch_size = max_batch_size - def index(self, docs: pd.Series, collection_name: str): + def __del__(self): + self.client.close() + + def index(self, docs: pd.Series, index_dir: str): """Create a collection and add documents with their embeddings""" - self.collection_name = collection_name + self.index_dir = index_dir # Create collection without vectorizer config (we'll provide vectors directly) collection = self.client.collections.create( - name=collection_name, + name=index_dir, properties=[ Property( name='content', @@ -64,14 +81,14 @@ def index(self, docs: pd.Series, collection_name: str): uuid=get_valid_uuid(str(uuid4())) ) - def load_index(self, collection_name: str): + def load_index(self, index_dir: str): """Load/set the collection name to use""" - self.collection_name = collection_name + self.index_dir = index_dir # Verify collection exists try: - self.client.collections.get(collection_name) + self.client.collections.get(index_dir) except weaviate.exceptions.UnexpectedStatusCodeException: - raise ValueError(f"Collection {collection_name} not found") + raise ValueError(f"Collection {index_dir} not found") def __call__(self, queries: Union[pd.Series, str, Image.Image, list, NDArray[np.float64]], @@ -79,10 +96,10 @@ def __call__(self, **kwargs: dict[str, Any] ) -> RMOutput: """Perform vector search using pre-computed query vectors""" - if self.collection_name is None: + if self.index_dir is None: raise ValueError("No collection loaded. Call load_index first.") - collection = self.client.collections.get(self.collection_name) + collection = self.client.collections.get(self.index_dir) # Convert single query to list if isinstance(queries, (str, Image.Image)): @@ -134,9 +151,9 @@ def __call__(self, indices=np.array(all_indices, dtype=np.int64).tolist() ) - def get_vectors_from_index(self, collection_name: str, ids: list[Any]) -> NDArray[np.float64]: + def get_vectors_from_index(self, index_dir: str, ids: list[Any]) -> NDArray[np.float64]: """Retrieve vectors for specific document IDs""" - collection = self.client.collections.get(collection_name) + collection = self.client.collections.get(index_dir) # Query for documents with specific doc_ids vectors = [] @@ -146,7 +163,7 @@ def get_vectors_from_index(self, collection_name: str, ids: list[Any]) -> NDArra if response: vectors.append(response.vector) else: - raise ValueError(f'{id} does not exist in {collection_name}') + raise ValueError(f'{id} does not exist in {index_dir}') return np.array(vectors, dtype=np.float64) From 6dbd1dbdc97f91b0e41faec3d32741f303b56f24 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Tue, 21 Jan 2025 18:56:18 -0800 Subject: [PATCH 25/65] fixed --- lotus/vector_store/vs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py index 6513c4ac..bbcc4042 100644 --- a/lotus/vector_store/vs.py +++ b/lotus/vector_store/vs.py @@ -21,7 +21,7 @@ def initialize(model_name): if model_name == 'intfloat/e5-small-v2': - return SentenceTransformer(model_name=model_name) + return SentenceTransformer(model_name_or_path=model_name) elif model_name== 'mixedbread-ai/mxbai-rerank-xsmall-v1': return CrossEncoder(model_name=model_name) return lambda batch: embedding(model=model_name, input=batch) From bea1d190201f775eff5c9a59a344e21aaca4ba19 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Tue, 21 Jan 2025 19:04:49 -0800 Subject: [PATCH 26/65] changed setattr to getattr --- lotus/settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lotus/settings.py b/lotus/settings.py index 97122f65..b4225ad3 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -27,7 +27,7 @@ def configure(self, **kwargs): for key, value in kwargs.items(): if not hasattr(self, key): raise ValueError(f"Invalid setting: {key}") - if (key == 'vs' and hasattr(self, 'rm')) or (key == 'rm' and hasattr(self, 'vs')): + if (key == 'vs' and getattr(self, 'rm') is not None) or (key == 'rm' and getattr(self, 'vs') is not None): raise ValueError('Invalid settings: you can only set a retriever module or a vector store, but not both') setattr(self, key, value) From f93f7ed137a55d644753206117492a4bf6adce10 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Tue, 21 Jan 2025 19:21:55 -0800 Subject: [PATCH 27/65] fixed a test --- .github/tests/rm_tests.py | 36 +----------------------------------- 1 file changed, 1 insertion(+), 35 deletions(-) diff --git a/.github/tests/rm_tests.py b/.github/tests/rm_tests.py index 08893115..53f9b897 100644 --- a/.github/tests/rm_tests.py +++ b/.github/tests/rm_tests.py @@ -149,41 +149,7 @@ def test_dedup(setup_models): "Text": [ "Probability and Random Processes", "Probability and Markov Chains", - "Harry Potter",3 - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + "Harry Potter", "Harry James Potter", ] } From 38ff87dd2fdf51144e39f67639b2f850bd887d74 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Fri, 24 Jan 2025 19:47:22 -0800 Subject: [PATCH 28/65] over --- lotus/settings.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/lotus/settings.py b/lotus/settings.py index b4225ad3..fa302e54 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -24,10 +24,15 @@ class Settings: parallel_groupby_max_threads: int = 8 def configure(self, **kwargs): + + if 'rm' in kwargs and 'vs' in kwargs: + raise ValueError('Invalid settings: you can only set a retriever module or a vector store, but not both') + + for key, value in kwargs.items(): if not hasattr(self, key): raise ValueError(f"Invalid setting: {key}") - if (key == 'vs' and getattr(self, 'rm') is not None) or (key == 'rm' and getattr(self, 'vs') is not None): + if (key == 'vs' and (getattr(self, 'rm') is not None)) or (key == 'rm' and (getattr(self, 'vs') is not None)): raise ValueError('Invalid settings: you can only set a retriever module or a vector store, but not both') setattr(self, key, value) From c885dbcd90cb500f5207ef1359f3b55c0bbf87ae Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Fri, 24 Jan 2025 20:05:39 -0800 Subject: [PATCH 29/65] another change --- lotus/settings.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/lotus/settings.py b/lotus/settings.py index fa302e54..ad24652a 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -32,9 +32,6 @@ def configure(self, **kwargs): for key, value in kwargs.items(): if not hasattr(self, key): raise ValueError(f"Invalid setting: {key}") - if (key == 'vs' and (getattr(self, 'rm') is not None)) or (key == 'rm' and (getattr(self, 'vs') is not None)): - raise ValueError('Invalid settings: you can only set a retriever module or a vector store, but not both') - setattr(self, key, value) def __str__(self): From 8eefac05df2fc9649a615f4fa9b7f78cd5162307 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Fri, 24 Jan 2025 20:22:27 -0800 Subject: [PATCH 30/65] fixed type check errors --- lotus/vector_store/qdrant_vs.py | 4 ++-- lotus/vector_store/weaviate_vs.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lotus/vector_store/qdrant_vs.py b/lotus/vector_store/qdrant_vs.py index a7cb5d15..bf4892b2 100644 --- a/lotus/vector_store/qdrant_vs.py +++ b/lotus/vector_store/qdrant_vs.py @@ -29,7 +29,7 @@ def __init__(self, embedding_model: str, max_batch_size: int = 64): """Initialize with Qdrant client and embedding model""" super().__init__(embedding_model) # Fixed the super() call syntax - self.client = client + self.client: QdrantClient = client self.max_batch_size = max_batch_size def __del__(self): @@ -75,7 +75,7 @@ def index(self, docs: pd.Series, index_dir: str): for i in tqdm(range(0, len(points), batch_size), desc="Uploading to Qdrant"): batch = points[i:i + batch_size] self.client.upsert( - index_dir=index_dir, + collection_name=index_dir, points=batch ) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 025e4d9a..1d1b82ab 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -27,7 +27,7 @@ def __init__(self, embedding_model: str, max_batch_size: int = 64): API_KEY = 'nwRhjKLulSWbhPjX67WBklmJs7dgUS9XGWrZ' - weaviate_client: weaviate.WeaviateClient = None # need to set this up + weaviate_client: weaviate.WeaviateClient | None = None # need to set this up weaviate_client = weaviate.connect_to_weaviate_cloud( From 23bafa52e2a562ec9a51a36248a65e470dd990e4 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Mon, 27 Jan 2025 08:32:59 -0800 Subject: [PATCH 31/65] second refactor (removed index_dir) --- .github/tests/lm_tests.py | 4 +- .github/tests/multimodality_tests.py | 10 ++-- .github/tests/rm_tests.py | 45 ++++++++++------- examples/op_examples/cluster.py | 4 +- examples/op_examples/dedup.py | 5 +- examples/op_examples/join_cascade.py | 4 +- examples/op_examples/partition.py | 4 +- examples/op_examples/search.py | 6 ++- examples/op_examples/sim_join.py | 4 +- lotus/models/colbertv2_rm.py | 3 ++ lotus/models/litellm_rm.py | 4 +- lotus/models/rm.py | 50 ++++++++++++++----- lotus/models/sentence_transformers_rm.py | 9 ++-- lotus/sem_ops/sem_cluster_by.py | 6 ++- lotus/sem_ops/sem_dedup.py | 6 ++- lotus/sem_ops/sem_index.py | 6 ++- lotus/sem_ops/sem_search.py | 20 ++++---- lotus/sem_ops/sem_sim_join.py | 31 ++++++------ lotus/settings.py | 7 +-- lotus/utils.py | 13 ++--- lotus/vector_store/__init__.py | 3 +- lotus/vector_store/chroma_vs.py | 16 +++--- .../faiss_rm.py => vector_store/faiss_vs.py} | 25 ++++------ lotus/vector_store/pinecone_vs.py | 20 +++----- lotus/vector_store/qdrant_vs.py | 22 ++++---- lotus/vector_store/vs.py | 21 +++----- lotus/vector_store/weaviate_vs.py | 16 +++--- 27 files changed, 207 insertions(+), 157 deletions(-) rename lotus/{models/faiss_rm.py => vector_store/faiss_vs.py} (74%) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 1704bbd8..f7f7a7b0 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -8,6 +8,7 @@ from lotus.cache import CacheConfig, CacheFactory, CacheType from lotus.models import LM, SentenceTransformersRM from lotus.types import CascadeArgs +from lotus.vector_store import FaissVS ################################################################################ # Setup @@ -289,7 +290,8 @@ def test_filter_cascade(setup_models): def test_join_cascade(setup_models): models = setup_models rm = SentenceTransformersRM(model="intfloat/e5-base-v2") - lotus.settings.configure(lm=models["gpt-4o-mini"], rm=rm) + vs = FaissVS() + lotus.settings.configure(lm=models["gpt-4o-mini"], rm=rm, vs=vs) data1 = { "School": [ diff --git a/.github/tests/multimodality_tests.py b/.github/tests/multimodality_tests.py index c6311d81..64ce0bc9 100644 --- a/.github/tests/multimodality_tests.py +++ b/.github/tests/multimodality_tests.py @@ -6,6 +6,7 @@ import lotus from lotus.dtype_extensions import ImageArray from lotus.models import LM, SentenceTransformersRM +from lotus.vector_store import FaissVS ################################################################################ # Setup @@ -160,7 +161,8 @@ def test_topk_with_groupby_operation(setup_models, model): @pytest.mark.parametrize("model", get_enabled("clip-ViT-B-32")) def test_search_operation(setup_models, model): rm = setup_models[model] - lotus.settings.configure(rm=rm) + vs = FaissVS() + lotus.settings.configure(rm=rm, vs=vs) image_url = [ "https://img.etsystatic.com/il/4bee20/1469037676/il_340x270.1469037676_iiti.jpg?version=0", @@ -180,7 +182,8 @@ def test_search_operation(setup_models, model): @pytest.mark.parametrize("model", get_enabled("clip-ViT-B-32")) def test_sim_join_operation_image_index(setup_models, model): rm = setup_models[model] - lotus.settings.configure(rm=rm) + vs = FaissVS() + lotus.settings.configure(rm=rm, vs=vs) image_url = [ "https://img.etsystatic.com/il/4bee20/1469037676/il_340x270.1469037676_iiti.jpg?version=0", @@ -205,7 +208,8 @@ def test_sim_join_operation_image_index(setup_models, model): @pytest.mark.parametrize("model", get_enabled("clip-ViT-B-32")) def test_sim_join_operation_text_index(setup_models, model): rm = setup_models[model] - lotus.settings.configure(rm=rm) + vs = FaissVS() + lotus.settings.configure(rm=rm, vs=vs) image_url = [ "https://img.etsystatic.com/il/4bee20/1469037676/il_340x270.1469037676_iiti.jpg?version=0", diff --git a/.github/tests/rm_tests.py b/.github/tests/rm_tests.py index 53f9b897..22db3544 100644 --- a/.github/tests/rm_tests.py +++ b/.github/tests/rm_tests.py @@ -5,7 +5,7 @@ import lotus from lotus.models import CrossEncoderReranker, LiteLLMRM, SentenceTransformersRM -from lotus.vector_store import ChromaVS, PineconeVS, QdrantVS, WeaviateVS +from lotus.vector_store import ChromaVS, FaissVS, PineconeVS, QdrantVS, WeaviateVS ################################################################################ # Setup @@ -32,6 +32,7 @@ } VECTOR_STORE_TO_CLS = { + 'local': FaissVS, 'weaviate':WeaviateVS, 'pinecone': PineconeVS, 'chroma': ChromaVS, @@ -56,13 +57,13 @@ def setup_models(): @pytest.fixture(scope='session') def setup_vs(): - vs_and_embed_model = {} + vs_model = {} for vs in VECTOR_STORE_TO_CLS: for model_name in ENABLED_MODEL_NAMES: - vs_and_embed_model[(vs, model_name)] = VECTOR_STORE_TO_CLS[vs](embedding_model=model_name) + vs_model[model_name] = VECTOR_STORE_TO_CLS[vs]() - return vs_and_embed_model + return vs_model ################################################################################ # RM Only Tests @@ -70,7 +71,8 @@ def setup_vs(): @pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) def test_cluster_by(setup_models, model): rm = setup_models[model] - lotus.settings.configure(rm=rm) + vs = FaissVS() + lotus.settings.configure(rm=rm, vs=vs) data = { "Course Name": [ @@ -99,7 +101,9 @@ def test_cluster_by(setup_models, model): @pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) def test_search_rm_only(setup_models, model): rm = setup_models[model] - lotus.settings.configure(rm=rm) + vs = FaissVS() + + lotus.settings.configure(rm=rm, vs=vs) data = { "Course Name": [ @@ -118,7 +122,8 @@ def test_search_rm_only(setup_models, model): @pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) def test_sim_join(setup_models, model): rm = setup_models[model] - lotus.settings.configure(rm=rm) + vs = FaissVS() + lotus.settings.configure(rm=rm, vs=vs) data1 = { "Course Name": [ @@ -144,7 +149,8 @@ def test_sim_join(setup_models, model): ) def test_dedup(setup_models): rm = setup_models["intfloat/e5-small-v2"] - lotus.settings.configure(rm=rm) + vs = FaissVS() + lotus.settings.configure(rm=rm,vs=vs) data = { "Text": [ "Probability and Random Processes", @@ -171,8 +177,9 @@ def test_dedup(setup_models): @pytest.mark.parametrize("vs", VECTOR_STORE_TO_CLS.keys()) @pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) def test_vs_cluster_by(setup_vs, vs, model): - my_vs = setup_vs[(vs, model)] - lotus.settings.configure(vs=my_vs) + rm = setup_models[model] + my_vs = setup_vs[vs] + lotus.settings.configure(rm=rm, vs=my_vs) data = { "Course Name": [ @@ -200,8 +207,9 @@ def test_vs_cluster_by(setup_vs, vs, model): @pytest.mark.parametrize("vs", VECTOR_STORE_TO_CLS.keys()) @pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) def test_vs_search_rm_only(setup_vs, vs, model): - my_vs = setup_vs[(vs, model)] - lotus.settings.configure(vs=my_vs) + rm = setup_models[model] + my_vs = setup_vs[vs] + lotus.settings.configure(rm=rm, vs=my_vs) data = { "Course Name": [ @@ -219,8 +227,9 @@ def test_vs_search_rm_only(setup_vs, vs, model): @pytest.mark.parametrize("vs", VECTOR_STORE_TO_CLS.keys()) @pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) def test_vs_sim_join(setup_vs, vs, model): - my_vs = setup_vs[(vs, model)] - lotus.settings.configure(vs=my_vs) + rm = setup_models[model] + my_vs = setup_vs[vs] + lotus.settings.configure(rm=rm, vs=my_vs) data1 = { "Course Name": [ @@ -246,8 +255,9 @@ def test_vs_sim_join(setup_vs, vs, model): ) @pytest.mark.parametrize("vs", VECTOR_STORE_TO_CLS.keys()) def test_vs_dedup(setup_vs, vs): - my_vs = setup_vs[(vs ,"intfloat/e5-small-v2")] - lotus.settings.configure(vs=my_vs) + rm = setup_models["intfloat/e5-small-v2"] + my_vs = setup_vs[vs] + lotus.settings.configure(rm=rm, vs=my_vs) data = { "Text": [ "Probability and Random Processes", @@ -294,8 +304,9 @@ def test_search_reranker_only(setup_models, model): def test_search(setup_models): models = setup_models rm = models["intfloat/e5-small-v2"] + vs = FaissVS() reranker = models["mixedbread-ai/mxbai-rerank-xsmall-v1"] - lotus.settings.configure(rm=rm, reranker=reranker) + lotus.settings.configure(rm=rm, vs = vs, reranker=reranker) data = { "Course Name": [ diff --git a/examples/op_examples/cluster.py b/examples/op_examples/cluster.py index e117b249..e37e08f8 100644 --- a/examples/op_examples/cluster.py +++ b/examples/op_examples/cluster.py @@ -2,11 +2,13 @@ import lotus from lotus.models import LM, SentenceTransformersRM +from lotus.vector_store import FaissVS lm = LM(model="gpt-4o-mini") rm = SentenceTransformersRM(model="intfloat/e5-base-v2") +vs = FaissVS() -lotus.settings.configure(lm=lm, rm=rm) +lotus.settings.configure(lm=lm, rm=rm, vs=vs) data = { "Course Name": [ "Probability and Random Processes", diff --git a/examples/op_examples/dedup.py b/examples/op_examples/dedup.py index 1494df95..7a866918 100644 --- a/examples/op_examples/dedup.py +++ b/examples/op_examples/dedup.py @@ -2,10 +2,11 @@ import lotus from lotus.models import SentenceTransformersRM +from lotus.vector_store import FaissVS rm = SentenceTransformersRM(model="intfloat/e5-base-v2") - -lotus.settings.configure(rm=rm) +vs = FaissVS() +lotus.settings.configure(rm=rm, vs=vs) data = { "Text": [ "Probability and Random Processes", diff --git a/examples/op_examples/join_cascade.py b/examples/op_examples/join_cascade.py index 05e3e562..517ee6b7 100644 --- a/examples/op_examples/join_cascade.py +++ b/examples/op_examples/join_cascade.py @@ -3,11 +3,13 @@ import lotus from lotus.models import LM, SentenceTransformersRM from lotus.types import CascadeArgs +from lotus.vector_store import FaissVS lm = LM(model="gpt-4o-mini") rm = SentenceTransformersRM(model="intfloat/e5-base-v2") +vs = FaissVS() -lotus.settings.configure(lm=lm, rm=rm) +lotus.settings.configure(lm=lm, rm=rm, vs=vs) data = { "Course Name": [ "Digital Design and Integrated Circuits", diff --git a/examples/op_examples/partition.py b/examples/op_examples/partition.py index 932b170b..ed82015c 100644 --- a/examples/op_examples/partition.py +++ b/examples/op_examples/partition.py @@ -2,11 +2,13 @@ import lotus from lotus.models import LM, SentenceTransformersRM +from lotus.vector_store import FaissVS lm = LM(max_tokens=2048) rm = SentenceTransformersRM(model="intfloat/e5-base-v2") +vs = FaissVS() -lotus.settings.configure(lm=lm, rm=rm) +lotus.settings.configure(lm=lm, rm=rm, vs=vs) data = { "Course Name": [ "Probability and Random Processes", diff --git a/examples/op_examples/search.py b/examples/op_examples/search.py index c9382aae..49466ac3 100644 --- a/examples/op_examples/search.py +++ b/examples/op_examples/search.py @@ -2,12 +2,14 @@ import lotus from lotus.models import LM, CrossEncoderReranker, SentenceTransformersRM +from lotus.vector_store import FaissVS lm = LM(model="gpt-4o-mini") rm = SentenceTransformersRM(model="intfloat/e5-base-v2") -reranker = CrossEncoderReranker(model="mixedbread-ai/mxbai-rerank-large-v1") +reranker = CrossEncoderReranker(model="mixeddbread-ai/mxbai-rerank-large-v1") +vs = FaissVS() -lotus.settings.configure(lm=lm, rm=rm, reranker=reranker) +lotus.settings.configure(lm=lm, rm=rm, reranker=reranker, vs=vs) data = { "Course Name": [ "Probability and Random Processes", diff --git a/examples/op_examples/sim_join.py b/examples/op_examples/sim_join.py index 108ccb28..6a6bfc27 100644 --- a/examples/op_examples/sim_join.py +++ b/examples/op_examples/sim_join.py @@ -2,11 +2,13 @@ import lotus from lotus.models import LM, LiteLLMRM +from lotus.vector_store import FaissVS lm = LM(model="gpt-4o-mini") rm = LiteLLMRM(model="text-embedding-3-small") +vs = FaissVS() -lotus.settings.configure(lm=lm, rm=rm) +lotus.settings.configure(lm=lm, rm=rm, vs=vs) data = { "Course Name": [ "History of the Atlantic World", diff --git a/lotus/models/colbertv2_rm.py b/lotus/models/colbertv2_rm.py index 2bd8ed99..481ac5f9 100644 --- a/lotus/models/colbertv2_rm.py +++ b/lotus/models/colbertv2_rm.py @@ -46,6 +46,9 @@ def load_index(self, index_dir: str) -> None: def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]: raise NotImplementedError("This method is not implemented for ColBERTv2RM") + + + # this should be called in vs.py if it's def __call__( self, queries: str | Image.Image | list | NDArray[np.float64], diff --git a/lotus/models/litellm_rm.py b/lotus/models/litellm_rm.py index a4486dc6..c0d8b14c 100644 --- a/lotus/models/litellm_rm.py +++ b/lotus/models/litellm_rm.py @@ -7,10 +7,10 @@ from tqdm import tqdm from lotus.dtype_extensions import convert_to_base_data -from lotus.models.faiss_rm import FaissRM +from lotus.models.rm import RM -class LiteLLMRM(FaissRM): +class LiteLLMRM(RM): def __init__( self, model: str = "text-embedding-3-small", diff --git a/lotus/models/rm.py b/lotus/models/rm.py index c58a0c38..72324db8 100644 --- a/lotus/models/rm.py +++ b/lotus/models/rm.py @@ -1,42 +1,63 @@ from abc import ABC, abstractmethod -from typing import Any +from typing import Union import numpy as np import pandas as pd from numpy.typing import NDArray from PIL import Image -from lotus.types import RMOutput - class RM(ABC): - """Abstract class for retriever models.""" + #Abstract class for retriever models. def __init__(self) -> None: - self.index_dir: str | None = None + pass + @abstractmethod + def _embed(self, docs:pd.Series | list): + pass + + def __call__(self, docs: pd.Series | list): + return self._embed(docs) + + def convert_query_to_query_vector(self, queries: Union[pd.Series, str, Image.Image, list, NDArray[np.float64]], +): + if isinstance(queries, (str, Image.Image)): + queries = [queries] + + # Handle numpy array queries (pre-computed vectors) + if isinstance(queries, np.ndarray): + query_vectors = queries + else: + # Convert queries to list if needed + if isinstance(queries, pd.Series): + queries = queries.tolist() + # Create embeddings for text queries + query_vectors = self._embed(queries) + return query_vectors + """ @abstractmethod def index(self, docs: pd.Series, index_dir: str, **kwargs: dict[str, Any]) -> None: - """Create index and store it to a directory. + Create index and store it to a directory. Args: docs (list[str]): A list of documents to index. index_dir (str): The directory to save the index in. - """ + pass @abstractmethod def load_index(self, index_dir: str) -> None: - """Load the index into memory. + Load the index into memory. Args: index_dir (str): The directory of where the index is stored. - """ + pass @abstractmethod def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]: - """Get the vectors from the index. + Get the vectors from the index. Args: index_dir (str): Directory of the index. @@ -44,7 +65,7 @@ def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.f Returns: NDArray[np.float64]: The vectors matching the specified ids. - """ + pass @@ -55,7 +76,7 @@ def __call__( K: int, **kwargs: dict[str, Any], ) -> RMOutput: - """Run top-k search on the index. + Run top-k search on the index. Args: queries (str | list[str] | NDArray[np.float64]): Either a query or a list of queries or a 2D FP32 array. @@ -64,5 +85,8 @@ def __call__( Returns: RMOutput: An RMOutput object containing the distances and indices of the top-k vectors. - """ + pass + + + """ \ No newline at end of file diff --git a/lotus/models/sentence_transformers_rm.py b/lotus/models/sentence_transformers_rm.py index 76db6f3f..23559e21 100644 --- a/lotus/models/sentence_transformers_rm.py +++ b/lotus/models/sentence_transformers_rm.py @@ -1,4 +1,3 @@ -import faiss import numpy as np import pandas as pd import torch @@ -7,20 +6,18 @@ from tqdm import tqdm from lotus.dtype_extensions import convert_to_base_data -from lotus.models.faiss_rm import FaissRM +from lotus.models.rm import RM -class SentenceTransformersRM(FaissRM): +class SentenceTransformersRM(RM): def __init__( self, model: str = "intfloat/e5-base-v2", max_batch_size: int = 64, normalize_embeddings: bool = True, device: str | None = None, - factory_string: str = "Flat", - metric=faiss.METRIC_INNER_PRODUCT, ): - super().__init__(factory_string, metric) + #super().__init__(factory_string, metric) self.model: str = model self.max_batch_size: int = max_batch_size self.normalize_embeddings: bool = normalize_embeddings diff --git a/lotus/sem_ops/sem_cluster_by.py b/lotus/sem_ops/sem_cluster_by.py index fc8a9c6f..43e5324b 100644 --- a/lotus/sem_ops/sem_cluster_by.py +++ b/lotus/sem_ops/sem_cluster_by.py @@ -42,9 +42,11 @@ def __call__( Returns: pd.DataFrame: The DataFrame with the cluster assignments. """ - if lotus.settings.rm is None: + rm = lotus.settings.rm + vs = lotus.settings.vs + if rm is None or vs is None : raise ValueError( - "The retrieval model must be an instance of RM. Please configure a valid retrieval model using lotus.settings.configure()" + "The retrieval model must be an instance of RM, and the vector store must be an instance of VS. Please configure a valid retrieval model using lotus.settings.configure()" ) cluster_fn = lotus.utils.cluster(col_name, ncentroids) diff --git a/lotus/sem_ops/sem_dedup.py b/lotus/sem_ops/sem_dedup.py index 7c900073..2a4f358b 100644 --- a/lotus/sem_ops/sem_dedup.py +++ b/lotus/sem_ops/sem_dedup.py @@ -34,9 +34,11 @@ def __call__( Returns: pd.DataFrame: The DataFrame with duplicates removed. """ - if lotus.settings.rm is None: + rm = lotus.settings.rm + vs = lotus.settings.vs + if rm is None or vs is None: raise ValueError( - "The retrieval model must be an instance of RM. Please configure a valid retrieval model using lotus.settings.configure()" + "The retrieval model must be an instance of RM, and the vector store must be an instance of VS. Please configure a valid retrieval model using lotus.settings.configure()" ) joined_df = self._obj.sem_sim_join(self._obj, col_name, col_name, len(self._obj), lsuffix="_l", rsuffix="_r") diff --git a/lotus/sem_ops/sem_index.py b/lotus/sem_ops/sem_index.py index 7a31a3f4..a7a4e0db 100644 --- a/lotus/sem_ops/sem_index.py +++ b/lotus/sem_ops/sem_index.py @@ -35,7 +35,9 @@ def __call__(self, col_name: str, index_dir: str) -> pd.DataFrame: "The retrieval model must be an instance of RM. Please configure a valid retrieval model using lotus.settings.configure()" ) - rm = lotus.settings.get_rm_or_vs() - rm.index(self._obj[col_name], index_dir) + rm = lotus.settings.rm + vs = lotus.settings.vs + embeddings = rm(self._obj[col_name]) + vs.index(self._obj[col_name], embeddings, index_dir) self._obj.attrs["index_dirs"][col_name] = index_dir return self._obj diff --git a/lotus/sem_ops/sem_search.py b/lotus/sem_ops/sem_search.py index 340db425..45090e01 100644 --- a/lotus/sem_ops/sem_search.py +++ b/lotus/sem_ops/sem_search.py @@ -47,25 +47,27 @@ def __call__( assert not (K is None and n_rerank is None), "K or n_rerank must be provided" if K is not None: # get retriever model and index - rm = lotus.settings.get_rm_or_vs() - if rm is None: + rm = lotus.settings.rm + vs = lotus.settings.vs + if rm is None or vs is None : raise ValueError( - "The retrieval model must be an instance of RM or VS. Please configure a valid retrieval model pr vector store using lotus.settings.configure()" + "The retrieval model must be an instance of RM, and the vector store should be an instance of VS. Please configure a valid retrieval model and vector store using lotus.settings.configure()" ) col_index_dir = self._obj.attrs["index_dirs"][col_name] - if rm.index_dir != col_index_dir: - rm.load_index(col_index_dir) - assert rm.index_dir == col_index_dir + if vs.index_dir != col_index_dir: + vs.load_index(col_index_dir) + assert vs.index_dir == col_index_dir df_idxs = self._obj.index K = min(K, len(df_idxs)) search_K = K while True: - rm_output: RMOutput = rm(query, search_K) - doc_idxs = rm_output.indices[0] - scores = rm_output.distances[0] + query_vectors = rm.convert_query_to_query_vector(query) + vs_output: RMOutput = vs(query_vectors, search_K) + doc_idxs = vs_output.indices[0] + scores = vs_output.distances[0] assert len(doc_idxs) == len(scores) postfiltered_doc_idxs = [] diff --git a/lotus/sem_ops/sem_sim_join.py b/lotus/sem_ops/sem_sim_join.py index 15c62c84..3de4a853 100644 --- a/lotus/sem_ops/sem_sim_join.py +++ b/lotus/sem_ops/sem_sim_join.py @@ -52,20 +52,21 @@ def __call__( raise ValueError("Other Series must have a name") other = pd.DataFrame({other.name: other}) - rm = lotus.settings.get_rm_or_vs() - if not isinstance(rm, RM) and not isinstance(rm, VS): + rm = lotus.settings.rm + vs = lotus.settings.vs + if not isinstance(rm, RM) or not isinstance(vs, VS): raise ValueError( - "The retrieval model must be an instance of RM or VS. Please configure a valid retrieval model or vector store using lotus.settings.configure()" + "The retrieval model must be an instance of RM, and the vector store must be an instance of VS. Please configure a valid retrieval model or vector store using lotus.settings.configure()" ) - + # load query embeddings from index if they exist if left_on in self._obj.attrs.get("index_dirs", []): query_index_dir = self._obj.attrs["index_dirs"][left_on] - if rm.index_dir != query_index_dir: - rm.load_index(query_index_dir) - assert rm.index_dir == query_index_dir + if vs.index_dir != query_index_dir: + vs.load_index(query_index_dir) + assert vs.index_dir == query_index_dir try: - queries = rm.get_vectors_from_index(query_index_dir, self._obj.index) + queries = vs.get_vectors_from_index(query_index_dir, self._obj.index) except NotImplementedError: queries = self._obj[left_on] else: @@ -76,13 +77,15 @@ def __call__( col_index_dir = other.attrs["index_dirs"][right_on] except KeyError: raise ValueError(f"Index directory for column {right_on} not found in DataFrame") - if rm.index_dir != col_index_dir: - rm.load_index(col_index_dir) - assert rm.index_dir == col_index_dir + if vs.index_dir != col_index_dir: + vs.load_index(col_index_dir) + assert vs.index_dir == col_index_dir + + query_vectors = rm.convert_query_to_query_vector(queries) - rm_output: RMOutput = rm(queries, K) - distances = rm_output.distances - indices = rm_output.indices + vs_output: RMOutput = vs(query_vectors, K) + distances = vs_output.distances + indices = vs_output.indices other_index_set = set(other.index) join_results = [] diff --git a/lotus/settings.py b/lotus/settings.py index ad24652a..571e6155 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -8,7 +8,7 @@ class Settings: # Models lm: lotus.models.LM | None = None - rm: lotus.models.RM | None = None + rm: lotus.models.RM | None = None # supposed to only generate embeddings helper_lm: lotus.models.LM | None = None reranker: lotus.models.Reranker | None = None vs: lotus.vector_store.VS | None = None @@ -25,10 +25,7 @@ class Settings: def configure(self, **kwargs): - if 'rm' in kwargs and 'vs' in kwargs: - raise ValueError('Invalid settings: you can only set a retriever module or a vector store, but not both') - for key, value in kwargs.items(): if not hasattr(self, key): raise ValueError(f"Invalid setting: {key}") @@ -37,8 +34,6 @@ def configure(self, **kwargs): def __str__(self): return str(vars(self)) - def get_rm_or_vs(self): - return self.rm or self.vs settings = Settings() diff --git a/lotus/utils.py b/lotus/utils.py index b3e68ac9..da4dc82a 100644 --- a/lotus/utils.py +++ b/lotus/utils.py @@ -42,9 +42,10 @@ def ret( # get rmodel and index rm = lotus.settings.rm - if rm is None: + vs = lotus.settings.vs + if rm is None or vs is None: raise ValueError( - "The retrieval model must be an instance of RM. Please configure a valid retrieval model using lotus.settings.configure()" + "The retrieval model must be an instance of RM, and the vector store must be an instance of VS. Please configure a valid retrieval model using lotus.settings.configure()" ) try: @@ -52,12 +53,12 @@ def ret( except KeyError: raise ValueError(f"Index directory for column {col_name} not found in DataFrame") - if rm.index_dir != col_index_dir: - rm.load_index(col_index_dir) - assert rm.index_dir == col_index_dir + if vs.index_dir != col_index_dir: + vs.load_index(col_index_dir) + assert vs.index_dir == col_index_dir ids = df.index.tolist() # assumes df index hasn't been resest and corresponds to faiss index ids - vec_set = rm.get_vectors_from_index(col_index_dir, ids) + vec_set = vs.get_vectors_from_index(col_index_dir, ids) d = vec_set.shape[1] kmeans = faiss.Kmeans(d, ncentroids, niter=niter, verbose=verbose) kmeans.train(vec_set) diff --git a/lotus/vector_store/__init__.py b/lotus/vector_store/__init__.py index 34c41998..b1efbeac 100644 --- a/lotus/vector_store/__init__.py +++ b/lotus/vector_store/__init__.py @@ -1,7 +1,8 @@ from lotus.vector_store.vs import VS +from lotus.vector_store.faiss_vs import FaissVS from lotus.vector_store.weaviate_vs import WeaviateVS from lotus.vector_store.pinecone_vs import PineconeVS from lotus.vector_store.chroma_vs import ChromaVS from lotus.vector_store.qdrant_vs import QdrantVS -__all__ = ["VS", "WeaviateVS", "PineconeVS", "ChromaVS", "QdrantVS"] +__all__ = ["VS", "FaissVS", "WeaviateVS", "PineconeVS", "ChromaVS", "QdrantVS"] diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index 583b238a..7a7d24f5 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -3,7 +3,6 @@ import numpy as np import pandas as pd from numpy.typing import NDArray -from PIL import Image from tqdm import tqdm from lotus.types import RMOutput @@ -19,12 +18,12 @@ ) from err class ChromaVS(VS): - def __init__(self, embedding_model: str, max_batch_size: int = 64): + def __init__(self, max_batch_size: int = 64): client: ClientAPI = Client() """Initialize with ChromaDB client and embedding model""" - super().__init__(embedding_model) + super() self.client = client self.collection: Collection | None = None self.index_dir = None @@ -33,7 +32,7 @@ def __init__(self, embedding_model: str, max_batch_size: int = 64): def __del__(self): return - def index(self, docs: pd.Series, index_dir: str): + def index(self, docs: Any, embeddings: Any, index_dir: str): """Create a collection and add documents with their embeddings""" self.index_dir = index_dir @@ -46,9 +45,6 @@ def index(self, docs: pd.Series, index_dir: str): # Convert docs to list if it's a pandas Series docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs - # Generate embeddings - embeddings = self._batch_embed(docs_list) - # Prepare documents for addition ids = [str(i) for i in range(len(docs_list))] metadatas: list[Mapping[str, Union[str, int, float, bool]]] = [{"doc_id": int(i)} for i in range(len(docs_list))] @@ -74,7 +70,7 @@ def load_index(self, index_dir: str): def __call__( self, - queries: Union[pd.Series, str, Image.Image, list, NDArray[np.float64]], + query_vectors, K: int, **kwargs: dict[str, Any] ) -> RMOutput: @@ -82,6 +78,8 @@ def __call__( if self.collection is None: raise ValueError("No collection loaded. Call load_index first.") + + """ # Convert single query to list if isinstance(queries, (str, Image.Image)): queries = [queries] @@ -96,6 +94,8 @@ def __call__( # Create embeddings for text queries query_vectors = self._batch_embed(queries) + """ + # Perform searches all_distances = [] all_indices = [] diff --git a/lotus/models/faiss_rm.py b/lotus/vector_store/faiss_vs.py similarity index 74% rename from lotus/models/faiss_rm.py rename to lotus/vector_store/faiss_vs.py index ace1fd92..47654f0c 100644 --- a/lotus/models/faiss_rm.py +++ b/lotus/vector_store/faiss_vs.py @@ -1,19 +1,17 @@ import os import pickle -from abc import abstractmethod from typing import Any import faiss import numpy as np import pandas as pd from numpy.typing import NDArray -from PIL import Image -from lotus.models.rm import RM from lotus.types import RMOutput +from lotus.vector_store.vs import VS -class FaissRM(RM): +class FaissVS(VS): def __init__(self, factory_string: str = "Flat", metric=faiss.METRIC_INNER_PRODUCT): super().__init__() self.factory_string = factory_string @@ -22,15 +20,14 @@ def __init__(self, factory_string: str = "Flat", metric=faiss.METRIC_INNER_PRODU self.faiss_index: faiss.Index | None = None self.vecs: NDArray[np.float64] | None = None - def index(self, docs: pd.Series, index_dir: str, **kwargs: dict[str, Any]) -> None: - vecs = self._embed(docs) - self.faiss_index = faiss.index_factory(vecs.shape[1], self.factory_string, self.metric) - self.faiss_index.add(vecs) + def index(self, docs: pd.Series, embeddings, index_dir: str, **kwargs: dict[str, Any]) -> None: + self.faiss_index = faiss.index_factory(embeddings.shape[1], self.factory_string, self.metric) + self.faiss_index.add(embeddings) self.index_dir = index_dir os.makedirs(index_dir, exist_ok=True) with open(f"{index_dir}/vecs", "wb") as fp: - pickle.dump(vecs, fp) + pickle.dump(embeddings, fp) faiss.write_index(self.faiss_index, f"{index_dir}/index") def load_index(self, index_dir: str) -> None: @@ -45,8 +42,11 @@ def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.f return vecs[ids] def __call__( - self, queries: pd.Series | str | Image.Image | list | NDArray[np.float64], K: int, **kwargs: dict[str, Any] + self, embedded_queries, K: int, **kwargs: dict[str, Any] ) -> RMOutput: + + """ + do this processing in the rm if isinstance(queries, str) or isinstance(queries, Image.Image): queries = [queries] @@ -54,13 +54,10 @@ def __call__( embedded_queries = self._embed(queries) else: embedded_queries = np.asarray(queries, dtype=np.float32) - + """ if self.faiss_index is None: raise ValueError("Index not loaded") distances, indices = self.faiss_index.search(embedded_queries, K) return RMOutput(distances=distances, indices=indices) - @abstractmethod - def _embed(self, docs: pd.Series | list) -> NDArray[np.float64]: - pass diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py index 4a1d58a3..a85362bc 100644 --- a/lotus/vector_store/pinecone_vs.py +++ b/lotus/vector_store/pinecone_vs.py @@ -1,9 +1,8 @@ -from typing import Any, Union +from typing import Any import numpy as np import pandas as pd from numpy.typing import NDArray -from PIL import Image from tqdm import tqdm from lotus.types import RMOutput @@ -17,12 +16,12 @@ ) from err class PineconeVS(VS): - def __init__(self, embedding_model: str, max_batch_size: int = 64): + def __init__(self, max_batch_size: int = 64): api_key = 'pcsk_45ecSY_CW62eJeL4jwj6dUfaqM6j9dL3uwK12rudednzGisWMxJv9bHH2DLz6tWoY91W84' """Initialize Pinecone client with API key and environment""" - super().__init__(embedding_model) + super() self.pinecone = Pinecone(api_key=api_key) self.pc_index:Index | None = None self.max_batch_size = max_batch_size @@ -31,13 +30,11 @@ def __del__(self): return - def index(self, docs: pd.Series, index_dir: str): + def index(self, docs: pd.Series, embeddings: Any, index_dir: str): """Create an index and add documents to it""" self.index_dir = index_dir - # Get sample embedding to determine vector dimension - sample_embedding = self._embed([docs.iloc[0]]) - dimension = sample_embedding.shape[1] + dimension = embeddings.shape[1] # Check if index already exists if index_dir not in self.pinecone.list_indexes(): @@ -54,9 +51,6 @@ def index(self, docs: pd.Series, index_dir: str): # Convert docs to list if it's a pandas Series docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs - # Create embeddings using the provided embedding model - embeddings = self._batch_embed(docs_list) - # Prepare vectors for upsert vectors = [] for idx, (embedding, doc) in enumerate(zip(embeddings, docs_list)): @@ -85,7 +79,7 @@ def load_index(self, index_dir: str): def __call__( self, - queries: Union[pd.Series, str, Image.Image, list, NDArray[np.float64]], + query_vectors, K: int, **kwargs: dict[str, Any] ) -> RMOutput: @@ -93,6 +87,7 @@ def __call__( if self.pc_index is None: raise ValueError("No index loaded. Call load_index first.") + """ # Convert single query to list if isinstance(queries, (str, Image.Image)): queries = [queries] @@ -106,6 +101,7 @@ def __call__( queries = queries.tolist() # Create embeddings for text queries query_vectors = self._batch_embed(queries) + """ # Perform searches all_distances = [] diff --git a/lotus/vector_store/qdrant_vs.py b/lotus/vector_store/qdrant_vs.py index bf4892b2..5ae59a35 100644 --- a/lotus/vector_store/qdrant_vs.py +++ b/lotus/vector_store/qdrant_vs.py @@ -1,9 +1,8 @@ -from typing import Any, Union +from typing import Any import numpy as np import pandas as pd from numpy.typing import NDArray -from PIL import Image from tqdm import tqdm from lotus.types import RMOutput @@ -16,7 +15,7 @@ raise ImportError("Please install the qdrant client") from err class QdrantVS(VS): - def __init__(self, embedding_model: str, max_batch_size: int = 64): + def __init__(self, max_batch_size: int = 64): API_KEY = '_Mic3dVln2gAkS6NLyia6p-CCyMScK42ayuq8Rapm5-xsV5j5_UlIA' @@ -28,20 +27,19 @@ def __init__(self, embedding_model: str, max_batch_size: int = 64): ) """Initialize with Qdrant client and embedding model""" - super().__init__(embedding_model) # Fixed the super() call syntax + super() # Fixed the super() call syntax self.client: QdrantClient = client self.max_batch_size = max_batch_size def __del__(self): self.client.close() - def index(self, docs: pd.Series, index_dir: str): + def index(self, docs:pd.Series, embeddings, index_dir: str): """Create a collection and add documents with their embeddings""" self.index_dir = index_dir # Get sample embedding to determine vector dimension - sample_embedding = self._embed([docs.iloc[0]]) - dimension = sample_embedding.shape[1] + dimension = embeddings.shape[1] # Create collection if it doesn't exist if not self.client.collection_exists(index_dir): @@ -53,9 +51,6 @@ def index(self, docs: pd.Series, index_dir: str): # Convert docs to list if it's a pandas Series docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs - # Generate embeddings - embeddings = self._batch_embed(docs_list) - # Prepare points for upload points = [] for idx, (doc, embedding) in enumerate(zip(docs_list, embeddings)): @@ -87,7 +82,7 @@ def load_index(self, index_dir: str): def __call__( self, - queries: Union[pd.Series, str, Image.Image, list, NDArray[np.float64]], + query_vectors, K: int, **kwargs: dict[str, Any] ) -> RMOutput: @@ -95,6 +90,10 @@ def __call__( if self.index_dir is None: raise ValueError("No collection loaded. Call load_index first.") + """ + + do this in retriever module before passing into here + # Convert single query to list if isinstance(queries, (str, Image.Image)): queries = [queries] @@ -108,6 +107,7 @@ def __call__( queries = queries.tolist() # Create embeddings for text queries query_vectors = self._batch_embed(queries) + """ # Perform searches all_distances = [] diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py index bbcc4042..37be83ec 100644 --- a/lotus/vector_store/vs.py +++ b/lotus/vector_store/vs.py @@ -2,16 +2,11 @@ from typing import Any import numpy as np -import pandas as pd -from litellm import embedding from numpy.typing import NDArray -from PIL import Image -from sentence_transformers import CrossEncoder, SentenceTransformer -from tqdm import tqdm -from lotus.dtype_extensions import convert_to_base_data from lotus.types import RMOutput +""" MODEL_NAME_TO_CLS = { "intfloat/e5-small-v2": lambda model: SentenceTransformer(model_name_or_path=model), "mixedbread-ai/mxbai-rerank-xsmall-v1": lambda model: CrossEncoder(model_name=model), @@ -25,19 +20,17 @@ def initialize(model_name): elif model_name== 'mixedbread-ai/mxbai-rerank-xsmall-v1': return CrossEncoder(model_name=model_name) return lambda batch: embedding(model=model_name, input=batch) - +""" class VS(ABC): """Abstract class for vector stores.""" - def __init__(self, embedding_model: str) -> None: + def __init__(self) -> None: self.index_dir: str | None = None - self._embed = initialize(embedding_model) self.max_batch_size:int = 64 - @abstractmethod - def index(self, docs: pd.Series, collection_name: str): + def index(self, docs, embeddings: Any, collection_name: str): """ Create index and store it in vector store """ @@ -50,7 +43,7 @@ def load_index(self, index_dir: str): @abstractmethod def __call__(self, - queries: pd.Series | str | Image.Image | list | NDArray[np.float64], + query_vectors:Any, K:int, **kwargs: dict[str, Any], ) -> RMOutput: @@ -60,8 +53,9 @@ def __call__(self, def get_vectors_from_index(self, index_dir:str, ids: list[Any]) -> NDArray[np.float64]: pass + """ def _batch_embed(self, docs: pd.Series | list) -> NDArray[np.float64]: - """Create embeddings using the provided embedding model with batching""" + Create embeddings using the provided embedding model with batching all_embeddings = [] for i in tqdm(range(0, len(docs), self.max_batch_size), desc="Creating embeddings"): batch = docs[i : i + self.max_batch_size] @@ -69,3 +63,4 @@ def _batch_embed(self, docs: pd.Series | list) -> NDArray[np.float64]: embeddings = self._embed(_batch) all_embeddings.append(embeddings) return np.vstack(all_embeddings) + """ \ No newline at end of file diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 1d1b82ab..36620161 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -1,9 +1,8 @@ -from typing import Any, List, Union +from typing import Any, List import numpy as np import pandas as pd from numpy.typing import NDArray -from PIL import Image from lotus.types import RMOutput from lotus.vector_store.vs import VS @@ -21,7 +20,7 @@ raise ImportError("Please install the weaviate client") from err class WeaviateVS(VS): - def __init__(self, embedding_model: str, max_batch_size: int = 64): + def __init__(self, max_batch_size: int = 64): REST_URL = 'https://dovieiknqr20pmgoticrmw.c0.us-west3.gcp.weaviate.cloud' @@ -36,14 +35,14 @@ def __init__(self, embedding_model: str, max_batch_size: int = 64): ) """Initialize with Weaviate client and embedding model""" - super().__init__(embedding_model) + super() self.client = weaviate_client self.max_batch_size = max_batch_size def __del__(self): self.client.close() - def index(self, docs: pd.Series, index_dir: str): + def index(self, docs: pd.Series, embeddings, index_dir: str): """Create a collection and add documents with their embeddings""" self.index_dir = index_dir @@ -66,7 +65,6 @@ def index(self, docs: pd.Series, index_dir: str): # Generate embeddings for all documents docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs - embeddings = self._batch_embed(docs_list) # Add documents to collection with their embeddings with collection.batch.dynamic() as batch: @@ -91,7 +89,7 @@ def load_index(self, index_dir: str): raise ValueError(f"Collection {index_dir} not found") def __call__(self, - queries: Union[pd.Series, str, Image.Image, list, NDArray[np.float64]], + query_vectors, K: int, **kwargs: dict[str, Any] ) -> RMOutput: @@ -101,6 +99,9 @@ def __call__(self, collection = self.client.collections.get(self.index_dir) + """ + + do this in the retriever module # Convert single query to list if isinstance(queries, (str, Image.Image)): queries = [queries] @@ -111,6 +112,7 @@ def __call__(self, else: # Generate embeddings for text queries query_vectors = self._batch_embed(queries) + """ # Perform searches results = [] From 75d11ea5334afb4e7c8d373eb3f82c56ef456c96 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Mon, 27 Jan 2025 09:25:11 -0800 Subject: [PATCH 32/65] fixed type checks --- lotus/models/colbertv2_rm.py | 3 +-- lotus/models/litellm_rm.py | 2 +- lotus/sem_ops/sem_index.py | 11 +++++++---- lotus/vector_store/chroma_vs.py | 2 +- lotus/vector_store/faiss_vs.py | 4 ++-- lotus/vector_store/pinecone_vs.py | 2 +- lotus/vector_store/qdrant_vs.py | 2 +- lotus/vector_store/vs.py | 2 +- lotus/vector_store/weaviate_vs.py | 2 +- 9 files changed, 16 insertions(+), 14 deletions(-) diff --git a/lotus/models/colbertv2_rm.py b/lotus/models/colbertv2_rm.py index 481ac5f9..8018221f 100644 --- a/lotus/models/colbertv2_rm.py +++ b/lotus/models/colbertv2_rm.py @@ -6,7 +6,6 @@ from numpy.typing import NDArray from PIL import Image -from lotus.models.rm import RM from lotus.types import RMOutput try: @@ -16,7 +15,7 @@ pass -class ColBERTv2RM(RM): +class ColBERTv2RM(): def __init__(self) -> None: self.docs: list[str] | None = None self.kwargs: dict[str, Any] = {"doc_maxlen": 300, "nbits": 2} diff --git a/lotus/models/litellm_rm.py b/lotus/models/litellm_rm.py index c0d8b14c..b0ee07be 100644 --- a/lotus/models/litellm_rm.py +++ b/lotus/models/litellm_rm.py @@ -18,7 +18,7 @@ def __init__( factory_string: str = "Flat", metric=faiss.METRIC_INNER_PRODUCT, ): - super().__init__(factory_string, metric) + super() self.model: str = model self.max_batch_size: int = max_batch_size diff --git a/lotus/sem_ops/sem_index.py b/lotus/sem_ops/sem_index.py index a7a4e0db..ba4038ff 100644 --- a/lotus/sem_ops/sem_index.py +++ b/lotus/sem_ops/sem_index.py @@ -30,13 +30,16 @@ def __call__(self, col_name: str, index_dir: str) -> pd.DataFrame: Returns: pd.DataFrame: The DataFrame with the index directory saved. """ - if lotus.settings.rm is None: - raise ValueError( - "The retrieval model must be an instance of RM. Please configure a valid retrieval model using lotus.settings.configure()" - ) rm = lotus.settings.rm vs = lotus.settings.vs + if rm is None or vs is None: + raise ValueError( + "The retrieval model must be an instance of RM, and the vector store must be an instance of VS. Please configure a valid retrieval model using lotus.settings.configure()" + ) + + + embeddings = rm(self._obj[col_name]) vs.index(self._obj[col_name], embeddings, index_dir) self._obj.attrs["index_dirs"][col_name] = index_dir diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index 7a7d24f5..b7e1ce9f 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -32,7 +32,7 @@ def __init__(self, max_batch_size: int = 64): def __del__(self): return - def index(self, docs: Any, embeddings: Any, index_dir: str): + def index(self, docs: Any, embeddings: Any, index_dir: str, **kwargs: dict[str, Any]): """Create a collection and add documents with their embeddings""" self.index_dir = index_dir diff --git a/lotus/vector_store/faiss_vs.py b/lotus/vector_store/faiss_vs.py index 47654f0c..9f682469 100644 --- a/lotus/vector_store/faiss_vs.py +++ b/lotus/vector_store/faiss_vs.py @@ -42,7 +42,7 @@ def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.f return vecs[ids] def __call__( - self, embedded_queries, K: int, **kwargs: dict[str, Any] + self, query_vectors, K: int, **kwargs: dict[str, Any] ) -> RMOutput: """ @@ -58,6 +58,6 @@ def __call__( if self.faiss_index is None: raise ValueError("Index not loaded") - distances, indices = self.faiss_index.search(embedded_queries, K) + distances, indices = self.faiss_index.search(query_vectors, K) return RMOutput(distances=distances, indices=indices) diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py index a85362bc..cb8cf4e1 100644 --- a/lotus/vector_store/pinecone_vs.py +++ b/lotus/vector_store/pinecone_vs.py @@ -30,7 +30,7 @@ def __del__(self): return - def index(self, docs: pd.Series, embeddings: Any, index_dir: str): + def index(self, docs: pd.Series, embeddings: Any, index_dir: str, **kwargs: dict[str, Any]): """Create an index and add documents to it""" self.index_dir = index_dir diff --git a/lotus/vector_store/qdrant_vs.py b/lotus/vector_store/qdrant_vs.py index 5ae59a35..e8337d02 100644 --- a/lotus/vector_store/qdrant_vs.py +++ b/lotus/vector_store/qdrant_vs.py @@ -34,7 +34,7 @@ def __init__(self, max_batch_size: int = 64): def __del__(self): self.client.close() - def index(self, docs:pd.Series, embeddings, index_dir: str): + def index(self, docs:pd.Series, embeddings, index_dir: str, **kwargs: dict[str, Any]): """Create a collection and add documents with their embeddings""" self.index_dir = index_dir diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py index 37be83ec..7e1818e6 100644 --- a/lotus/vector_store/vs.py +++ b/lotus/vector_store/vs.py @@ -30,7 +30,7 @@ def __init__(self) -> None: self.max_batch_size:int = 64 @abstractmethod - def index(self, docs, embeddings: Any, collection_name: str): + def index(self, docs, embeddings: Any, index_dir: str, **kwargs: dict[str, Any]): """ Create index and store it in vector store """ diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 36620161..cb8c7956 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -42,7 +42,7 @@ def __init__(self, max_batch_size: int = 64): def __del__(self): self.client.close() - def index(self, docs: pd.Series, embeddings, index_dir: str): + def index(self, docs: pd.Series, embeddings, index_dir: str, **kwargs: dict[str, Any]): """Create a collection and add documents with their embeddings""" self.index_dir = index_dir From 0b0bf38102d05ebd929443cf754fd2c9118d1cc7 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Mon, 27 Jan 2025 09:33:04 -0800 Subject: [PATCH 33/65] fixed retriever module errors --- .github/tests/rm_tests.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/tests/rm_tests.py b/.github/tests/rm_tests.py index 22db3544..be36d4e1 100644 --- a/.github/tests/rm_tests.py +++ b/.github/tests/rm_tests.py @@ -176,7 +176,7 @@ def test_dedup(setup_models): @pytest.mark.parametrize("vs", VECTOR_STORE_TO_CLS.keys()) @pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) -def test_vs_cluster_by(setup_vs, vs, model): +def test_vs_cluster_by(setup_models, setup_vs, vs, model): rm = setup_models[model] my_vs = setup_vs[vs] lotus.settings.configure(rm=rm, vs=my_vs) @@ -206,7 +206,7 @@ def test_vs_cluster_by(setup_vs, vs, model): @pytest.mark.parametrize("vs", VECTOR_STORE_TO_CLS.keys()) @pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) -def test_vs_search_rm_only(setup_vs, vs, model): +def test_vs_search_rm_only(setup_models, setup_vs, vs, model): rm = setup_models[model] my_vs = setup_vs[vs] lotus.settings.configure(rm=rm, vs=my_vs) @@ -226,7 +226,7 @@ def test_vs_search_rm_only(setup_vs, vs, model): @pytest.mark.parametrize("vs", VECTOR_STORE_TO_CLS.keys()) @pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) -def test_vs_sim_join(setup_vs, vs, model): +def test_vs_sim_join(setup_models, setup_vs, vs, model): rm = setup_models[model] my_vs = setup_vs[vs] lotus.settings.configure(rm=rm, vs=my_vs) @@ -254,7 +254,7 @@ def test_vs_sim_join(setup_vs, vs, model): reason="Skipping test because intfloat/e5-small-v2 is not enabled", ) @pytest.mark.parametrize("vs", VECTOR_STORE_TO_CLS.keys()) -def test_vs_dedup(setup_vs, vs): +def test_vs_dedup(setup_models, setup_vs, vs): rm = setup_models["intfloat/e5-small-v2"] my_vs = setup_vs[vs] lotus.settings.configure(rm=rm, vs=my_vs) From 6bf79266e527c3ed6a576fab52d7cbacdaaef7ea Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Mon, 27 Jan 2025 09:38:15 -0800 Subject: [PATCH 34/65] fixed key error --- .github/tests/rm_tests.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/tests/rm_tests.py b/.github/tests/rm_tests.py index be36d4e1..d0c31b99 100644 --- a/.github/tests/rm_tests.py +++ b/.github/tests/rm_tests.py @@ -60,8 +60,7 @@ def setup_vs(): vs_model = {} for vs in VECTOR_STORE_TO_CLS: - for model_name in ENABLED_MODEL_NAMES: - vs_model[model_name] = VECTOR_STORE_TO_CLS[vs]() + vs_model[vs] = VECTOR_STORE_TO_CLS[vs]() return vs_model From f7071a2403c5ad0ada481cadf7617f807cc5c8fc Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Mon, 27 Jan 2025 21:07:42 -0800 Subject: [PATCH 35/65] added fixes to failing rm tests --- lotus/vector_store/chroma_vs.py | 4 +-- lotus/vector_store/pinecone_vs.py | 8 +++-- lotus/vector_store/qdrant_vs.py | 2 +- lotus/vector_store/weaviate_vs.py | 53 +++++++++++++++---------------- 4 files changed, 35 insertions(+), 32 deletions(-) diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index b7e1ce9f..877fc1d1 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -150,8 +150,8 @@ def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.f include=[IncludeEnum.embeddings] ) - if not results['embeddings']: - raise ValueError("No vectors found for the given ids") + if not results['embeddings'].all(): + raise ValueError("No vectors found for the given ids", results['embeddings']) return np.array(results['embeddings'], dtype=np.float64) diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py index cb8cf4e1..c27e4dfb 100644 --- a/lotus/vector_store/pinecone_vs.py +++ b/lotus/vector_store/pinecone_vs.py @@ -9,7 +9,7 @@ from lotus.vector_store.vs import VS try: - from pinecone import Index, Pinecone + from pinecone import Index, Pinecone, ServerlessSpec except ImportError as err: raise ImportError( "The pinecone library is required to use PineconeVS. Install it with `pip install pinecone`", @@ -42,7 +42,11 @@ def index(self, docs: pd.Series, embeddings: Any, index_dir: str, **kwargs: dic self.pinecone.create_index( name=index_dir, dimension=dimension, - metric="cosine" + metric="cosine", + spec=ServerlessSpec( + cloud='aws', + region='us-west-1' + ) ) # Connect to index diff --git a/lotus/vector_store/qdrant_vs.py b/lotus/vector_store/qdrant_vs.py index e8337d02..ff3ff9c5 100644 --- a/lotus/vector_store/qdrant_vs.py +++ b/lotus/vector_store/qdrant_vs.py @@ -39,7 +39,7 @@ def index(self, docs:pd.Series, embeddings, index_dir: str, **kwargs: dict[str, self.index_dir = index_dir # Get sample embedding to determine vector dimension - dimension = embeddings.shape[1] + dimension = embeddings.reshape((len(embeddings), -1)).shape()[1] # Create collection if it doesn't exist if not self.client.collection_exists(index_dir): diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index cb8c7956..6e1301dc 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -8,14 +8,10 @@ from lotus.vector_store.vs import VS try: - import uuid - from uuid import uuid4 - import weaviate from weaviate.classes.config import Configure, DataType, Property from weaviate.classes.init import Auth from weaviate.classes.query import MetadataQuery - from weaviate.util import get_valid_uuid except ImportError as err: raise ImportError("Please install the weaviate client") from err @@ -47,21 +43,24 @@ def index(self, docs: pd.Series, embeddings, index_dir: str, **kwargs: dict[str, self.index_dir = index_dir # Create collection without vectorizer config (we'll provide vectors directly) - collection = self.client.collections.create( - name=index_dir, - properties=[ - Property( - name='content', - data_type=DataType.TEXT - ), - Property( - name='doc_id', - data_type=DataType.INT, - ) - ], - vectorizer_config=None, # No vectorizer needed as we provide vectors - vector_index_config=Configure.VectorIndex.dynamic() - ) + if not self.client.collections.exists(index_dir): + collection = self.client.collections.create( + name=index_dir, + properties=[ + Property( + name='content', + data_type=DataType.TEXT + ), + Property( + name='doc_id', + data_type=DataType.INT, + ) + ], + vectorizer_config=None, # No vectorizer needed as we provide vectors + vector_index_config=Configure.VectorIndex.dynamic() + ) + else: + collection = self.client.collections.get(index_dir) # Generate embeddings for all documents docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs @@ -76,7 +75,6 @@ def index(self, docs: pd.Series, embeddings, index_dir: str, **kwargs: dict[str, batch.add_object( properties=properties, vector=embedding.tolist(), # Provide pre-computed vector - uuid=get_valid_uuid(str(uuid4())) ) def load_index(self, index_dir: str): @@ -123,7 +121,6 @@ def __call__(self, limit=K, return_metadata=MetadataQuery(distance=True) )) - response.objects[0].uuid results.append(response) # Process results into expected format @@ -136,13 +133,13 @@ def __call__(self, distances:List[float] = [] indices = [] for obj in objects: - indices.append(obj.uuid) + indices.append(obj.properties.get('doc_id', -1)) # Convert cosine distance to similarity score distance = obj.metadata.distance if obj.metadata and obj.metadata.distance is not None else 1.0 distances.append(1 - distance) # Convert distance to similarity # Pad results if fewer than K matches while len(indices) < K: - indices.append(uuid.UUID(int=0)) + indices.append(-1) distances.append(0.0) all_distances.append(distances) @@ -161,10 +158,12 @@ def get_vectors_from_index(self, index_dir: str, ids: list[Any]) -> NDArray[np.f vectors = [] for id in ids: - response = collection.query.fetch_object_by_id(uuid=id) - if response: - vectors.append(response.vector) - else: + exists = False + for obj in collection.query.fetch_objects().objects: + if id == obj.properties.get('doc_id', -1): + exists = True + vectors.append(obj.vector) + if not exists: raise ValueError(f'{id} does not exist in {index_dir}') return np.array(vectors, dtype=np.float64) From 6ebe40771383b1b8501ec001471fd91abe65c5c0 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Mon, 27 Jan 2025 21:10:21 -0800 Subject: [PATCH 36/65] fixed chroma --- lotus/vector_store/chroma_vs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index 877fc1d1..01cd77a8 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -37,7 +37,7 @@ def index(self, docs: Any, embeddings: Any, index_dir: str, **kwargs: dict[str, self.index_dir = index_dir # Create collection without embedding function (we'll provide embeddings directly) - self.collection = self.client.create_collection( + self.collection = self.client.get_or_create_collection( name=index_dir, metadata={"hnsw:space": "cosine"} # Use cosine similarity for consistency ) From e588beea127554a548804d5774cb8a178cc531e3 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Mon, 27 Jan 2025 21:19:28 -0800 Subject: [PATCH 37/65] removed dynamic indexing for weaviatevs --- lotus/vector_store/weaviate_vs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 6e1301dc..e42f1a9b 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -9,7 +9,7 @@ try: import weaviate - from weaviate.classes.config import Configure, DataType, Property + from weaviate.classes.config import DataType, Property from weaviate.classes.init import Auth from weaviate.classes.query import MetadataQuery except ImportError as err: @@ -57,7 +57,7 @@ def index(self, docs: pd.Series, embeddings, index_dir: str, **kwargs: dict[str, ) ], vectorizer_config=None, # No vectorizer needed as we provide vectors - vector_index_config=Configure.VectorIndex.dynamic() + #vector_index_config=Configure.VectorIndex.dynamic() ) else: collection = self.client.collections.get(index_dir) From d6a86e1f51540e55faf8c253929c14d4f3f24169 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Mon, 27 Jan 2025 21:32:29 -0800 Subject: [PATCH 38/65] fixed type errors --- lotus/vector_store/chroma_vs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index 01cd77a8..fcd88d68 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -150,7 +150,7 @@ def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.f include=[IncludeEnum.embeddings] ) - if not results['embeddings'].all(): + if results['embeddings'] is None: raise ValueError("No vectors found for the given ids", results['embeddings']) return np.array(results['embeddings'], dtype=np.float64) From ddfd549ff200a3cc6c359fc37b6ab88e1f3a0660 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Mon, 27 Jan 2025 21:36:12 -0800 Subject: [PATCH 39/65] changed weaviate index config --- lotus/vector_store/weaviate_vs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index e42f1a9b..4858601a 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -9,7 +9,7 @@ try: import weaviate - from weaviate.classes.config import DataType, Property + from weaviate.classes.config import Configure, DataType, Property from weaviate.classes.init import Auth from weaviate.classes.query import MetadataQuery except ImportError as err: @@ -57,7 +57,7 @@ def index(self, docs: pd.Series, embeddings, index_dir: str, **kwargs: dict[str, ) ], vectorizer_config=None, # No vectorizer needed as we provide vectors - #vector_index_config=Configure.VectorIndex.dynamic() + vector_index_config=Configure.VectorIndex.hnsw() ) else: collection = self.client.collections.get(index_dir) From 20206e1f9284ca49b4dcb15250b76624751a882a Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Mon, 27 Jan 2025 21:38:56 -0800 Subject: [PATCH 40/65] changed rm tests index name to avoid pinecone failures --- .github/tests/rm_tests.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/tests/rm_tests.py b/.github/tests/rm_tests.py index d0c31b99..3e47d5a7 100644 --- a/.github/tests/rm_tests.py +++ b/.github/tests/rm_tests.py @@ -189,7 +189,7 @@ def test_vs_cluster_by(setup_models, setup_vs, vs, model): ] } df = pd.DataFrame(data) - df = df.sem_index("Course Name", "index_dir") + df = df.sem_index("Course Name", "index-dir") df = df.sem_cluster_by("Course Name", 2) groups = df.groupby("cluster_id")["Course Name"].apply(set).to_dict() assert len(groups) == 2, groups @@ -219,7 +219,7 @@ def test_vs_search_rm_only(setup_models, setup_vs, vs, model): ] } df = pd.DataFrame(data) - df = df.sem_index("Course Name", "index_dir") + df = df.sem_index("Course Name", "index-dir") df = df.sem_search("Course Name", "Optimization", K=1) assert df["Course Name"].tolist() == ["Optimization Methods in Engineering"] @@ -240,7 +240,7 @@ def test_vs_sim_join(setup_models, setup_vs, vs, model): data2 = {"Skill": ["Math", "History"]} df1 = pd.DataFrame(data1) - df2 = pd.DataFrame(data2).sem_index("Skill", "index_dir") + df2 = pd.DataFrame(data2).sem_index("Skill", "index-dir") joined_df = df1.sem_sim_join(df2, left_on="Course Name", right_on="Skill", K=1) joined_pairs = set(zip(joined_df["Course Name"], joined_df["Skill"])) expected_pairs = {("History of the Atlantic World", "History"), ("Riemannian Geometry", "Math")} @@ -266,7 +266,7 @@ def test_vs_dedup(setup_models, setup_vs, vs): ] } df = pd.DataFrame(data) - df = df.sem_index("Text", "index_dir").sem_dedup("Text", threshold=0.85) + df = df.sem_index("Text", "index-dir").sem_dedup("Text", threshold=0.85) kept = df["Text"].tolist() kept.sort() assert len(kept) == 2, kept From e7ea24f0a110a90138d13c6dfa91574f48b69c9b Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Tue, 28 Jan 2025 08:29:00 -0800 Subject: [PATCH 41/65] fixed naming convention for index_dir and fixed serverless spec for pc index --- .github/tests/rm_tests.py | 8 ++++---- .github/workflows/tests.yml | 2 +- lotus/vector_store/pinecone_vs.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/tests/rm_tests.py b/.github/tests/rm_tests.py index 3e47d5a7..3afcb745 100644 --- a/.github/tests/rm_tests.py +++ b/.github/tests/rm_tests.py @@ -189,7 +189,7 @@ def test_vs_cluster_by(setup_models, setup_vs, vs, model): ] } df = pd.DataFrame(data) - df = df.sem_index("Course Name", "index-dir") + df = df.sem_index("Course Name", "indexdir") df = df.sem_cluster_by("Course Name", 2) groups = df.groupby("cluster_id")["Course Name"].apply(set).to_dict() assert len(groups) == 2, groups @@ -219,7 +219,7 @@ def test_vs_search_rm_only(setup_models, setup_vs, vs, model): ] } df = pd.DataFrame(data) - df = df.sem_index("Course Name", "index-dir") + df = df.sem_index("Course Name", "indexdir") df = df.sem_search("Course Name", "Optimization", K=1) assert df["Course Name"].tolist() == ["Optimization Methods in Engineering"] @@ -240,7 +240,7 @@ def test_vs_sim_join(setup_models, setup_vs, vs, model): data2 = {"Skill": ["Math", "History"]} df1 = pd.DataFrame(data1) - df2 = pd.DataFrame(data2).sem_index("Skill", "index-dir") + df2 = pd.DataFrame(data2).sem_index("Skill", "indexdir") joined_df = df1.sem_sim_join(df2, left_on="Course Name", right_on="Skill", K=1) joined_pairs = set(zip(joined_df["Course Name"], joined_df["Skill"])) expected_pairs = {("History of the Atlantic World", "History"), ("Riemannian Geometry", "Math")} @@ -266,7 +266,7 @@ def test_vs_dedup(setup_models, setup_vs, vs): ] } df = pd.DataFrame(data) - df = df.sem_index("Text", "index-dir").sem_dedup("Text", threshold=0.85) + df = df.sem_index("Text", "indexdir").sem_dedup("Text", threshold=0.85) kept = df["Text"].tolist() kept.sort() assert len(kept) == 2, kept diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 494ded8c..b5c1e595 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -155,7 +155,7 @@ jobs: rm_test: name: Retrieval Model Tests runs-on: ubuntu-latest - timeout-minutes: 5 + timeout-minutes: 10 steps: - name: Checkout code diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py index c27e4dfb..78326979 100644 --- a/lotus/vector_store/pinecone_vs.py +++ b/lotus/vector_store/pinecone_vs.py @@ -45,7 +45,7 @@ def index(self, docs: pd.Series, embeddings: Any, index_dir: str, **kwargs: dic metric="cosine", spec=ServerlessSpec( cloud='aws', - region='us-west-1' + region='us-west-2' ) ) From f152b54fa20373fe6f21fd336880e4fcb579a1ca Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Tue, 28 Jan 2025 19:35:01 -0800 Subject: [PATCH 42/65] changed serverless spec for pc index due to free plan --- lotus/vector_store/pinecone_vs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py index 78326979..f68d768e 100644 --- a/lotus/vector_store/pinecone_vs.py +++ b/lotus/vector_store/pinecone_vs.py @@ -45,7 +45,7 @@ def index(self, docs: pd.Series, embeddings: Any, index_dir: str, **kwargs: dic metric="cosine", spec=ServerlessSpec( cloud='aws', - region='us-west-2' + region='us-east-1' ) ) From 2e21a977fdf86b40f907e4ec70e437432531ea6c Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Tue, 28 Jan 2025 19:37:04 -0800 Subject: [PATCH 43/65] added debug statement --- lotus/vector_store/weaviate_vs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 4858601a..4e554f38 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -162,6 +162,7 @@ def get_vectors_from_index(self, index_dir: str, ids: list[Any]) -> NDArray[np.f for obj in collection.query.fetch_objects().objects: if id == obj.properties.get('doc_id', -1): exists = True + print(f'vector example {obj.vector}') vectors.append(obj.vector) if not exists: raise ValueError(f'{id} does not exist in {index_dir}') From 524b501dd5d12e475706ff8d9cd098a63c27ace5 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Tue, 28 Jan 2025 20:07:23 -0800 Subject: [PATCH 44/65] made changes to errors --- lotus/vector_store/chroma_vs.py | 28 ++++++++++++++++++++++------ lotus/vector_store/pinecone_vs.py | 2 +- lotus/vector_store/qdrant_vs.py | 2 +- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index fcd88d68..ad3009a8 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -12,6 +12,7 @@ from chromadb import Client, ClientAPI from chromadb.api import Collection from chromadb.api.types import IncludeEnum + from chromadb.errors import InvalidDimensionException except ImportError as err: raise ImportError( "The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`" @@ -53,12 +54,27 @@ def index(self, docs: Any, embeddings: Any, index_dir: str, **kwargs: dict[str, batch_size = 100 for i in tqdm(range(0, len(docs_list), batch_size), desc="Uploading to ChromaDB"): end_idx = min(i + batch_size, len(docs_list)) - self.collection.add( - ids=ids[i:end_idx], - documents=docs_list[i:end_idx], - embeddings=embeddings[i:end_idx].tolist(), - metadatas=metadatas[i:end_idx] - ) + try: + self.collection.add( + ids=ids[i:end_idx], + documents=docs_list[i:end_idx], + embeddings=embeddings[i:end_idx].tolist(), + metadatas=metadatas[i:end_idx] + ) + except InvalidDimensionException: + # delete, recreate, then add + self.client.delete_collection(index_dir) + # Create collection without embedding function (we'll provide embeddings directly) + self.collection = self.client.get_or_create_collection( + name=index_dir, + metadata={"hnsw:space": "cosine"} # Use cosine similarity for consistency + ) + self.collection.add( + ids=ids[i:end_idx], + documents=docs_list[i:end_idx], + embeddings=embeddings[i:end_idx].tolist(), + metadatas=metadatas[i:end_idx] + ) def load_index(self, index_dir: str): """Load an existing collection""" diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py index f68d768e..cb55b85d 100644 --- a/lotus/vector_store/pinecone_vs.py +++ b/lotus/vector_store/pinecone_vs.py @@ -37,7 +37,7 @@ def index(self, docs: pd.Series, embeddings: Any, index_dir: str, **kwargs: dic dimension = embeddings.shape[1] # Check if index already exists - if index_dir not in self.pinecone.list_indexes(): + if index_dir not in self.pinecone.list_indexes().names(): # Create new index with the correct dimension self.pinecone.create_index( name=index_dir, diff --git a/lotus/vector_store/qdrant_vs.py b/lotus/vector_store/qdrant_vs.py index ff3ff9c5..1e1d203c 100644 --- a/lotus/vector_store/qdrant_vs.py +++ b/lotus/vector_store/qdrant_vs.py @@ -39,7 +39,7 @@ def index(self, docs:pd.Series, embeddings, index_dir: str, **kwargs: dict[str, self.index_dir = index_dir # Get sample embedding to determine vector dimension - dimension = embeddings.reshape((len(embeddings), -1)).shape()[1] + dimension = np.reshape(embeddings, (len(embeddings), -1)).shape[1] # Create collection if it doesn't exist if not self.client.collection_exists(index_dir): From e9959969eea61147f9da10c6b5996d94143ff211 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Wed, 29 Jan 2025 06:48:50 -0800 Subject: [PATCH 45/65] added some fixes to collection upload error handling --- lotus/vector_store/pinecone_vs.py | 12 ++++++++++++ lotus/vector_store/qdrant_vs.py | 9 ++++++++- lotus/vector_store/weaviate_vs.py | 31 +++++++++++++++++++++++++++++-- 3 files changed, 49 insertions(+), 3 deletions(-) diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py index cb55b85d..f5c4ff32 100644 --- a/lotus/vector_store/pinecone_vs.py +++ b/lotus/vector_store/pinecone_vs.py @@ -48,6 +48,18 @@ def index(self, docs: pd.Series, embeddings: Any, index_dir: str, **kwargs: dic region='us-east-1' ) ) + elif self.pinecone.describe_index(index_dir).dimension != dimension: + # resolve any potential dimension-mismatch errors + self.pinecone.delete_index(index_dir) + self.pinecone.create_index( + name=index_dir, + dimension=dimension, + metric="cosine", + spec=ServerlessSpec( + cloud='aws', + region='us-east-1' + ) + ) # Connect to index self.pc_index = self.pinecone.Index(index_dir) diff --git a/lotus/vector_store/qdrant_vs.py b/lotus/vector_store/qdrant_vs.py index 1e1d203c..605eeb54 100644 --- a/lotus/vector_store/qdrant_vs.py +++ b/lotus/vector_store/qdrant_vs.py @@ -47,7 +47,14 @@ def index(self, docs:pd.Series, embeddings, index_dir: str, **kwargs: dict[str, collection_name=index_dir, vectors_config=VectorParams(size=dimension, distance=Distance.COSINE) ) - + collection_info = self.client.get_collection(index_dir) + if collection_info['vectors']['size'] != dimension: + # if there's a discrepency, create a new version of that collection + self.client.delete_collection(index_dir) + self.client.create_collection( + collection_name=index_dir, + vectors_config=VectorParams(size=dimension, distance=Distance.COSINE) + ) # Convert docs to list if it's a pandas Series docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 4e554f38..eb118188 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -38,9 +38,17 @@ def __init__(self, max_batch_size: int = 64): def __del__(self): self.client.close() + def get_collection_dimension(self, index_dir): + schema = self.client.schema.get() + for cls in schema['classes']: + if cls['class'] == index_dir: + return cls['vectorizer']['config']['dimensions'] + def index(self, docs: pd.Series, embeddings, index_dir: str, **kwargs: dict[str, Any]): """Create a collection and add documents with their embeddings""" self.index_dir = index_dir + + embedding_dim = np.reshape(embeddings, (len(embeddings), -1)).shape[1] # Create collection without vectorizer config (we'll provide vectors directly) if not self.client.collections.exists(index_dir): @@ -61,10 +69,29 @@ def index(self, docs: pd.Series, embeddings, index_dir: str, **kwargs: dict[str, ) else: collection = self.client.collections.get(index_dir) + if self.get_collection_dimension(index_dir) != embedding_dim: + self.client.collections.delete(index_dir) + collection = self.client.collections.create( + name=index_dir, + properties=[ + Property( + name='content', + data_type=DataType.TEXT + ), + Property( + name='doc_id', + data_type=DataType.INT, + ) + ], + vectorizer_config=None, # No vectorizer needed as we provide vectors + vector_index_config=Configure.VectorIndex.hnsw() + ) + # Generate embeddings for all documents docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs + # Add documents to collection with their embeddings with collection.batch.dynamic() as batch: for idx, (doc, embedding) in enumerate(zip(docs_list, embeddings)): @@ -162,8 +189,8 @@ def get_vectors_from_index(self, index_dir: str, ids: list[Any]) -> NDArray[np.f for obj in collection.query.fetch_objects().objects: if id == obj.properties.get('doc_id', -1): exists = True - print(f'vector example {obj.vector}') - vectors.append(obj.vector) + print(f'vector example {obj.vector} {obj.vector.values()}') + vectors.append(obj.vector.values()) if not exists: raise ValueError(f'{id} does not exist in {index_dir}') return np.array(vectors, dtype=np.float64) From 1a7548662aa675ca76855c7f2f4f1df46ed43a15 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Wed, 29 Jan 2025 07:16:10 -0800 Subject: [PATCH 46/65] made some other change --- lotus/vector_store/qdrant_vs.py | 2 +- lotus/vector_store/weaviate_vs.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/lotus/vector_store/qdrant_vs.py b/lotus/vector_store/qdrant_vs.py index 605eeb54..b9167a24 100644 --- a/lotus/vector_store/qdrant_vs.py +++ b/lotus/vector_store/qdrant_vs.py @@ -48,7 +48,7 @@ def index(self, docs:pd.Series, embeddings, index_dir: str, **kwargs: dict[str, vectors_config=VectorParams(size=dimension, distance=Distance.COSINE) ) collection_info = self.client.get_collection(index_dir) - if collection_info['vectors']['size'] != dimension: + if collection_info.config.params.vectors.size != dimension: # if there's a discrepency, create a new version of that collection self.client.delete_collection(index_dir) self.client.create_collection( diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index eb118188..cd52bf0c 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -39,10 +39,7 @@ def __del__(self): self.client.close() def get_collection_dimension(self, index_dir): - schema = self.client.schema.get() - for cls in schema['classes']: - if cls['class'] == index_dir: - return cls['vectorizer']['config']['dimensions'] + self.client.collections.get(index_dir).config def index(self, docs: pd.Series, embeddings, index_dir: str, **kwargs: dict[str, Any]): """Create a collection and add documents with their embeddings""" @@ -69,7 +66,9 @@ def index(self, docs: pd.Series, embeddings, index_dir: str, **kwargs: dict[str, ) else: collection = self.client.collections.get(index_dir) - if self.get_collection_dimension(index_dir) != embedding_dim: + print(self.client.collections.get(index_dir).config) + """ + if self.get_collection_dimension(index_dir) != embedding_dim: self.client.collections.delete(index_dir) collection = self.client.collections.create( name=index_dir, @@ -86,6 +85,7 @@ def index(self, docs: pd.Series, embeddings, index_dir: str, **kwargs: dict[str, vectorizer_config=None, # No vectorizer needed as we provide vectors vector_index_config=Configure.VectorIndex.hnsw() ) + """ # Generate embeddings for all documents From c5f50f65560c0f99c74bf93eb9b1d2669febee50 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Sat, 8 Feb 2025 16:41:55 -0800 Subject: [PATCH 47/65] fixed type errors for qdrant vs --- lotus/vector_store/qdrant_vs.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/lotus/vector_store/qdrant_vs.py b/lotus/vector_store/qdrant_vs.py index b9167a24..50f1017f 100644 --- a/lotus/vector_store/qdrant_vs.py +++ b/lotus/vector_store/qdrant_vs.py @@ -48,13 +48,27 @@ def index(self, docs:pd.Series, embeddings, index_dir: str, **kwargs: dict[str, vectors_config=VectorParams(size=dimension, distance=Distance.COSINE) ) collection_info = self.client.get_collection(index_dir) - if collection_info.config.params.vectors.size != dimension: - # if there's a discrepency, create a new version of that collection - self.client.delete_collection(index_dir) - self.client.create_collection( - collection_name=index_dir, - vectors_config=VectorParams(size=dimension, distance=Distance.COSINE) - ) + if (collection_info is not None and collection_info.config is not None and collection_info.config.params and collection_info.config.params.vectors): + + vectors = collection_info.config.params.vectors + if isinstance(vectors, dict): + # If it's a dict, decide how to handle it. + # Here we extract the first vector, but you may need a different logic. + vector = next(iter(vectors.values())) + size = vector.size + elif isinstance(vectors, VectorParams): + size = vectors.size + else: + size = None + + if size != dimension: + # If there's a discrepancy, create a new version of that collection + self.client.delete_collection(index_dir) + self.client.create_collection( + collection_name=index_dir, + vectors_config=VectorParams(size=dimension, distance=Distance.COSINE) + ) + # Convert docs to list if it's a pandas Series docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs From 85daf512cfe1ba528748f8097cec241369f3ce59 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Sat, 8 Feb 2025 17:20:15 -0800 Subject: [PATCH 48/65] changed endpoint --- lotus/vector_store/weaviate_vs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index cd52bf0c..40e85171 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -18,9 +18,9 @@ class WeaviateVS(VS): def __init__(self, max_batch_size: int = 64): - REST_URL = 'https://dovieiknqr20pmgoticrmw.c0.us-west3.gcp.weaviate.cloud' + REST_URL = 'https://aliugucnqnkzihdc3jqdig.c0.us-west3.gcp.weaviate.cloud' - API_KEY = 'nwRhjKLulSWbhPjX67WBklmJs7dgUS9XGWrZ' + API_KEY = 'e1VUifT3atB7PHLB3kXQPYhL2PNXeG0JeGYK' weaviate_client: weaviate.WeaviateClient | None = None # need to set this up From 6b80fd3c9bac2fd01b9282e006d3d7a9b5cb6334 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Sat, 8 Feb 2025 17:47:18 -0800 Subject: [PATCH 49/65] added changes --- lotus/vector_store/weaviate_vs.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 40e85171..690bd235 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -66,9 +66,7 @@ def index(self, docs: pd.Series, embeddings, index_dir: str, **kwargs: dict[str, ) else: collection = self.client.collections.get(index_dir) - print(self.client.collections.get(index_dir).config) - """ - if self.get_collection_dimension(index_dir) != embedding_dim: + if self.get_collection_dimension(index_dir) != embedding_dim: self.client.collections.delete(index_dir) collection = self.client.collections.create( name=index_dir, @@ -85,7 +83,6 @@ def index(self, docs: pd.Series, embeddings, index_dir: str, **kwargs: dict[str, vectorizer_config=None, # No vectorizer needed as we provide vectors vector_index_config=Configure.VectorIndex.hnsw() ) - """ # Generate embeddings for all documents From f90ff0fcacf4d1602001cd7905ab95a9a029b347 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Sat, 8 Feb 2025 19:43:33 -0800 Subject: [PATCH 50/65] added fixes --- lotus/vector_store/weaviate_vs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 690bd235..1857dc39 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -187,7 +187,7 @@ def get_vectors_from_index(self, index_dir: str, ids: list[Any]) -> NDArray[np.f if id == obj.properties.get('doc_id', -1): exists = True print(f'vector example {obj.vector} {obj.vector.values()}') - vectors.append(obj.vector.values()) + vectors.append(list(obj.vector.values())) if not exists: raise ValueError(f'{id} does not exist in {index_dir}') return np.array(vectors, dtype=np.float64) From cccfa394a8552765f1ab77abd3ded13115cc4296 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Sat, 8 Feb 2025 20:53:27 -0800 Subject: [PATCH 51/65] added some changes --- .github/tests/rm_tests.py | 6 +++--- lotus/vector_store/weaviate_vs.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/tests/rm_tests.py b/.github/tests/rm_tests.py index 3afcb745..31e6c0c0 100644 --- a/.github/tests/rm_tests.py +++ b/.github/tests/rm_tests.py @@ -219,7 +219,7 @@ def test_vs_search_rm_only(setup_models, setup_vs, vs, model): ] } df = pd.DataFrame(data) - df = df.sem_index("Course Name", "indexdir") + df = df.sem_index("Course Name", "secondindexdir") df = df.sem_search("Course Name", "Optimization", K=1) assert df["Course Name"].tolist() == ["Optimization Methods in Engineering"] @@ -240,7 +240,7 @@ def test_vs_sim_join(setup_models, setup_vs, vs, model): data2 = {"Skill": ["Math", "History"]} df1 = pd.DataFrame(data1) - df2 = pd.DataFrame(data2).sem_index("Skill", "indexdir") + df2 = pd.DataFrame(data2).sem_index("Skill", "thirdindexdir") joined_df = df1.sem_sim_join(df2, left_on="Course Name", right_on="Skill", K=1) joined_pairs = set(zip(joined_df["Course Name"], joined_df["Skill"])) expected_pairs = {("History of the Atlantic World", "History"), ("Riemannian Geometry", "Math")} @@ -266,7 +266,7 @@ def test_vs_dedup(setup_models, setup_vs, vs): ] } df = pd.DataFrame(data) - df = df.sem_index("Text", "indexdir").sem_dedup("Text", threshold=0.85) + df = df.sem_index("Text", "fourthindexdir").sem_dedup("Text", threshold=0.85) kept = df["Text"].tolist() kept.sort() assert len(kept) == 2, kept diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 1857dc39..925b26be 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -184,10 +184,11 @@ def get_vectors_from_index(self, index_dir: str, ids: list[Any]) -> NDArray[np.f for id in ids: exists = False for obj in collection.query.fetch_objects().objects: + print(f'object properties: {object.properties}') if id == obj.properties.get('doc_id', -1): exists = True print(f'vector example {obj.vector} {obj.vector.values()}') - vectors.append(list(obj.vector.values())) + vectors.append((obj.vector.values())) if not exists: raise ValueError(f'{id} does not exist in {index_dir}') return np.array(vectors, dtype=np.float64) From 6cf4f0a230c7655e98fba2e8b148b4259489654d Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Sat, 8 Feb 2025 21:14:52 -0800 Subject: [PATCH 52/65] added some change --- lotus/sem_ops/sem_search.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lotus/sem_ops/sem_search.py b/lotus/sem_ops/sem_search.py index 45090e01..6ae0e5ef 100644 --- a/lotus/sem_ops/sem_search.py +++ b/lotus/sem_ops/sem_search.py @@ -5,6 +5,7 @@ import lotus from lotus.cache import operator_cache from lotus.types import RerankerOutput, RMOutput +from lotus.vector_store.pinecone_vs import PineconeVS @pd.api.extensions.register_dataframe_accessor("sem_search") @@ -60,7 +61,10 @@ def __call__( assert vs.index_dir == col_index_dir df_idxs = self._obj.index - K = min(K, len(df_idxs)) + cur_min = len(df_idxs) + if isinstance(vs, PineconeVS): + cur_min = min(cur_min, 10000) + K = min(K, cur_min) search_K = K while True: From 0438b185003350202ac150131de59d3f5a8f436c Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Sat, 8 Feb 2025 21:15:49 -0800 Subject: [PATCH 53/65] another set of changes --- lotus/vector_store/weaviate_vs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 925b26be..92415fcf 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -184,7 +184,7 @@ def get_vectors_from_index(self, index_dir: str, ids: list[Any]) -> NDArray[np.f for id in ids: exists = False for obj in collection.query.fetch_objects().objects: - print(f'object properties: {object.properties}') + print(f'object properties: {obj.properties}') if id == obj.properties.get('doc_id', -1): exists = True print(f'vector example {obj.vector} {obj.vector.values()}') From 43e9bc314b0d8a6c1a12606c6af28577acd2c6c0 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Sat, 8 Feb 2025 21:54:27 -0800 Subject: [PATCH 54/65] added other logs --- lotus/vector_store/weaviate_vs.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 92415fcf..1b5906ad 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -184,11 +184,10 @@ def get_vectors_from_index(self, index_dir: str, ids: list[Any]) -> NDArray[np.f for id in ids: exists = False for obj in collection.query.fetch_objects().objects: - print(f'object properties: {obj.properties}') + print(f'object vector: {obj.vector}') if id == obj.properties.get('doc_id', -1): exists = True - print(f'vector example {obj.vector} {obj.vector.values()}') - vectors.append((obj.vector.values())) + vectors.append((1)) if not exists: raise ValueError(f'{id} does not exist in {index_dir}') return np.array(vectors, dtype=np.float64) From 90d07d026c171337c685f1c88315b04185e74700 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Sat, 8 Feb 2025 22:13:33 -0800 Subject: [PATCH 55/65] added logging --- lotus/vector_store/weaviate_vs.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 1b5906ad..67b4afc1 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -180,11 +180,10 @@ def get_vectors_from_index(self, index_dir: str, ids: list[Any]) -> NDArray[np.f # Query for documents with specific doc_ids vectors = [] - for id in ids: exists = False for obj in collection.query.fetch_objects().objects: - print(f'object vector: {obj.vector}') + print(f'object vector: {collection.query.__extract_vector_for_object(obj.metadata)}') if id == obj.properties.get('doc_id', -1): exists = True vectors.append((1)) From 6bf69ff1dfce9aea2ed01d5bbef05d75a223da9a Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Thu, 13 Feb 2025 19:03:51 -0800 Subject: [PATCH 56/65] chroma_vs implementation --- lotus/vector_store/pinecone_vs.py | 176 --------------------------- lotus/vector_store/qdrant_vs.py | 187 ---------------------------- lotus/vector_store/weaviate_vs.py | 194 ------------------------------ 3 files changed, 557 deletions(-) delete mode 100644 lotus/vector_store/pinecone_vs.py delete mode 100644 lotus/vector_store/qdrant_vs.py delete mode 100644 lotus/vector_store/weaviate_vs.py diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py deleted file mode 100644 index f5c4ff32..00000000 --- a/lotus/vector_store/pinecone_vs.py +++ /dev/null @@ -1,176 +0,0 @@ -from typing import Any - -import numpy as np -import pandas as pd -from numpy.typing import NDArray -from tqdm import tqdm - -from lotus.types import RMOutput -from lotus.vector_store.vs import VS - -try: - from pinecone import Index, Pinecone, ServerlessSpec -except ImportError as err: - raise ImportError( - "The pinecone library is required to use PineconeVS. Install it with `pip install pinecone`", - ) from err - -class PineconeVS(VS): - def __init__(self, max_batch_size: int = 64): - - api_key = 'pcsk_45ecSY_CW62eJeL4jwj6dUfaqM6j9dL3uwK12rudednzGisWMxJv9bHH2DLz6tWoY91W84' - - """Initialize Pinecone client with API key and environment""" - super() - self.pinecone = Pinecone(api_key=api_key) - self.pc_index:Index | None = None - self.max_batch_size = max_batch_size - - def __del__(self): - return - - - def index(self, docs: pd.Series, embeddings: Any, index_dir: str, **kwargs: dict[str, Any]): - """Create an index and add documents to it""" - self.index_dir = index_dir - - dimension = embeddings.shape[1] - - # Check if index already exists - if index_dir not in self.pinecone.list_indexes().names(): - # Create new index with the correct dimension - self.pinecone.create_index( - name=index_dir, - dimension=dimension, - metric="cosine", - spec=ServerlessSpec( - cloud='aws', - region='us-east-1' - ) - ) - elif self.pinecone.describe_index(index_dir).dimension != dimension: - # resolve any potential dimension-mismatch errors - self.pinecone.delete_index(index_dir) - self.pinecone.create_index( - name=index_dir, - dimension=dimension, - metric="cosine", - spec=ServerlessSpec( - cloud='aws', - region='us-east-1' - ) - ) - - # Connect to index - self.pc_index = self.pinecone.Index(index_dir) - - # Convert docs to list if it's a pandas Series - docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs - - # Prepare vectors for upsert - vectors = [] - for idx, (embedding, doc) in enumerate(zip(embeddings, docs_list)): - vectors.append({ - "id": str(idx), - "values": embedding.tolist(), # Pinecone expects lists, not numpy arrays - "metadata": { - "content": doc, - "doc_id": idx - } - }) - - # Upsert in batches of 100 - batch_size = 100 - for i in tqdm(range(0, len(vectors), batch_size), desc="Uploading to Pinecone"): - batch = vectors[i:i + batch_size] - self.pc_index.upsert(vectors=batch) - - def load_index(self, index_dir: str): - """Connect to an existing Pinecone index""" - if index_dir not in self.pinecone.list_indexes(): - raise ValueError(f"Index {index_dir} not found") - - self.index_dir = index_dir - self.pc_index = self.pinecone.Index(index_dir) - - def __call__( - self, - query_vectors, - K: int, - **kwargs: dict[str, Any] - ) -> RMOutput: - """Perform vector search using Pinecone""" - if self.pc_index is None: - raise ValueError("No index loaded. Call load_index first.") - - """ - # Convert single query to list - if isinstance(queries, (str, Image.Image)): - queries = [queries] - - # Handle numpy array queries (pre-computed vectors) - if isinstance(queries, np.ndarray): - query_vectors = queries - else: - # Convert queries to list if needed - if isinstance(queries, pd.Series): - queries = queries.tolist() - # Create embeddings for text queries - query_vectors = self._batch_embed(queries) - """ - - # Perform searches - all_distances = [] - all_indices = [] - - for query_vector in query_vectors: - # Query Pinecone - results = self.pc_index.query( - vector=query_vector.tolist(), - top_k=K, - include_metadata=True, - **kwargs - ) - - # Extract distances and indices - distances = [] - indices = [] - - for match in results.matches: - indices.append(int(match.metadata["doc_id"])) - distances.append(match.score) - - # Pad results if fewer than K matches - while len(indices) < K: - indices.append(-1) # Use -1 for padding - distances.append(0.0) - - all_distances.append(distances) - all_indices.append(indices) - - return RMOutput( - distances=np.array(all_distances, dtype=np.float32).tolist(), - indices=np.array(all_indices, dtype=np.int64).tolist() - ) - - def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]: - """Retrieve vectors for specific document IDs""" - if self.pc_index is None or self.index_dir != index_dir: - self.load_index(index_dir) - - if self.pc_index is None: # Add this check after load_index - raise ValueError("Failed to initialize Pinecone index") - - - - # Fetch vectors from Pinecone - vectors = [] - for doc_id in ids: - response = self.pc_index.fetch(ids=[str(doc_id)]) - if str(doc_id) in response.vectors: - vector = response.vectors[str(doc_id)].values - vectors.append(vector) - else: - raise ValueError(f"Document with id {doc_id} not found") - - return np.array(vectors, dtype=np.float64) diff --git a/lotus/vector_store/qdrant_vs.py b/lotus/vector_store/qdrant_vs.py deleted file mode 100644 index 50f1017f..00000000 --- a/lotus/vector_store/qdrant_vs.py +++ /dev/null @@ -1,187 +0,0 @@ -from typing import Any - -import numpy as np -import pandas as pd -from numpy.typing import NDArray -from tqdm import tqdm - -from lotus.types import RMOutput -from lotus.vector_store.vs import VS - -try: - from qdrant_client import QdrantClient - from qdrant_client.models import Distance, PointStruct, VectorParams -except ImportError as err: - raise ImportError("Please install the qdrant client") from err - -class QdrantVS(VS): - def __init__(self, max_batch_size: int = 64): - - API_KEY = '_Mic3dVln2gAkS6NLyia6p-CCyMScK42ayuq8Rapm5-xsV5j5_UlIA' - - URL = "https://6f8b9aec-a788-4aac-9aeb-417d307493e8.europe-west3-0.gcp.cloud.qdrant.io:6333" - - client: QdrantClient = QdrantClient( - url=URL, - api_key=API_KEY - ) - - """Initialize with Qdrant client and embedding model""" - super() # Fixed the super() call syntax - self.client: QdrantClient = client - self.max_batch_size = max_batch_size - - def __del__(self): - self.client.close() - - def index(self, docs:pd.Series, embeddings, index_dir: str, **kwargs: dict[str, Any]): - """Create a collection and add documents with their embeddings""" - self.index_dir = index_dir - - # Get sample embedding to determine vector dimension - dimension = np.reshape(embeddings, (len(embeddings), -1)).shape[1] - - # Create collection if it doesn't exist - if not self.client.collection_exists(index_dir): - self.client.create_collection( - collection_name=index_dir, - vectors_config=VectorParams(size=dimension, distance=Distance.COSINE) - ) - collection_info = self.client.get_collection(index_dir) - if (collection_info is not None and collection_info.config is not None and collection_info.config.params and collection_info.config.params.vectors): - - vectors = collection_info.config.params.vectors - if isinstance(vectors, dict): - # If it's a dict, decide how to handle it. - # Here we extract the first vector, but you may need a different logic. - vector = next(iter(vectors.values())) - size = vector.size - elif isinstance(vectors, VectorParams): - size = vectors.size - else: - size = None - - if size != dimension: - # If there's a discrepancy, create a new version of that collection - self.client.delete_collection(index_dir) - self.client.create_collection( - collection_name=index_dir, - vectors_config=VectorParams(size=dimension, distance=Distance.COSINE) - ) - - # Convert docs to list if it's a pandas Series - docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs - - # Prepare points for upload - points = [] - for idx, (doc, embedding) in enumerate(zip(docs_list, embeddings)): - points.append( - PointStruct( - id=idx, - vector=embedding.tolist(), - payload={ - "content": doc, - "doc_id": idx - } - ) - ) - - # Upload in batches - batch_size = 100 - for i in tqdm(range(0, len(points), batch_size), desc="Uploading to Qdrant"): - batch = points[i:i + batch_size] - self.client.upsert( - collection_name=index_dir, - points=batch - ) - - def load_index(self, index_dir: str): - """Set the collection name to use""" - if not self.client.collection_exists(index_dir): - raise ValueError(f"Collection {index_dir} not found") - self.index_dir = index_dir - - def __call__( - self, - query_vectors, - K: int, - **kwargs: dict[str, Any] - ) -> RMOutput: - """Perform vector search using Qdrant""" - if self.index_dir is None: - raise ValueError("No collection loaded. Call load_index first.") - - """ - - do this in retriever module before passing into here - - # Convert single query to list - if isinstance(queries, (str, Image.Image)): - queries = [queries] - - # Handle numpy array queries (pre-computed vectors) - if isinstance(queries, np.ndarray): - query_vectors = queries - else: - # Convert queries to list if needed - if isinstance(queries, pd.Series): - queries = queries.tolist() - # Create embeddings for text queries - query_vectors = self._batch_embed(queries) - """ - - # Perform searches - all_distances = [] - all_indices = [] - - for query_vector in query_vectors: - results = self.client.search( - collection_name=self.index_dir, - query_vector=query_vector.tolist(), - limit=K, - with_payload=True - ) - - # Extract distances and indices - distances = [] - indices = [] - - for result in results: - indices.append(result.id) - distances.append(result.score) # Qdrant returns cosine similarity directly - - # Pad results if fewer than K matches - while len(indices) < K: - indices.append(-1) - distances.append(0.0) - - all_distances.append(distances) - all_indices.append(indices) - - return RMOutput( - distances=np.array(all_distances, dtype=np.float32).tolist(), - indices=np.array(all_indices, dtype=np.int64).tolist() - ) - - def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]: - """Retrieve vectors for specific document IDs""" - if self.index_dir != index_dir: - self.load_index(index_dir) - - # Fetch points from Qdrant - points = self.client.retrieve( - collection_name=index_dir, - ids=ids, - with_vectors=True, - with_payload=False - ) - - # Extract and return vectors - vectors = [] - for point in points: - if point.vector is not None: - vectors.append(point.vector) - else: - raise ValueError(f"Vector not found for id {point.id}") - - return np.array(vectors, dtype=np.float64) \ No newline at end of file diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py deleted file mode 100644 index 67b4afc1..00000000 --- a/lotus/vector_store/weaviate_vs.py +++ /dev/null @@ -1,194 +0,0 @@ -from typing import Any, List - -import numpy as np -import pandas as pd -from numpy.typing import NDArray - -from lotus.types import RMOutput -from lotus.vector_store.vs import VS - -try: - import weaviate - from weaviate.classes.config import Configure, DataType, Property - from weaviate.classes.init import Auth - from weaviate.classes.query import MetadataQuery -except ImportError as err: - raise ImportError("Please install the weaviate client") from err - -class WeaviateVS(VS): - def __init__(self, max_batch_size: int = 64): - - REST_URL = 'https://aliugucnqnkzihdc3jqdig.c0.us-west3.gcp.weaviate.cloud' - - API_KEY = 'e1VUifT3atB7PHLB3kXQPYhL2PNXeG0JeGYK' - - weaviate_client: weaviate.WeaviateClient | None = None # need to set this up - - - weaviate_client = weaviate.connect_to_weaviate_cloud( - cluster_url=REST_URL, # Replace with your Weaviate Cloud URL - auth_credentials=Auth.api_key(API_KEY), # Replace with your Weaviate Cloud key - ) - - """Initialize with Weaviate client and embedding model""" - super() - self.client = weaviate_client - self.max_batch_size = max_batch_size - - def __del__(self): - self.client.close() - - def get_collection_dimension(self, index_dir): - self.client.collections.get(index_dir).config - - def index(self, docs: pd.Series, embeddings, index_dir: str, **kwargs: dict[str, Any]): - """Create a collection and add documents with their embeddings""" - self.index_dir = index_dir - - embedding_dim = np.reshape(embeddings, (len(embeddings), -1)).shape[1] - - # Create collection without vectorizer config (we'll provide vectors directly) - if not self.client.collections.exists(index_dir): - collection = self.client.collections.create( - name=index_dir, - properties=[ - Property( - name='content', - data_type=DataType.TEXT - ), - Property( - name='doc_id', - data_type=DataType.INT, - ) - ], - vectorizer_config=None, # No vectorizer needed as we provide vectors - vector_index_config=Configure.VectorIndex.hnsw() - ) - else: - collection = self.client.collections.get(index_dir) - if self.get_collection_dimension(index_dir) != embedding_dim: - self.client.collections.delete(index_dir) - collection = self.client.collections.create( - name=index_dir, - properties=[ - Property( - name='content', - data_type=DataType.TEXT - ), - Property( - name='doc_id', - data_type=DataType.INT, - ) - ], - vectorizer_config=None, # No vectorizer needed as we provide vectors - vector_index_config=Configure.VectorIndex.hnsw() - ) - - - # Generate embeddings for all documents - docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs - - - # Add documents to collection with their embeddings - with collection.batch.dynamic() as batch: - for idx, (doc, embedding) in enumerate(zip(docs_list, embeddings)): - properties = { - "content": doc, - "doc_id": idx - } - batch.add_object( - properties=properties, - vector=embedding.tolist(), # Provide pre-computed vector - ) - - def load_index(self, index_dir: str): - """Load/set the collection name to use""" - self.index_dir = index_dir - # Verify collection exists - try: - self.client.collections.get(index_dir) - except weaviate.exceptions.UnexpectedStatusCodeException: - raise ValueError(f"Collection {index_dir} not found") - - def __call__(self, - query_vectors, - K: int, - **kwargs: dict[str, Any] - ) -> RMOutput: - """Perform vector search using pre-computed query vectors""" - if self.index_dir is None: - raise ValueError("No collection loaded. Call load_index first.") - - collection = self.client.collections.get(self.index_dir) - - """ - - do this in the retriever module - # Convert single query to list - if isinstance(queries, (str, Image.Image)): - queries = [queries] - - # Handle numpy array queries (pre-computed vectors) - if isinstance(queries, np.ndarray): - query_vectors = queries - else: - # Generate embeddings for text queries - query_vectors = self._batch_embed(queries) - """ - - # Perform searches - results = [] - for query_vector in query_vectors: - response = (collection.query - .near_vector( - near_vector=query_vector.tolist(), - limit=K, - return_metadata=MetadataQuery(distance=True) - )) - results.append(response) - - # Process results into expected format - all_distances = [] - all_indices = [] - - for result in results: - objects = result.objects - - distances:List[float] = [] - indices = [] - for obj in objects: - indices.append(obj.properties.get('doc_id', -1)) - # Convert cosine distance to similarity score - distance = obj.metadata.distance if obj.metadata and obj.metadata.distance is not None else 1.0 - distances.append(1 - distance) # Convert distance to similarity - # Pad results if fewer than K matches - while len(indices) < K: - indices.append(-1) - distances.append(0.0) - - all_distances.append(distances) - all_indices.append(indices) - - return RMOutput( - distances=np.array(all_distances, dtype=np.float32).tolist(), - indices=np.array(all_indices, dtype=np.int64).tolist() - ) - - def get_vectors_from_index(self, index_dir: str, ids: list[Any]) -> NDArray[np.float64]: - """Retrieve vectors for specific document IDs""" - collection = self.client.collections.get(index_dir) - - # Query for documents with specific doc_ids - vectors = [] - for id in ids: - exists = False - for obj in collection.query.fetch_objects().objects: - print(f'object vector: {collection.query.__extract_vector_for_object(obj.metadata)}') - if id == obj.properties.get('doc_id', -1): - exists = True - vectors.append((1)) - if not exists: - raise ValueError(f'{id} does not exist in {index_dir}') - return np.array(vectors, dtype=np.float64) - - From 38580ce09ce6c9dcbf2ef8230d76603bd32ad8ac Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Thu, 13 Feb 2025 19:12:45 -0800 Subject: [PATCH 57/65] removed unused imports --- .github/tests/rm_tests.py | 5 +---- lotus/vector_store/__init__.py | 5 +---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/.github/tests/rm_tests.py b/.github/tests/rm_tests.py index 31e6c0c0..38ce90e1 100644 --- a/.github/tests/rm_tests.py +++ b/.github/tests/rm_tests.py @@ -5,7 +5,7 @@ import lotus from lotus.models import CrossEncoderReranker, LiteLLMRM, SentenceTransformersRM -from lotus.vector_store import ChromaVS, FaissVS, PineconeVS, QdrantVS, WeaviateVS +from lotus.vector_store import ChromaVS, FaissVS ################################################################################ # Setup @@ -33,10 +33,7 @@ VECTOR_STORE_TO_CLS = { 'local': FaissVS, - 'weaviate':WeaviateVS, - 'pinecone': PineconeVS, 'chroma': ChromaVS, - 'qdrant': QdrantVS } diff --git a/lotus/vector_store/__init__.py b/lotus/vector_store/__init__.py index b1efbeac..327ab06a 100644 --- a/lotus/vector_store/__init__.py +++ b/lotus/vector_store/__init__.py @@ -1,8 +1,5 @@ from lotus.vector_store.vs import VS from lotus.vector_store.faiss_vs import FaissVS -from lotus.vector_store.weaviate_vs import WeaviateVS -from lotus.vector_store.pinecone_vs import PineconeVS from lotus.vector_store.chroma_vs import ChromaVS -from lotus.vector_store.qdrant_vs import QdrantVS -__all__ = ["VS", "FaissVS", "WeaviateVS", "PineconeVS", "ChromaVS", "QdrantVS"] +__all__ = ["VS", "FaissVS", "ChromaVS"] From 988e072a06f58d2fea6c0c343dceb5dd503424a1 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Thu, 13 Feb 2025 19:18:32 -0800 Subject: [PATCH 58/65] removed pinecone reference --- lotus/sem_ops/sem_search.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/lotus/sem_ops/sem_search.py b/lotus/sem_ops/sem_search.py index 6ae0e5ef..4cdf973d 100644 --- a/lotus/sem_ops/sem_search.py +++ b/lotus/sem_ops/sem_search.py @@ -5,7 +5,6 @@ import lotus from lotus.cache import operator_cache from lotus.types import RerankerOutput, RMOutput -from lotus.vector_store.pinecone_vs import PineconeVS @pd.api.extensions.register_dataframe_accessor("sem_search") @@ -62,8 +61,6 @@ def __call__( df_idxs = self._obj.index cur_min = len(df_idxs) - if isinstance(vs, PineconeVS): - cur_min = min(cur_min, 10000) K = min(K, cur_min) search_K = K From c58e4799a6a66545e436d5b81516fc9eecfe7903 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Sun, 16 Feb 2025 14:03:47 -0800 Subject: [PATCH 59/65] removed merge conflicts --- lotus/sem_ops/sem_search.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/lotus/sem_ops/sem_search.py b/lotus/sem_ops/sem_search.py index 3a66cd33..4cdf973d 100644 --- a/lotus/sem_ops/sem_search.py +++ b/lotus/sem_ops/sem_search.py @@ -61,10 +61,6 @@ def __call__( df_idxs = self._obj.index cur_min = len(df_idxs) -<<<<<<< HEAD -======= - ->>>>>>> 6b9bcfa5439dd6aeff87f754e303127803ed6cb6 K = min(K, cur_min) search_K = K From 3ffd039651f34af3e36f203cd8ad34539a14da96 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Sun, 16 Feb 2025 14:12:26 -0800 Subject: [PATCH 60/65] modified chroma_vs call function to include optional filtering with ids --- lotus/vector_store/chroma_vs.py | 54 +++++++++++++++------------------ 1 file changed, 25 insertions(+), 29 deletions(-) diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index ad3009a8..52e53796 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Union +from typing import Any, Mapping, Optional, Union import numpy as np import pandas as pd @@ -88,59 +88,55 @@ def __call__( self, query_vectors, K: int, + ids: Optional[list[Any]] = None, **kwargs: dict[str, Any] ) -> RMOutput: - """Perform vector search using ChromaDB""" - if self.collection is None: - raise ValueError("No collection loaded. Call load_index first.") + """ + Perform vector search using ChromaDB with optional filtering by document IDs. + Args: + query_vectors: Pre-embedded query vectors. + K (int): Number of nearest neighbors to retrieve. + ids (Optional[list[Any]]): If provided, the search will be limited to documents with these ids. + **kwargs: Additional parameters. + Returns: + RMOutput: Contains the distances and indices of the nearest neighbors. """ - # Convert single query to list - if isinstance(queries, (str, Image.Image)): - queries = [queries] - - # Handle numpy array queries (pre-computed vectors) - if isinstance(queries, np.ndarray): - query_vectors = queries - else: - # Convert queries to list if needed - if isinstance(queries, pd.Series): - queries = queries.tolist() - # Create embeddings for text queries - query_vectors = self._batch_embed(queries) + if self.collection is None: + raise ValueError("No collection loaded. Call load_index first.") - """ + # If an ids list is provided, build a filter on the "doc_id" field. + where_filter = {"doc_id": {"$in": ids}} if ids is not None else None - # Perform searches all_distances = [] all_indices = [] + # Process each query vector. for query_vector in query_vectors: results = self.collection.query( query_embeddings=[query_vector.tolist()], n_results=K, - include=[IncludeEnum.metadatas, IncludeEnum.distances] + include=[IncludeEnum.metadatas, IncludeEnum.distances], + where=where_filter, ) - # Extract distances and indices distances = [] indices = [] - - if results['metadatas'] and results['distances']: - for metadata, distance in zip(results['metadatas'][0], results['distances'][0]): - indices.append(metadata['doc_id']) - # ChromaDB returns squared L2 distances, convert to cosine similarity - # similarity = 1 - (distance / 2) # Convert L2 distance to cosine similarity + + if results.get("metadatas") and results.get("distances"): + for metadata, distance in zip(results["metadatas"][0], results["distances"][0]): + indices.append(metadata["doc_id"]) + # ChromaDB returns squared L2 distances; convert these to cosine similarity. distances.append(1 - (distance / 2)) - # Pad results if fewer than K matches + # Pad results if fewer than K matches are returned. while len(indices) < K: indices.append(-1) distances.append(0.0) - all_distances.append(distances) all_indices.append(indices) + all_distances.append(distances) return RMOutput( distances=np.array(all_distances, dtype=np.float32).tolist(), From 9eca851fbbffcb5256d1e2d0fba09ddff78608e1 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Sun, 16 Feb 2025 14:20:25 -0800 Subject: [PATCH 61/65] fixed where filter --- lotus/vector_store/chroma_vs.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index 52e53796..491061c1 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -107,7 +107,13 @@ def __call__( raise ValueError("No collection loaded. Call load_index first.") # If an ids list is provided, build a filter on the "doc_id" field. - where_filter = {"doc_id": {"$in": ids}} if ids is not None else None + where_filter = { + "$or":[ + {"doc_id": {"$eq": id}} for id in ids + ] + } + + #{"doc_id": {"$in": ids}} if ids is not None else None all_distances = [] all_indices = [] From 63a378c30317b6c28342c95f1f72e1f2b891f85d Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Sun, 16 Feb 2025 14:43:59 -0800 Subject: [PATCH 62/65] fixed typing errs --- lotus/vector_store/chroma_vs.py | 47 ++++++++++++++++++--------------- lotus/vector_store/vs.py | 4 +-- 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index 491061c1..7541fbe7 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -1,5 +1,6 @@ -from typing import Any, Mapping, Optional, Union +from typing import Any, Mapping, Optional, Union, cast, List +from chromadb import Where import numpy as np import pandas as pd from numpy.typing import NDArray @@ -88,7 +89,7 @@ def __call__( self, query_vectors, K: int, - ids: Optional[list[Any]] = None, + ids: Optional[list[int]] = None, **kwargs: dict[str, Any] ) -> RMOutput: """ @@ -106,35 +107,39 @@ def __call__( if self.collection is None: raise ValueError("No collection loaded. Call load_index first.") - # If an ids list is provided, build a filter on the "doc_id" field. - where_filter = { - "$or":[ - {"doc_id": {"$eq": id}} for id in ids - ] - } - - #{"doc_id": {"$in": ids}} if ids is not None else None - - all_distances = [] - all_indices = [] + all_distances: list[list[float]] = [] + all_indices: list[list[int]] = [] # Process each query vector. for query_vector in query_vectors: + # Prepare the where clause by casting ids to a list of allowed types. + where_clause: Optional[dict[str, Union[dict[str, List[Union[str, int, float, bool]]]]]] = None + if ids: + where_clause = {"doc_id": {"$in": cast(List[Union[str, int, float, bool]], ids)}} + results = self.collection.query( query_embeddings=[query_vector.tolist()], n_results=K, include=[IncludeEnum.metadatas, IncludeEnum.distances], - where=where_filter, + where=cast(Where, where_clause), ) - distances = [] - indices = [] + distances: list[float] = [] + indices: list[int] = [] - if results.get("metadatas") and results.get("distances"): - for metadata, distance in zip(results["metadatas"][0], results["distances"][0]): - indices.append(metadata["doc_id"]) - # ChromaDB returns squared L2 distances; convert these to cosine similarity. - distances.append(1 - (distance / 2)) + # Retrieve and cast search results to help the type checker. + metadatas = results.get("metadatas") + dists = results.get("distances") + if metadatas is not None and dists is not None: + metadatas = cast( + List[List[Mapping[str, Union[str, int, float, bool]]]], metadatas + ) + dists = cast(List[List[float]], dists) + for metadata, distance in zip(metadatas[0], dists[0]): + if metadata is not None and distance is not None: + indices.append(int(metadata["doc_id"])) + # Convert squared L2 distances to cosine similarity. + distances.append(1 - (distance / 2)) # Pad results if fewer than K matches are returned. while len(indices) < K: diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py index a3632219..01087757 100644 --- a/lotus/vector_store/vs.py +++ b/lotus/vector_store/vs.py @@ -33,7 +33,7 @@ def __call__( self, query_vectors: Any, K: int, - ids: Optional[list[Any]] = None, + ids: Optional[list[int]] = None, **kwargs: dict[str, Any], ) -> RMOutput: """ @@ -52,7 +52,7 @@ def __call__( pass @abstractmethod - def get_vectors_from_index(self, index_dir: str, ids: list[Any]) -> NDArray[np.float64]: + def get_vectors_from_index(self, index_dir: str, ids: list[int]) -> NDArray[np.float64]: """ Retrieve vectors from a stored index given specific ids. """ From d376a49c88d6b1ba2147bc232e227d7e6d8653a6 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Sun, 16 Feb 2025 14:47:51 -0800 Subject: [PATCH 63/65] fixed linting --- lotus/vector_store/chroma_vs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index 7541fbe7..64ec9be2 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -1,8 +1,8 @@ -from typing import Any, Mapping, Optional, Union, cast, List +from typing import Any, List, Mapping, Optional, Union, cast -from chromadb import Where import numpy as np import pandas as pd +from chromadb import Where from numpy.typing import NDArray from tqdm import tqdm From 6aef44667a399a966ccde160da870a75c40c5796 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Sun, 16 Feb 2025 15:00:52 -0800 Subject: [PATCH 64/65] changed threshhold --- .github/tests/rm_tests.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/tests/rm_tests.py b/.github/tests/rm_tests.py index 7720ce1e..24c54e53 100644 --- a/.github/tests/rm_tests.py +++ b/.github/tests/rm_tests.py @@ -255,6 +255,9 @@ def test_vs_sim_join(setup_models, setup_vs, vs, model): ) @pytest.mark.parametrize("vs", VECTOR_STORE_TO_CLS.keys()) def test_vs_dedup(setup_models, setup_vs, vs): + curr_threshold = 0.85 + if vs == "chroma": + curr_threshold = 0.9 rm = setup_models["intfloat/e5-small-v2"] my_vs = setup_vs[vs] lotus.settings.configure(rm=rm, vs=my_vs) @@ -267,7 +270,7 @@ def test_vs_dedup(setup_models, setup_vs, vs): ] } df = pd.DataFrame(data) - df = df.sem_index("Text", "fourthindexdir").sem_dedup("Text", threshold=0.85) + df = df.sem_index("Text", "fourthindexdir").sem_dedup("Text", threshold=curr_threshold) kept = df["Text"].tolist() kept.sort() assert len(kept) == 2, kept From 720e5529a23141d6bf9ea9078be03ab78e3175f4 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi <58890349+AmoghTantradi@users.noreply.github.com> Date: Sun, 16 Feb 2025 15:13:09 -0800 Subject: [PATCH 65/65] added additional test for chroma_vs --- .github/tests/rm_tests.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/tests/rm_tests.py b/.github/tests/rm_tests.py index 24c54e53..0668bbe0 100644 --- a/.github/tests/rm_tests.py +++ b/.github/tests/rm_tests.py @@ -324,8 +324,9 @@ def test_search(setup_models): df = df.sem_search("Course Name", "Optimization", K=2, n_rerank=1) assert df["Course Name"].tolist() == ["Optimization Methods in Engineering"] +@pytest.mark.parametrize("vs", VECTOR_STORE_TO_CLS.keys()) @pytest.mark.parametrize("model", get_enabled("intfloat/e5-small-v2", "text-embedding-3-small")) -def test_filtered_vector_search(setup_models, model): +def test_filtered_vector_search(setup_models, setup_vs, vs, model): """ Test filtered vector search. @@ -340,7 +341,7 @@ def test_filtered_vector_search(setup_models, model): expected to pick out the culinary course "Gourmet Cooking Advanced". """ rm = setup_models[model] - vs = FaissVS() + vs = setup_vs[vs] lotus.settings.configure(rm=rm, vs=vs) data = {