diff --git a/pyproject.toml b/pyproject.toml index 7058106..4591a9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -516,6 +516,13 @@ module = [ "fastapi.*", "uvicorn.*", "whisper.*", + # Optional deps for ``locus.rag.multimodal`` — only loaded inside + # try/except for users who want OCR / PDF / image / audio extraction. + "pytesseract.*", + "PIL.*", + "pypdf.*", + "PyPDF2.*", + "pdf2image.*", ] ignore_missing_imports = true @@ -528,7 +535,6 @@ ignore_missing_imports = true # the strict block above. [[tool.mypy.overrides]] module = [ - "locus.rag.*", # ``locus.hooks.builtin.*`` is held back: the module-level migration # surfaces a real runtime bug — the LoggingHook / GuardrailsHook / # TelemetryHook signatures pre-date the event-based ``HookProvider`` diff --git a/src/locus/rag/__init__.py b/src/locus/rag/__init__.py index 63c9a79..de46285 100644 --- a/src/locus/rag/__init__.py +++ b/src/locus/rag/__init__.py @@ -62,6 +62,8 @@ ... ) """ +from typing import Any + # Embeddings from locus.rag.embeddings.base import ( BaseEmbedding, @@ -128,7 +130,7 @@ ] -def __getattr__(name: str): +def __getattr__(name: str) -> Any: """Lazy import providers and stores.""" # Embedding providers if name == "OCIEmbeddings": diff --git a/src/locus/rag/embeddings/__init__.py b/src/locus/rag/embeddings/__init__.py index dea830f..07e48f8 100644 --- a/src/locus/rag/embeddings/__init__.py +++ b/src/locus/rag/embeddings/__init__.py @@ -9,6 +9,8 @@ - OpenAIEmbeddings: OpenAI text-embedding models """ +from typing import Any + from locus.rag.embeddings.base import ( BaseEmbedding, EmbeddingConfig, @@ -29,7 +31,7 @@ ] -def __getattr__(name: str): +def __getattr__(name: str) -> Any: """Lazy import providers to avoid requiring all dependencies.""" if name == "OCIEmbeddings": from locus.rag.embeddings.oci import OCIEmbeddings diff --git a/src/locus/rag/embeddings/oci.py b/src/locus/rag/embeddings/oci.py index 43c3bf3..b380981 100644 --- a/src/locus/rag/embeddings/oci.py +++ b/src/locus/rag/embeddings/oci.py @@ -182,15 +182,19 @@ async def _get_client(self) -> GenerativeAiInferenceClient: config_file = os.path.expanduser(config_file) - self._oci_config_dict = oci.config.from_file( + # ``oci.config.from_file`` returns the parsed config dict; the + # ``_oci_config_dict`` field is declared as ``... | None`` to + # represent the pre-init state, so bind to a local before use. + config_dict: dict[str, Any] = oci.config.from_file( config_file, self.oci_config.profile_name, ) + self._oci_config_dict = config_dict # Determine service endpoint endpoint = self.oci_config.service_endpoint if endpoint is None: - region = self._oci_config_dict.get("region", "us-chicago-1") + region = config_dict.get("region", "us-chicago-1") endpoint = f"https://inference.generativeai.{region}.oci.oraclecloud.com" # Determine auth type - respect explicit setting, only auto-detect if needed @@ -202,21 +206,21 @@ async def _get_client(self) -> GenerativeAiInferenceClient: # 3. Config doesn't have user field (api_key profiles have user) if ( auth_type != "api_key" - and "security_token_file" in self._oci_config_dict - and "user" not in self._oci_config_dict + and "security_token_file" in config_dict + and "user" not in config_dict ): auth_type = "security_token" # Create client based on auth type if auth_type == "security_token": - token_file = self._oci_config_dict.get("security_token_file") + token_file = config_dict.get("security_token_file") if token_file: import os as os_module token_file = os_module.path.expanduser(token_file) with open(token_file) as f: token = f.read().strip() - key_file = os_module.path.expanduser(self._oci_config_dict["key_file"]) + key_file = os_module.path.expanduser(config_dict["key_file"]) private_key = oci.signer.load_private_key_from_file(key_file) signer = oci.auth.signers.SecurityTokenSigner(token, private_key) self._client = GenerativeAiInferenceClient( @@ -259,7 +263,8 @@ def _get_compartment_id(self) -> str: if self.oci_config.compartment_id: return self.oci_config.compartment_id if self._oci_config_dict: - return self._oci_config_dict.get("tenancy", "") + tenancy: str = self._oci_config_dict.get("tenancy", "") + return tenancy return "" async def embed(self, text: str) -> EmbeddingResult: diff --git a/src/locus/rag/multimodal.py b/src/locus/rag/multimodal.py index 370a6a1..82fb3cf 100644 --- a/src/locus/rag/multimodal.py +++ b/src/locus/rag/multimodal.py @@ -363,7 +363,7 @@ def __init__( ): self.use_whisper = use_whisper self.whisper_model = whisper_model - self._whisper = None + self._whisper: Any = None def supports(self, content_type: ContentType) -> bool: return content_type == ContentType.AUDIO @@ -448,8 +448,9 @@ async def _transcribe_whisper(self, audio_bytes: bytes, audio_format: str) -> st temp_path = f.name try: - result = self._whisper.transcribe(temp_path) - return result["text"].strip() + result: dict[str, Any] = self._whisper.transcribe(temp_path) + text: str = result["text"].strip() + return text finally: os.unlink(temp_path) @@ -476,7 +477,7 @@ def __init__( use_ocr: bool = True, use_whisper: bool = True, ): - self.processors = { + self.processors: dict[ContentType, ContentProcessor] = { ContentType.TEXT: TextProcessor(), ContentType.MARKDOWN: TextProcessor(), ContentType.HTML: TextProcessor(), diff --git a/src/locus/rag/retriever.py b/src/locus/rag/retriever.py index 4794361..ebd01ae 100644 --- a/src/locus/rag/retriever.py +++ b/src/locus/rag/retriever.py @@ -204,7 +204,8 @@ async def add_document( documents.append(doc) # Store all documents - return await self.store.add_batch(documents) + added: list[str] = await self.store.add_batch(documents) + return added async def add_documents( self, @@ -307,7 +308,8 @@ async def add_file( ) documents.append(doc) - return await self.store.add_batch(documents) + added: list[str] = await self.store.add_batch(documents) + return added async def add_image( self, @@ -345,7 +347,8 @@ async def add_image( raw_content=result.raw_content, ) - return await self.store.add(doc) + doc_added: str = await self.store.add(doc) + return doc_added async def add_pdf( self, @@ -412,7 +415,8 @@ async def add_audio( raw_content=result.raw_content, ) - return await self.store.add(doc) + doc_added: str = await self.store.add(doc) + return doc_added async def retrieve( self, @@ -495,21 +499,24 @@ async def retrieve_text( async def delete_document(self, doc_id: str) -> bool: """Delete a document by ID.""" - return await self.store.delete(doc_id) + deleted: bool = await self.store.delete(doc_id) + return deleted async def clear(self) -> int: """Delete all documents.""" - return await self.store.clear() + cleared: int = await self.store.clear() + return cleared async def count(self) -> int: """Count documents in store.""" - return await self.store.count() + n: int = await self.store.count() + return n async def close(self) -> None: """Close resources.""" await self.store.close() - def as_tool(self, name: str = "search_knowledge", description: str | None = None): + def as_tool(self, name: str = "search_knowledge", description: str | None = None) -> Any: """ Create a tool function for agent use. diff --git a/src/locus/rag/stores/__init__.py b/src/locus/rag/stores/__init__.py index d82311c..3688717 100644 --- a/src/locus/rag/stores/__init__.py +++ b/src/locus/rag/stores/__init__.py @@ -14,6 +14,8 @@ - InMemoryVectorStore: In-memory store (testing) """ +from typing import Any + from locus.rag.stores.base import ( BaseVectorStore, Document, @@ -41,7 +43,7 @@ ] -def __getattr__(name: str): +def __getattr__(name: str) -> Any: """Lazy import stores to avoid requiring all dependencies.""" if name == "OracleVectorStore": from locus.rag.stores.oracle import OracleVectorStore diff --git a/src/locus/rag/stores/chroma.py b/src/locus/rag/stores/chroma.py index 73202f0..2ee0714 100644 --- a/src/locus/rag/stores/chroma.py +++ b/src/locus/rag/stores/chroma.py @@ -229,9 +229,11 @@ async def add(self, document: Document) -> str: }, } + # chromadb's typed shapes ask for ``Sequence[Sequence[float]]``; + # we pass the wider ``list[list[float]]`` we have at hand. collection.upsert( ids=[doc_id], - embeddings=[document.embedding], + embeddings=[document.embedding], # type: ignore[arg-type, unused-ignore] documents=[document.content], metadatas=[metadata], ) @@ -269,9 +271,9 @@ async def add_batch(self, documents: list[Document]) -> list[str]: if ids: collection.upsert( ids=ids, - embeddings=embeddings, + embeddings=embeddings, # type: ignore[arg-type, unused-ignore] documents=docs, - metadatas=metadatas, + metadatas=metadatas, # type: ignore[arg-type, unused-ignore] ) return ids @@ -291,7 +293,7 @@ async def get(self, doc_id: str) -> Document | None: if not result["ids"]: return None - metadata = result["metadatas"][0] if result["metadatas"] else {} + metadata: dict[str, Any] = dict(result["metadatas"][0]) if result["metadatas"] else {} created_at_str = metadata.pop("created_at", None) created_at = datetime.fromisoformat(created_at_str) if created_at_str else datetime.now(UTC) @@ -303,7 +305,7 @@ async def get(self, doc_id: str) -> Document | None: return Document( id=result["ids"][0], content=result["documents"][0] if result["documents"] else "", - embedding=embedding, + embedding=embedding, # type: ignore[arg-type, unused-ignore] metadata=metadata, created_at=created_at, ) @@ -334,7 +336,7 @@ async def search( collection = self._get_collection() # Build where filter for metadata - where = None + where: dict[str, Any] | None = None if metadata_filter: if len(metadata_filter) == 1: key, value = next(iter(metadata_filter.items())) @@ -343,7 +345,7 @@ async def search( where = {"$and": [{k: {"$eq": v}} for k, v in metadata_filter.items()]} result = collection.query( - query_embeddings=[query_embedding], + query_embeddings=[query_embedding], # type: ignore[arg-type, unused-ignore] n_results=limit, where=where, include=["embeddings", "documents", "metadatas", "distances"], @@ -374,7 +376,7 @@ async def search( if threshold is not None and score < threshold: continue - metadata = metadatas[i] if i < len(metadatas) else {} + metadata: dict[str, Any] = dict(metadatas[i]) if i < len(metadatas) else {} created_at_str = metadata.pop("created_at", None) created_at = ( datetime.fromisoformat(created_at_str) if created_at_str else datetime.now(UTC) @@ -383,7 +385,7 @@ async def search( doc = Document( id=doc_id, content=documents[i] if i < len(documents) else "", - embedding=embeddings[i] if i < len(embeddings) else None, + embedding=embeddings[i] if i < len(embeddings) else None, # type: ignore[arg-type, unused-ignore] metadata=metadata, created_at=created_at, ) @@ -401,12 +403,13 @@ async def search( async def count(self) -> int: """Count documents.""" collection = self._get_collection() - return collection.count() + n: int = collection.count() + return n async def clear(self) -> int: """Delete all documents.""" collection = self._get_collection() - count = collection.count() + count: int = collection.count() # Delete collection and recreate client = self._get_client() diff --git a/src/locus/rag/stores/opensearch.py b/src/locus/rag/stores/opensearch.py index 251d43c..0bf8279 100644 --- a/src/locus/rag/stores/opensearch.py +++ b/src/locus/rag/stores/opensearch.py @@ -217,7 +217,9 @@ async def add_batch(self, documents: list[Document]) -> list[str]: await self._ensure_index() client = await self._get_client() - actions = [] + # The OpenSearch bulk API alternates control headers and source bodies; + # both shapes are dicts with disparate value types, so widen to ``Any``. + actions: list[dict[str, Any]] = [] ids = [] for doc in documents: @@ -273,7 +275,7 @@ async def delete(self, doc_id: str) -> bool: client = await self._get_client() try: - result = await client.delete( + result: dict[str, Any] = await client.delete( index=self.os_config.index_name, id=doc_id, refresh=True, @@ -304,8 +306,9 @@ async def search( } # Add metadata filter if provided + query: dict[str, Any] if metadata_filter: - must_clauses = [knn_query] + must_clauses: list[dict[str, Any]] = [knn_query] for key, value in metadata_filter.items(): must_clauses.append({"term": {f"metadata.{key}": value}}) query = {"bool": {"must": must_clauses}} @@ -364,7 +367,8 @@ async def count(self) -> int: client = await self._get_client() result = await client.count(index=self.os_config.index_name) - return result["count"] + n: int = result["count"] + return n async def clear(self) -> int: """Delete all documents.""" diff --git a/src/locus/rag/stores/oracle.py b/src/locus/rag/stores/oracle.py index c20161f..2220265 100644 --- a/src/locus/rag/stores/oracle.py +++ b/src/locus/rag/stores/oracle.py @@ -376,7 +376,7 @@ async def delete(self, doc_id: str) -> bool: f"DELETE FROM {self._full_table_name} WHERE id = :id", {"id": doc_id}, ) - deleted = cursor.rowcount > 0 + deleted: bool = cursor.rowcount > 0 await conn.commit() return deleted diff --git a/src/locus/rag/stores/pgvector.py b/src/locus/rag/stores/pgvector.py index 76bb4eb..b7059e3 100644 --- a/src/locus/rag/stores/pgvector.py +++ b/src/locus/rag/stores/pgvector.py @@ -370,7 +370,7 @@ async def delete(self, doc_id: str) -> bool: pool = await self._get_pool() async with pool.acquire() as conn: - result = await conn.execute( + result: str = await conn.execute( f""" DELETE FROM {self._full_table_name} WHERE id = $1 @@ -590,12 +590,13 @@ async def has_index(self) -> bool: table_name = self.pgvector_config.table_name async with pool.acquire() as conn: - return await conn.fetchval(f""" + exists: bool = await conn.fetchval(f""" SELECT EXISTS ( SELECT 1 FROM pg_indexes WHERE indexname = 'idx_{table_name}_embedding' ) """) + return exists async def close(self) -> None: """Close connection pool.""" diff --git a/src/locus/rag/stores/pinecone.py b/src/locus/rag/stores/pinecone.py index 3610e02..e041f1a 100644 --- a/src/locus/rag/stores/pinecone.py +++ b/src/locus/rag/stores/pinecone.py @@ -380,9 +380,11 @@ async def count(self) -> int: if self.pinecone_config.namespace: ns_stats = stats.namespaces.get(self.pinecone_config.namespace, {}) - return ns_stats.get("vector_count", 0) + ns_count: int = ns_stats.get("vector_count", 0) + return ns_count - return stats.total_vector_count + total: int = stats.total_vector_count + return total async def clear(self) -> int: """Delete all documents.""" diff --git a/src/locus/rag/stores/qdrant.py b/src/locus/rag/stores/qdrant.py index 83f57c7..42e16f2 100644 --- a/src/locus/rag/stores/qdrant.py +++ b/src/locus/rag/stores/qdrant.py @@ -308,7 +308,7 @@ async def get(self, doc_id: str) -> Document | None: return Document( id=payload.get("doc_id", str(point.id)), content=payload.get("content", ""), - embedding=list(point.vector) if point.vector else None, + embedding=list(point.vector) if point.vector else None, # type: ignore[arg-type, unused-ignore] metadata=payload.get("metadata", {}), created_at=datetime.fromisoformat(payload["created_at"]) if payload.get("created_at") @@ -358,15 +358,19 @@ async def search( except ImportError: pass else: - conditions = [] - for key, value in metadata_filter.items(): - conditions.append( - FieldCondition( - key=f"metadata.{key}", - match=MatchValue(value=value), - ) + # ``Filter.must`` is typed as the broad union of all possible + # condition kinds; ``list[FieldCondition]`` is invariant and + # would be rejected. Sequence is covariant and accepted. + from collections.abc import Sequence + + conditions: Sequence[FieldCondition] = [ + FieldCondition( + key=f"metadata.{key}", + match=MatchValue(value=value), ) - query_filter = Filter(must=conditions) + for key, value in metadata_filter.items() + ] + query_filter = Filter(must=list(conditions)) # Search using query_points (newer API) @@ -378,8 +382,18 @@ async def search( with_vectors=True, ) - # Get points from result - points = search_result.points if hasattr(search_result, "points") else search_result + # ``query_points`` returns a ``QueryResponse`` whose ``.points`` is + # ``list[ScoredPoint]``; older builds fall through to the bare + # response object. Either way, the loop body below uses + # ScoredPoint-shape attributes. + from typing import cast as _cast + + from qdrant_client.models import ScoredPoint as _ScoredPoint + + points: list[_ScoredPoint] = _cast( + "list[_ScoredPoint]", + search_result.points if hasattr(search_result, "points") else search_result, + ) results = [] for point in points: @@ -394,7 +408,7 @@ async def search( doc = Document( id=payload.get("doc_id", str(point.id)), content=payload.get("content", ""), - embedding=list(point.vector) if point.vector else None, + embedding=list(point.vector) if point.vector else None, # type: ignore[arg-type, unused-ignore] metadata=payload.get("metadata", {}), created_at=datetime.fromisoformat(payload["created_at"]) if payload.get("created_at") diff --git a/src/locus/rag/tools.py b/src/locus/rag/tools.py index 28b9a24..19552c6 100644 --- a/src/locus/rag/tools.py +++ b/src/locus/rag/tools.py @@ -39,7 +39,7 @@ def create_rag_tool( description: str | None = None, limit: int = 5, threshold: float | None = 0.5, -): +) -> Any: """ Create a RAG search tool for agent use. @@ -125,7 +125,7 @@ def create_rag_context_tool( name: str = "get_context", description: str | None = None, limit: int = 3, -): +) -> Any: """ Create a RAG tool that returns context as formatted text. @@ -207,7 +207,7 @@ def __init__( self.retriever = retriever self.prefix = prefix - def get_tools(self) -> list: + def get_tools(self) -> list[Any]: """Get all RAG tools.""" return [ self.search_tool(), @@ -215,7 +215,7 @@ def get_tools(self) -> list: self.lookup_tool(), ] - def search_tool(self): + def search_tool(self) -> Any: """Get the search tool.""" return create_rag_tool( self.retriever, @@ -223,7 +223,7 @@ def search_tool(self): description="Search the knowledge base for relevant documents.", ) - def context_tool(self): + def context_tool(self) -> Any: """Get the context tool.""" return create_rag_context_tool( self.retriever, @@ -231,7 +231,7 @@ def context_tool(self): description="Get formatted context from the knowledge base.", ) - def lookup_tool(self): + def lookup_tool(self) -> Any: """Get the lookup tool.""" from locus.tools import tool as tool_decorator