diff --git a/changai/changai/api/v2/emb_load_service.py b/changai/changai/api/v2/emb_load_service.py new file mode 100644 index 0000000..0aebd52 --- /dev/null +++ b/changai/changai/api/v2/emb_load_service.py @@ -0,0 +1,36 @@ +from fastapi import FastAPI +from pydantic import BaseModel +from langchain_huggingface import HuggingFaceEmbeddings +import os +MODEL_PATH = os.environ.get("CHANGAI_EMBEDDING_MODEL_PATH") + +app = FastAPI() + +embedding_model = HuggingFaceEmbeddings( + model_name=MODEL_PATH, + model_kwargs={ + "device": "cpu", + "trust_remote_code": True, + }, + encode_kwargs={ + "normalize_embeddings": True, + }, +) + +# warmup during service startup +embedding_model.embed_query("changai warmup") + + +class EmbedRequest(BaseModel): + text: str + + +@app.get("/health") +def health(): + return {"ok": True, "model_loaded": True} + + +@app.post("/embed") +def embed(req: EmbedRequest): + vector = embedding_model.embed_query(req.text) + return {"embedding": vector} \ No newline at end of file diff --git a/changai/changai/api/v2/embedding_client.py b/changai/changai/api/v2/embedding_client.py new file mode 100644 index 0000000..d0b1cf7 --- /dev/null +++ b/changai/changai/api/v2/embedding_client.py @@ -0,0 +1,23 @@ +import requests +from langchain_core.embeddings import Embeddings +import frappe + +@frappe.whitelist(allow_guest=False) +def get_local_embedding(text: str): + try: + res = requests.post( + "http://127.0.0.1:8001/embed", + json={"text": text}, + timeout=30 + ) + res.raise_for_status() + return res.json()["embedding"] + except Exception as e: + frappe.throw(f"Embedding service error :{str(e)}") + +class LocalEmbeddingService(Embeddings): + def embed_query(self, text: str) -> list[float]: + return get_local_embedding(text) + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [get_local_embedding(t) for t in texts] \ No newline at end of file diff --git a/changai/changai/api/v2/text2sql_pipeline_v2.py b/changai/changai/api/v2/text2sql_pipeline_v2.py index cb74f89..6513ddd 100644 --- a/changai/changai/api/v2/text2sql_pipeline_v2.py +++ b/changai/changai/api/v2/text2sql_pipeline_v2.py @@ -47,6 +47,8 @@ from pathlib import Path import numpy as np from typing import List, Dict, Any +# from changai.changai.api.v2.embedding_client import LocalEmbeddingService +# from changai.changai.api.v2.embedding_client import get_local_embedding # from symspellpy.symspellpy import SymSpell sym_spell = None _GEMINI_CLIENT = None @@ -210,11 +212,22 @@ def enrich_fields_for_sql_context(table: str, fields: list[str]) -> list[str]: return out -def format_schema_context(grouped: dict[str, list[str]]) -> str: +def format_schema_context(grouped: dict) -> str: parts = [] - for table, raw_fields in grouped.items(): - child = is_child_table(table) + for table, table_data in grouped.items(): + if isinstance(table_data, dict): + raw_fields = table_data.get("fields", []) + is_table_value = table_data.get("is_table") + + if is_table_value is None: + child = is_child_table(table) + else: + child = bool(is_table_value) + else: + raw_fields = table_data + child = is_child_table(table) + fields = enrich_fields_for_sql_context(table, raw_fields) parts.append(f"TABLE: {table}") @@ -435,15 +448,6 @@ def load_field_matrix(): return docs, embs, table_to_idx -def _get_cached_embedding_test(q: str) -> tuple: - t0=time.time() - emb = get_embedding_engine() - emb_load_time = time.time() - t0 - t1 = time.time() - vec = emb.embed_query(q) - embed_query_time = time.time() - t1 - return emb_load_time,embed_query_time # tuple for hashability - def get_embedding_engine(): global _EMBEDDER_INSTANCE @@ -1433,14 +1437,12 @@ def check_memory_status() -> dict: } } + @lru_cache(maxsize=512) def _get_cached_embedding(q: str, request_id: str) -> tuple: - publish_pipeline_update( - request_id, - "embedding_start", - "embedding started" - ) + # vec = get_local_embedding(q) emb = get_embedding_engine() + publish_pipeline_update( request_id, "embedding_end", @@ -1492,7 +1494,6 @@ def create_entity(state:SQLState): - def call_fvs_table_search(get_table: bool, q: str, request_id: str) -> List[str]: # get cached embedding publish_pipeline_update( @@ -1501,6 +1502,11 @@ def call_fvs_table_search(get_table: bool, q: str, request_id: str) -> List[str] _("Inside the Table Search Function") ) q_vec = np.array(_get_cached_embedding(q,request_id), dtype="float32") + publish_pipeline_update( + request_id, + "Completed Embed for Table Search Function", + _("Completed Embed for Table Search Function") + ) # use FAISS index directly instead of similarity_search publish_pipeline_update( @@ -1555,7 +1561,6 @@ def build_hnsw_index(embeddings): return index - def call_retrieve_multi_line(user_question: str, request_id: str) -> Dict[str, Any]: try: top_tables = call_fvs_table_search(True, user_question, request_id) @@ -1586,20 +1591,23 @@ def call_retrieve_multi_line(user_question: str, request_id: str) -> Dict[str, A except Exception as e: return {"selected_fields": {}, "selected_tables": [], "top_tables": [], "error": str(e)} + def call_fvs_field_search_global_k( user_question: str, selected_tables: List[str], k_total: int = 40, request_id: Optional[str] = None ) -> str: - + if isinstance(selected_tables, str): + try: + selected_tables = json.loads(selected_tables) + except Exception: + selected_tables = [selected_tables] if not user_question or not selected_tables: return "" docs, embs, table_to_idx = load_field_matrix() - # emb = get_embedding_engine() - q_vec = np.array( _get_cached_embedding(user_question, request_id), dtype="float32" @@ -1607,12 +1615,33 @@ def call_fvs_field_search_global_k( q_vec = q_vec / max(np.linalg.norm(q_vec), 1e-12) - # collect indices all_idxs = [] + for t in selected_tables: - all_idxs.extend(table_to_idx.get(t, [])) + t = str(t).strip() + if not t: + continue + + candidates = [ + t, + f"tab{t}" if not t.startswith("tab") else t, + t.replace("tab", "", 1) if t.startswith("tab") else t, + ] + + for key in candidates: + if key in table_to_idx: + all_idxs.extend(table_to_idx[key]) + break if not all_idxs: + frappe.log_error( + title="ChangAI Field Search: No Indexes Found", + message=json.dumps({ + "user_question": user_question, + "selected_tables": selected_tables, + "sample_table_to_idx_keys": list(table_to_idx.keys())[:50], + }, indent=2, default=str) + ) return "" sub_embs = embs[all_idxs] @@ -1628,9 +1657,10 @@ def call_fvs_field_search_global_k( d = docs[doc_i] meta = getattr(d, "metadata", {}) or {} + is_table = meta.get("is_table") table = meta.get("table") - field = meta.get("field") or meta.get("name") + field = meta.get("field") or meta.get("name") if not table or not field: continue @@ -1638,31 +1668,45 @@ def call_fvs_field_search_global_k( key = (table, field) if key in seen: continue + seen.add(key) name = field - # join hint - if meta.get("join_hint"): - linked_table = meta["join_hint"].get("table") + join_hint = meta.get("join_hint") + if isinstance(join_hint, dict): + linked_table = join_hint.get("table") if linked_table: name += f" -> {linked_table}" + elif isinstance(join_hint, str) and join_hint.strip(): + name += f" -> {join_hint.strip()}" - # options - if meta.get("options"): - opts = meta["options"] + opts = meta.get("options") + if opts: if isinstance(opts, list): name += " {" + ", ".join(str(o) for o in opts[:5]) + "}" + else: + name += " {" + str(opts) + "}" + grouped.setdefault(table, { "is_table": is_table, "fields": [] }) - grouped[table]["fields"].append(name) - - res = format_schema_context(grouped) - # 🔥 final compact string - return res + grouped[table]["fields"].append(name) + + if not grouped: + frappe.log_error( + title="ChangAI Field Search: Empty Grouped Result", + message=json.dumps({ + "user_question": user_question, + "selected_tables": selected_tables, + "all_idxs_count": len(all_idxs), + "top_global_count": len(top_global), + }, indent=2, default=str) + ) + return "" + return format_schema_context(grouped) # Node 1: Retrive with Fiass Vector Store. @@ -3398,6 +3442,7 @@ def non_erp_response(non_erp_q: str) -> Optional[str]: @frappe.whitelist(allow_guest=False) def run_text2sql_pipeline(user_question: str, chat_id: str, request_id: str, sendNonErptoAI: bool = False) -> Dict: + err = "" memory_status = check_memory_status() logs = find_similar_log_question(user_question) if logs.get("matched") and logs["error"] == "" and logs.get("type") != "NonERP": diff --git a/changai/changai/prompts/sql_user_prompt.txt b/changai/changai/prompts/sql_user_prompt.txt index 5bdbfdb..07666b2 100644 --- a/changai/changai/prompts/sql_user_prompt.txt +++ b/changai/changai/prompts/sql_user_prompt.txt @@ -10,4 +10,4 @@ GENERIC FIELDS (available on ALL transaction doctypes): name, creation, modified, owner, company, docstatus, naming_series, amended_from. GENERIC FIELDS (available on ALL master doctypes): name, creation, modified, owner, disabled, naming_series -REMINDER: Use only fields from SCHEMA CONTEXT.and never ever use a field or a table that is not in the given schema. +REMINDER: Use only fields from SCHEMA CONTEXT.and never ever use any field or any table that is not in the given schema.that is important.beware of using non existing fields.