diff --git a/huf/ai/knowledge/backends/__init__.py b/huf/ai/knowledge/backends/__init__.py index 0858b42e..cf35147d 100644 --- a/huf/ai/knowledge/backends/__init__.py +++ b/huf/ai/knowledge/backends/__init__.py @@ -2,7 +2,7 @@ Knowledge Backend Abstraction This module provides a unified interface for knowledge storage backends. -Supported: SQLite FTS (keyword search), SQLite Vec (vector search), ChromaDB (vector search) +Supported: SQLite FTS (keyword search), SQLite Vec (vector search), ChromaDB (vector search), PGVector (vector search) """ from abc import ABC, abstractmethod @@ -66,10 +66,11 @@ def get_backend(backend_type: str) -> type: "sqlite_fts": "huf.ai.knowledge.backends.sqlite_fts.SQLiteFTSBackend", "sqlite_vec": "huf.ai.knowledge.backends.sqlite_vec_backend.SQLiteVecBackend", "chroma": "huf.ai.knowledge.backends.chroma_backend.ChromaBackend", + "pgvector": "huf.ai.knowledge.backends.pgvector_backend.PGVectorBackend", } if backend_type not in backends: raise ValueError(f"Unknown backend type: {backend_type}") import frappe - return frappe.get_attr(backends[backend_type]) + return frappe.get_attr(backends[backend_type]) \ No newline at end of file diff --git a/huf/ai/knowledge/backends/pgvector_backend.py b/huf/ai/knowledge/backends/pgvector_backend.py new file mode 100644 index 00000000..b353f02d --- /dev/null +++ b/huf/ai/knowledge/backends/pgvector_backend.py @@ -0,0 +1,368 @@ +# Copyright (c) 2025, Huf and contributors +# For license information, please see license.txt + +"""PostgreSQL/PGVector backend for HUF knowledge storage.""" + +import json +import re +import uuid +from contextlib import contextmanager +from typing import Any, Dict, List, Optional + +import frappe +from frappe import _ + +from . import ChunkResult, KnowledgeBackend + +try: + import psycopg + from psycopg import sql + PSYCOPG_AVAILABLE = True +except ImportError: + PSYCOPG_AVAILABLE = False + + +VALID_IDENTIFIER = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") +DISTANCE_OPERATORS = { + "cosine": "<=>", + "l2": "<->", + "inner_product": "<#>", +} + + +class PGVectorBackend(KnowledgeBackend): + """PostgreSQL backend using the pgvector extension for semantic search.""" + + def __init__(self): + self.knowledge_source = None + self.config = {} + self.table_name = "huf_knowledge_vectors" + self.dimension = 1536 + self.distance_metric = "cosine" + self.connection_mode = "External PostgreSQL" + self._initialized = False + + def initialize(self, knowledge_source: str, config: Dict[str, Any]) -> None: + if not PSYCOPG_AVAILABLE: + frappe.throw( + _("psycopg is required for pgvector knowledge sources. " + "Install it with: pip install psycopg[binary]") + ) + + self.knowledge_source = knowledge_source + self.config = config or {} + self.table_name = self.config.get("table_name") or "huf_knowledge_vectors" + self.dimension = int(self.config.get("vector_dimension") or 1536) + self.distance_metric = self.config.get("distance_metric") or "cosine" + self.connection_mode = self.config.get("connection_mode") or "External PostgreSQL" + + self._validate_config() + self._ensure_schema() + self._initialized = True + + def _validate_config(self) -> None: + if not VALID_IDENTIFIER.match(self.table_name): + frappe.throw(_("PGVector table name must be a valid PostgreSQL identifier")) + + if self.distance_metric not in DISTANCE_OPERATORS: + frappe.throw(_("Unsupported PGVector distance metric: {0}").format(self.distance_metric)) + + if self.dimension <= 0: + frappe.throw(_("PGVector vector dimension must be positive")) + + @contextmanager + def _get_connection(self): + conn = psycopg.connect(**self._get_connection_params()) + try: + yield conn + conn.commit() + except Exception: + conn.rollback() + raise + finally: + conn.close() + + def _get_connection_params(self) -> Dict[str, Any]: + if self.connection_mode == "Site PostgreSQL": + if frappe.conf.db_type != "postgres": + frappe.throw( + _("Site PostgreSQL mode requires a PostgreSQL-backed Frappe site. " + "Use External PostgreSQL for MariaDB-backed sites.") + ) + return { + "host": frappe.conf.db_host or "localhost", + "port": int(frappe.conf.db_port or 5432), + "dbname": frappe.conf.db_name, + "user": frappe.conf.db_user, + "password": frappe.conf.db_password, + } + + params = { + "host": self.config.get("host") or "localhost", + "port": int(self.config.get("port") or 5432), + "dbname": self.config.get("database"), + "user": self.config.get("user"), + "password": self.config.get("password"), + } + sslmode = self.config.get("sslmode") + if sslmode: + params["sslmode"] = sslmode + return params + + def _ensure_schema(self) -> None: + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute("CREATE EXTENSION IF NOT EXISTS vector") + cursor.execute( + sql.SQL( + """ + CREATE TABLE IF NOT EXISTS {table} ( + id BIGSERIAL PRIMARY KEY, + site_name TEXT NOT NULL, + knowledge_source TEXT NOT NULL, + input_id TEXT NOT NULL, + input_type TEXT NOT NULL, + chunk_id TEXT NOT NULL UNIQUE, + source_title TEXT, + chunk_index INTEGER, + text TEXT NOT NULL, + char_start INTEGER, + char_end INTEGER, + metadata JSONB DEFAULT '{{}}'::jsonb, + embedding VECTOR({dimension}) NOT NULL, + created_at TIMESTAMPTZ DEFAULT now(), + updated_at TIMESTAMPTZ DEFAULT now() + ) + """ + ).format( + table=sql.Identifier(self.table_name), + dimension=sql.SQL(str(self.dimension)), + ) + ) + cursor.execute( + sql.SQL( + "CREATE INDEX IF NOT EXISTS {index} ON {table} (site_name, knowledge_source)" + ).format( + index=sql.Identifier(f"idx_{self.table_name}_source"), + table=sql.Identifier(self.table_name), + ) + ) + cursor.execute( + sql.SQL( + "CREATE INDEX IF NOT EXISTS {index} ON {table} (site_name, knowledge_source, input_id)" + ).format( + index=sql.Identifier(f"idx_{self.table_name}_input"), + table=sql.Identifier(self.table_name), + ) + ) + cursor.execute( + sql.SQL( + "CREATE INDEX IF NOT EXISTS {index} ON {table} USING GIN (metadata)" + ).format( + index=sql.Identifier(f"idx_{self.table_name}_metadata"), + table=sql.Identifier(self.table_name), + ) + ) + + def add_chunks(self, chunks: List[Dict[str, Any]]) -> int: + if not chunks: + return 0 + + from huf.ai.knowledge.embedding import get_embeddings, resolve_embedding_config + + texts = [chunk["text"] for chunk in chunks] + embed_config = resolve_embedding_config(self.knowledge_source) + embeddings = get_embeddings( + texts=texts, + model=embed_config["model"], + api_key=embed_config.get("api_key"), + api_base=embed_config.get("api_base"), + ) + + with self._get_connection() as conn: + with conn.cursor() as cursor: + for chunk, embedding in zip(chunks, embeddings): + chunk_id = chunk.get("chunk_id") or str(uuid.uuid4()) + metadata = json.dumps(chunk.get("metadata") or {}) + cursor.execute( + sql.SQL( + """ + INSERT INTO {table} + (site_name, knowledge_source, input_id, input_type, chunk_id, source_title, + chunk_index, text, char_start, char_end, metadata, embedding, updated_at) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s::jsonb, %s::vector, now()) + ON CONFLICT (chunk_id) DO UPDATE SET + site_name = EXCLUDED.site_name, + knowledge_source = EXCLUDED.knowledge_source, + input_id = EXCLUDED.input_id, + input_type = EXCLUDED.input_type, + source_title = EXCLUDED.source_title, + chunk_index = EXCLUDED.chunk_index, + text = EXCLUDED.text, + char_start = EXCLUDED.char_start, + char_end = EXCLUDED.char_end, + metadata = EXCLUDED.metadata, + embedding = EXCLUDED.embedding, + updated_at = now() + """ + ).format(table=sql.Identifier(self.table_name)), + ( + frappe.local.site, + self.knowledge_source, + chunk["input_id"], + chunk["input_type"], + chunk_id, + chunk.get("source_title"), + chunk.get("chunk_index"), + chunk["text"], + chunk.get("char_start"), + chunk.get("char_end"), + metadata, + self._format_vector(embedding), + ), + ) + return len(chunks) + + def delete_chunks(self, input_id: str) -> int: + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute( + sql.SQL( + "DELETE FROM {table} WHERE site_name = %s AND knowledge_source = %s AND input_id = %s" + ).format(table=sql.Identifier(self.table_name)), + (frappe.local.site, self.knowledge_source, input_id), + ) + return cursor.rowcount or 0 + + def search( + self, + query: str, + top_k: int = 5, + filters: Optional[Dict[str, Any]] = None, + ) -> List[ChunkResult]: + if not query or not query.strip(): + return [] + + from huf.ai.knowledge.embedding import get_embedding, resolve_embedding_config + + embed_config = resolve_embedding_config(self.knowledge_source) + query_embedding = get_embedding( + text=query, + model=embed_config["model"], + api_key=embed_config.get("api_key"), + api_base=embed_config.get("api_base"), + ) + + where_parts = [sql.SQL("site_name = %s"), sql.SQL("knowledge_source = %s")] + params: List[Any] = [frappe.local.site, self.knowledge_source] + if filters: + for key, value in filters.items(): + where_parts.append(sql.SQL("metadata ->> %s = %s")) + params.extend([key, str(value)]) + + operator = sql.SQL(DISTANCE_OPERATORS[self.distance_metric]) + vector_text = self._format_vector(query_embedding) + params.extend([vector_text, vector_text, int(top_k)]) + + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute( + sql.SQL( + """ + SELECT chunk_id, text, source_title, input_id, metadata, + embedding {operator} %s::vector AS distance + FROM {table} + WHERE {where_sql} + ORDER BY embedding {operator} %s::vector + LIMIT %s + """ + ).format( + table=sql.Identifier(self.table_name), + operator=operator, + where_sql=sql.SQL(" AND ").join(where_parts), + ), + params, + ) + results = [] + for row in cursor.fetchall(): + chunk_id, text, title, input_id, metadata, distance = row + results.append( + ChunkResult( + chunk_id=chunk_id, + text=text, + title=title, + score=self._distance_to_score(distance), + source=input_id, + metadata=metadata or {}, + ) + ) + return results + + def clear(self) -> None: + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute( + sql.SQL("DELETE FROM {table} WHERE site_name = %s AND knowledge_source = %s").format( + table=sql.Identifier(self.table_name) + ), + (frappe.local.site, self.knowledge_source), + ) + + def get_stats(self) -> Dict[str, Any]: + stats = { + "backend_type": "pgvector", + "knowledge_source": self.knowledge_source, + "table_name": self.table_name, + "chunk_count": 0, + "input_count": 0, + "vector_dimension": self.dimension, + "distance_metric": self.distance_metric, + } + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute( + sql.SQL( + """ + SELECT COUNT(*), COUNT(DISTINCT input_id) + FROM {table} + WHERE site_name = %s AND knowledge_source = %s + """ + ).format(table=sql.Identifier(self.table_name)), + (frappe.local.site, self.knowledge_source), + ) + chunk_count, input_count = cursor.fetchone() + stats["chunk_count"] = chunk_count or 0 + stats["input_count"] = input_count or 0 + return stats + + def health_check(self): + try: + with self._get_connection() as conn: + with conn.cursor() as cursor: + cursor.execute("SELECT 1") + return (True, "Healthy") + except Exception as exc: + return (False, str(exc)) + + def supports_filters(self) -> bool: + return True + + def supports_hybrid_search(self) -> bool: + return False + + def _format_vector(self, embedding: List[float]) -> str: + if len(embedding) != self.dimension: + frappe.throw( + _("Embedding dimension mismatch. Expected {0}, got {1}").format( + self.dimension, len(embedding) + ) + ) + return "[" + ",".join(str(float(value)) for value in embedding) + "]" + + def _distance_to_score(self, distance) -> float: + if distance is None: + return 0.0 + distance = float(distance) + if self.distance_metric == "cosine": + return max(0.0, 1.0 - distance) + return 1.0 / (1.0 + abs(distance)) diff --git a/huf/ai/knowledge/indexer.py b/huf/ai/knowledge/indexer.py index e9ad9b16..09bbb7cd 100644 --- a/huf/ai/knowledge/indexer.py +++ b/huf/ai/knowledge/indexer.py @@ -14,14 +14,14 @@ def _build_backend_config(source) -> dict: """Build configuration dict for backend initialization. Includes chunking settings for all backends and adds embedding - configuration for the sqlite_vec vector backend. + configuration for vector backends. """ config = { "chunk_size": source.chunk_size, "chunk_overlap": source.chunk_overlap, } - if source.knowledge_type in ("sqlite_vec", "chroma"): + if source.knowledge_type in ("sqlite_vec", "chroma", "pgvector"): config["embedding_model"] = source.embedding_model config["vector_dimension"] = source.vector_dimension config["embedding_provider"] = getattr(source, "embedding_provider", None) @@ -39,6 +39,20 @@ def _build_backend_config(source) -> dict: safe_name = frappe.scrub(source.name) config["persist_directory"] = os.path.join(files_path, "knowledge", f"{safe_name}_chroma") + if source.knowledge_type == "pgvector": + config.update({ + "connection_mode": getattr(source, "pgvector_connection_mode", None) or "External PostgreSQL", + "table_name": getattr(source, "pgvector_table_name", None) or "huf_knowledge_vectors", + "distance_metric": getattr(source, "pgvector_distance_metric", None) or "cosine", + "index_type": getattr(source, "pgvector_index_type", None) or "hnsw", + "host": getattr(source, "pgvector_host", None) or "localhost", + "port": int(getattr(source, "pgvector_port", None) or 5432), + "database": getattr(source, "pgvector_database", None), + "user": getattr(source, "pgvector_user", None), + "password": source.get_password("pgvector_password") if getattr(source, "pgvector_password", None) else None, + "sslmode": getattr(source, "pgvector_sslmode", None) or "prefer", + }) + return config @@ -261,44 +275,26 @@ def _extract_text(doc) -> ExtractedText: return extractor.extract(file_path) elif doc.input_type == "URL": - # Get URL extractor - from .extractors.url import URLExtractor - extractor = URLExtractor() - return extractor.extract(doc.url) + # Fetch URL content + import requests + response = requests.get(doc.url, timeout=30) + response.raise_for_status() + + # Extract text from HTML + extractor = TextExtractor.get_extractor("html") + return extractor.extract_from_content(response.text, doc.url) - raise ValueError(f"Unknown input type: {doc.input_type}") + else: + frappe.throw(_("Unsupported input type: {0}").format(doc.input_type)) def update_source_stats(source, backend): - """Update knowledge source statistics.""" - stats = backend.get_stats() - - source.reload() - source.total_chunks = stats.get("chunk_count", 0) - source.total_inputs = stats.get("input_count", 0) - source.index_size_bytes = stats.get("size_bytes", 0) - db_path = getattr(backend, "db_path", None) - source.sqlite_file_path = db_path - - # Update SQLite file reference - if db_path and os.path.exists(db_path): - # Create or update file reference - from frappe.utils import get_files_path - files_path = get_files_path(is_private=True) - relative_path = os.path.relpath(backend.db_path, files_path) - file_url = f"/private/files/{relative_path.replace(os.sep, '/')}" - - # Check if file doc exists - existing_file = frappe.db.exists("File", {"file_url": file_url}) - if not existing_file: - file_doc = frappe.get_doc({ - "doctype": "File", - "file_name": os.path.basename(backend.db_path), - "file_url": file_url, - "is_private": 1, - }) - file_doc.insert(ignore_permissions=True) - - source.sqlite_file = file_url - - source.save(ignore_permissions=True) + """Update knowledge source statistics from backend.""" + try: + stats = backend.get_stats() + source.total_chunks = stats.get("chunk_count", 0) + source.total_inputs = stats.get("input_count", 0) + source.index_size_bytes = stats.get("size_bytes", 0) + source.save(ignore_permissions=True) + except Exception as e: + frappe.logger().warning(f"Failed to update knowledge source stats: {str(e)}") \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 83562e6c..df754186 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "httpx>=0.24.0", "chromadb", "llama-index-vector-stores-chroma", + "psycopg[binary]>=3.1.0", "pypdf", "python-docx", "beautifulsoup4", @@ -65,4 +66,4 @@ typing-modules = ["frappe.types.DF"] [tool.ruff.format] quote-style = "double" indent-style = "tab" -docstring-code-format = true +docstring-code-format = true \ No newline at end of file