Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,13 @@ module = [
"fastapi.*",
"uvicorn.*",
"whisper.*",
# Optional deps for ``locus.rag.multimodal`` — only loaded inside
# try/except for users who want OCR / PDF / image / audio extraction.
"pytesseract.*",
"PIL.*",
"pypdf.*",
"PyPDF2.*",
"pdf2image.*",
]
ignore_missing_imports = true

Expand All @@ -528,7 +535,6 @@ ignore_missing_imports = true
# the strict block above.
[[tool.mypy.overrides]]
module = [
"locus.rag.*",
# ``locus.hooks.builtin.*`` is held back: the module-level migration
# surfaces a real runtime bug — the LoggingHook / GuardrailsHook /
# TelemetryHook signatures pre-date the event-based ``HookProvider``
Expand Down
4 changes: 3 additions & 1 deletion src/locus/rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
... )
"""

from typing import Any

# Embeddings
from locus.rag.embeddings.base import (
BaseEmbedding,
Expand Down Expand Up @@ -128,7 +130,7 @@
]


def __getattr__(name: str):
def __getattr__(name: str) -> Any:
"""Lazy import providers and stores."""
# Embedding providers
if name == "OCIEmbeddings":
Expand Down
4 changes: 3 additions & 1 deletion src/locus/rag/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
- OpenAIEmbeddings: OpenAI text-embedding models
"""

from typing import Any

from locus.rag.embeddings.base import (
BaseEmbedding,
EmbeddingConfig,
Expand All @@ -29,7 +31,7 @@
]


def __getattr__(name: str):
def __getattr__(name: str) -> Any:
"""Lazy import providers to avoid requiring all dependencies."""
if name == "OCIEmbeddings":
from locus.rag.embeddings.oci import OCIEmbeddings
Expand Down
19 changes: 12 additions & 7 deletions src/locus/rag/embeddings/oci.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,19 @@ async def _get_client(self) -> GenerativeAiInferenceClient:

config_file = os.path.expanduser(config_file)

self._oci_config_dict = oci.config.from_file(
# ``oci.config.from_file`` returns the parsed config dict; the
# ``_oci_config_dict`` field is declared as ``... | None`` to
# represent the pre-init state, so bind to a local before use.
config_dict: dict[str, Any] = oci.config.from_file(
config_file,
self.oci_config.profile_name,
)
self._oci_config_dict = config_dict

# Determine service endpoint
endpoint = self.oci_config.service_endpoint
if endpoint is None:
region = self._oci_config_dict.get("region", "us-chicago-1")
region = config_dict.get("region", "us-chicago-1")
endpoint = f"https://inference.generativeai.{region}.oci.oraclecloud.com"

# Determine auth type - respect explicit setting, only auto-detect if needed
Expand All @@ -202,21 +206,21 @@ async def _get_client(self) -> GenerativeAiInferenceClient:
# 3. Config doesn't have user field (api_key profiles have user)
if (
auth_type != "api_key"
and "security_token_file" in self._oci_config_dict
and "user" not in self._oci_config_dict
and "security_token_file" in config_dict
and "user" not in config_dict
):
auth_type = "security_token"

# Create client based on auth type
if auth_type == "security_token":
token_file = self._oci_config_dict.get("security_token_file")
token_file = config_dict.get("security_token_file")
if token_file:
import os as os_module

token_file = os_module.path.expanduser(token_file)
with open(token_file) as f:
token = f.read().strip()
key_file = os_module.path.expanduser(self._oci_config_dict["key_file"])
key_file = os_module.path.expanduser(config_dict["key_file"])
private_key = oci.signer.load_private_key_from_file(key_file)
signer = oci.auth.signers.SecurityTokenSigner(token, private_key)
self._client = GenerativeAiInferenceClient(
Expand Down Expand Up @@ -259,7 +263,8 @@ def _get_compartment_id(self) -> str:
if self.oci_config.compartment_id:
return self.oci_config.compartment_id
if self._oci_config_dict:
return self._oci_config_dict.get("tenancy", "")
tenancy: str = self._oci_config_dict.get("tenancy", "")
return tenancy
return ""

async def embed(self, text: str) -> EmbeddingResult:
Expand Down
9 changes: 5 additions & 4 deletions src/locus/rag/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def __init__(
):
self.use_whisper = use_whisper
self.whisper_model = whisper_model
self._whisper = None
self._whisper: Any = None

def supports(self, content_type: ContentType) -> bool:
return content_type == ContentType.AUDIO
Expand Down Expand Up @@ -448,8 +448,9 @@ async def _transcribe_whisper(self, audio_bytes: bytes, audio_format: str) -> st
temp_path = f.name

try:
result = self._whisper.transcribe(temp_path)
return result["text"].strip()
result: dict[str, Any] = self._whisper.transcribe(temp_path)
text: str = result["text"].strip()
return text
finally:
os.unlink(temp_path)

Expand All @@ -476,7 +477,7 @@ def __init__(
use_ocr: bool = True,
use_whisper: bool = True,
):
self.processors = {
self.processors: dict[ContentType, ContentProcessor] = {
ContentType.TEXT: TextProcessor(),
ContentType.MARKDOWN: TextProcessor(),
ContentType.HTML: TextProcessor(),
Expand Down
23 changes: 15 additions & 8 deletions src/locus/rag/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ async def add_document(
documents.append(doc)

# Store all documents
return await self.store.add_batch(documents)
added: list[str] = await self.store.add_batch(documents)
return added

async def add_documents(
self,
Expand Down Expand Up @@ -307,7 +308,8 @@ async def add_file(
)
documents.append(doc)

return await self.store.add_batch(documents)
added: list[str] = await self.store.add_batch(documents)
return added

async def add_image(
self,
Expand Down Expand Up @@ -345,7 +347,8 @@ async def add_image(
raw_content=result.raw_content,
)

return await self.store.add(doc)
doc_added: str = await self.store.add(doc)
return doc_added

async def add_pdf(
self,
Expand Down Expand Up @@ -412,7 +415,8 @@ async def add_audio(
raw_content=result.raw_content,
)

return await self.store.add(doc)
doc_added: str = await self.store.add(doc)
return doc_added

async def retrieve(
self,
Expand Down Expand Up @@ -495,21 +499,24 @@ async def retrieve_text(

async def delete_document(self, doc_id: str) -> bool:
"""Delete a document by ID."""
return await self.store.delete(doc_id)
deleted: bool = await self.store.delete(doc_id)
return deleted

async def clear(self) -> int:
"""Delete all documents."""
return await self.store.clear()
cleared: int = await self.store.clear()
return cleared

async def count(self) -> int:
"""Count documents in store."""
return await self.store.count()
n: int = await self.store.count()
return n

async def close(self) -> None:
"""Close resources."""
await self.store.close()

def as_tool(self, name: str = "search_knowledge", description: str | None = None):
def as_tool(self, name: str = "search_knowledge", description: str | None = None) -> Any:
"""
Create a tool function for agent use.

Expand Down
4 changes: 3 additions & 1 deletion src/locus/rag/stores/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
- InMemoryVectorStore: In-memory store (testing)
"""

from typing import Any

from locus.rag.stores.base import (
BaseVectorStore,
Document,
Expand Down Expand Up @@ -41,7 +43,7 @@
]


def __getattr__(name: str):
def __getattr__(name: str) -> Any:
"""Lazy import stores to avoid requiring all dependencies."""
if name == "OracleVectorStore":
from locus.rag.stores.oracle import OracleVectorStore
Expand Down
25 changes: 14 additions & 11 deletions src/locus/rag/stores/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,11 @@ async def add(self, document: Document) -> str:
},
}

# chromadb's typed shapes ask for ``Sequence[Sequence[float]]``;
# we pass the wider ``list[list[float]]`` we have at hand.
collection.upsert(
ids=[doc_id],
embeddings=[document.embedding],
embeddings=[document.embedding], # type: ignore[arg-type, unused-ignore]
documents=[document.content],
metadatas=[metadata],
)
Expand Down Expand Up @@ -269,9 +271,9 @@ async def add_batch(self, documents: list[Document]) -> list[str]:
if ids:
collection.upsert(
ids=ids,
embeddings=embeddings,
embeddings=embeddings, # type: ignore[arg-type, unused-ignore]
documents=docs,
metadatas=metadatas,
metadatas=metadatas, # type: ignore[arg-type, unused-ignore]
)

return ids
Expand All @@ -291,7 +293,7 @@ async def get(self, doc_id: str) -> Document | None:
if not result["ids"]:
return None

metadata = result["metadatas"][0] if result["metadatas"] else {}
metadata: dict[str, Any] = dict(result["metadatas"][0]) if result["metadatas"] else {}
created_at_str = metadata.pop("created_at", None)
created_at = datetime.fromisoformat(created_at_str) if created_at_str else datetime.now(UTC)

Expand All @@ -303,7 +305,7 @@ async def get(self, doc_id: str) -> Document | None:
return Document(
id=result["ids"][0],
content=result["documents"][0] if result["documents"] else "",
embedding=embedding,
embedding=embedding, # type: ignore[arg-type, unused-ignore]
metadata=metadata,
created_at=created_at,
)
Expand Down Expand Up @@ -334,7 +336,7 @@ async def search(
collection = self._get_collection()

# Build where filter for metadata
where = None
where: dict[str, Any] | None = None
if metadata_filter:
if len(metadata_filter) == 1:
key, value = next(iter(metadata_filter.items()))
Expand All @@ -343,7 +345,7 @@ async def search(
where = {"$and": [{k: {"$eq": v}} for k, v in metadata_filter.items()]}

result = collection.query(
query_embeddings=[query_embedding],
query_embeddings=[query_embedding], # type: ignore[arg-type, unused-ignore]
n_results=limit,
where=where,
include=["embeddings", "documents", "metadatas", "distances"],
Expand Down Expand Up @@ -374,7 +376,7 @@ async def search(
if threshold is not None and score < threshold:
continue

metadata = metadatas[i] if i < len(metadatas) else {}
metadata: dict[str, Any] = dict(metadatas[i]) if i < len(metadatas) else {}
created_at_str = metadata.pop("created_at", None)
created_at = (
datetime.fromisoformat(created_at_str) if created_at_str else datetime.now(UTC)
Expand All @@ -383,7 +385,7 @@ async def search(
doc = Document(
id=doc_id,
content=documents[i] if i < len(documents) else "",
embedding=embeddings[i] if i < len(embeddings) else None,
embedding=embeddings[i] if i < len(embeddings) else None, # type: ignore[arg-type, unused-ignore]
metadata=metadata,
created_at=created_at,
)
Expand All @@ -401,12 +403,13 @@ async def search(
async def count(self) -> int:
"""Count documents."""
collection = self._get_collection()
return collection.count()
n: int = collection.count()
return n

async def clear(self) -> int:
"""Delete all documents."""
collection = self._get_collection()
count = collection.count()
count: int = collection.count()

# Delete collection and recreate
client = self._get_client()
Expand Down
12 changes: 8 additions & 4 deletions src/locus/rag/stores/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ async def add_batch(self, documents: list[Document]) -> list[str]:
await self._ensure_index()
client = await self._get_client()

actions = []
# The OpenSearch bulk API alternates control headers and source bodies;
# both shapes are dicts with disparate value types, so widen to ``Any``.
actions: list[dict[str, Any]] = []
ids = []

for doc in documents:
Expand Down Expand Up @@ -273,7 +275,7 @@ async def delete(self, doc_id: str) -> bool:
client = await self._get_client()

try:
result = await client.delete(
result: dict[str, Any] = await client.delete(
index=self.os_config.index_name,
id=doc_id,
refresh=True,
Expand Down Expand Up @@ -304,8 +306,9 @@ async def search(
}

# Add metadata filter if provided
query: dict[str, Any]
if metadata_filter:
must_clauses = [knn_query]
must_clauses: list[dict[str, Any]] = [knn_query]
for key, value in metadata_filter.items():
must_clauses.append({"term": {f"metadata.{key}": value}})
query = {"bool": {"must": must_clauses}}
Expand Down Expand Up @@ -364,7 +367,8 @@ async def count(self) -> int:
client = await self._get_client()

result = await client.count(index=self.os_config.index_name)
return result["count"]
n: int = result["count"]
return n

async def clear(self) -> int:
"""Delete all documents."""
Expand Down
2 changes: 1 addition & 1 deletion src/locus/rag/stores/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ async def delete(self, doc_id: str) -> bool:
f"DELETE FROM {self._full_table_name} WHERE id = :id",
{"id": doc_id},
)
deleted = cursor.rowcount > 0
deleted: bool = cursor.rowcount > 0
await conn.commit()

return deleted
Expand Down
Loading
Loading