Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions changai/changai/api/v2/emb_load_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from fastapi import FastAPI
from pydantic import BaseModel
from langchain_huggingface import HuggingFaceEmbeddings
import os
MODEL_PATH = os.environ.get("CHANGAI_EMBEDDING_MODEL_PATH")

app = FastAPI()

embedding_model = HuggingFaceEmbeddings(
model_name=MODEL_PATH,
model_kwargs={
"device": "cpu",
"trust_remote_code": True,
},
encode_kwargs={
"normalize_embeddings": True,
},
)

# warmup during service startup
embedding_model.embed_query("changai warmup")


class EmbedRequest(BaseModel):
text: str


@app.get("/health")
def health():
return {"ok": True, "model_loaded": True}


@app.post("/embed")
def embed(req: EmbedRequest):
vector = embedding_model.embed_query(req.text)
return {"embedding": vector}
23 changes: 23 additions & 0 deletions changai/changai/api/v2/embedding_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import requests
from langchain_core.embeddings import Embeddings
import frappe

@frappe.whitelist(allow_guest=False)
def get_local_embedding(text: str):
try:
res = requests.post(
"http://127.0.0.1:8001/embed",
json={"text": text},
timeout=30
)
res.raise_for_status()
return res.json()["embedding"]
except Exception as e:
frappe.throw(f"Embedding service error :{str(e)}")

class LocalEmbeddingService(Embeddings):
def embed_query(self, text: str) -> list[float]:
return get_local_embedding(text)

def embed_documents(self, texts: list[str]) -> list[list[float]]:
return [get_local_embedding(t) for t in texts]
117 changes: 81 additions & 36 deletions changai/changai/api/v2/text2sql_pipeline_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
from pathlib import Path
import numpy as np
from typing import List, Dict, Any
# from changai.changai.api.v2.embedding_client import LocalEmbeddingService
# from changai.changai.api.v2.embedding_client import get_local_embedding
# from symspellpy.symspellpy import SymSpell
sym_spell = None
_GEMINI_CLIENT = None
Expand Down Expand Up @@ -210,11 +212,22 @@ def enrich_fields_for_sql_context(table: str, fields: list[str]) -> list[str]:
return out


def format_schema_context(grouped: dict[str, list[str]]) -> str:
def format_schema_context(grouped: dict) -> str:
parts = []

for table, raw_fields in grouped.items():
child = is_child_table(table)
for table, table_data in grouped.items():
if isinstance(table_data, dict):
raw_fields = table_data.get("fields", [])
is_table_value = table_data.get("is_table")

if is_table_value is None:
child = is_child_table(table)
else:
child = bool(is_table_value)
else:
raw_fields = table_data
child = is_child_table(table)

fields = enrich_fields_for_sql_context(table, raw_fields)

parts.append(f"TABLE: {table}")
Expand Down Expand Up @@ -435,15 +448,6 @@ def load_field_matrix():

return docs, embs, table_to_idx

def _get_cached_embedding_test(q: str) -> tuple:
t0=time.time()
emb = get_embedding_engine()
emb_load_time = time.time() - t0
t1 = time.time()
vec = emb.embed_query(q)
embed_query_time = time.time() - t1
return emb_load_time,embed_query_time # tuple for hashability


def get_embedding_engine():
global _EMBEDDER_INSTANCE
Expand Down Expand Up @@ -1433,14 +1437,12 @@ def check_memory_status() -> dict:
}
}


@lru_cache(maxsize=512)
def _get_cached_embedding(q: str, request_id: str) -> tuple:
publish_pipeline_update(
request_id,
"embedding_start",
"embedding started"
)
# vec = get_local_embedding(q)
emb = get_embedding_engine()

publish_pipeline_update(
request_id,
"embedding_end",
Expand Down Expand Up @@ -1492,7 +1494,6 @@ def create_entity(state:SQLState):




def call_fvs_table_search(get_table: bool, q: str, request_id: str) -> List[str]:
# get cached embedding
publish_pipeline_update(
Expand All @@ -1501,6 +1502,11 @@ def call_fvs_table_search(get_table: bool, q: str, request_id: str) -> List[str]
_("Inside the Table Search Function")
)
q_vec = np.array(_get_cached_embedding(q,request_id), dtype="float32")
publish_pipeline_update(
request_id,
"Completed Embed for Table Search Function",
_("Completed Embed for Table Search Function")
)

# use FAISS index directly instead of similarity_search
publish_pipeline_update(
Expand Down Expand Up @@ -1555,7 +1561,6 @@ def build_hnsw_index(embeddings):

return index


def call_retrieve_multi_line(user_question: str, request_id: str) -> Dict[str, Any]:
try:
top_tables = call_fvs_table_search(True, user_question, request_id)
Expand Down Expand Up @@ -1586,33 +1591,57 @@ def call_retrieve_multi_line(user_question: str, request_id: str) -> Dict[str, A
except Exception as e:
return {"selected_fields": {}, "selected_tables": [], "top_tables": [], "error": str(e)}


def call_fvs_field_search_global_k(
user_question: str,
selected_tables: List[str],
k_total: int = 40,
request_id: Optional[str] = None
) -> str:

if isinstance(selected_tables, str):
try:
selected_tables = json.loads(selected_tables)
except Exception:
selected_tables = [selected_tables]
if not user_question or not selected_tables:
return ""

docs, embs, table_to_idx = load_field_matrix()

# emb = get_embedding_engine()

q_vec = np.array(
_get_cached_embedding(user_question, request_id),
dtype="float32"
)

q_vec = q_vec / max(np.linalg.norm(q_vec), 1e-12)

# collect indices
all_idxs = []

for t in selected_tables:
all_idxs.extend(table_to_idx.get(t, []))
t = str(t).strip()
if not t:
continue

candidates = [
t,
f"tab{t}" if not t.startswith("tab") else t,
t.replace("tab", "", 1) if t.startswith("tab") else t,
]

for key in candidates:
if key in table_to_idx:
all_idxs.extend(table_to_idx[key])
break

if not all_idxs:
frappe.log_error(
title="ChangAI Field Search: No Indexes Found",
message=json.dumps({
"user_question": user_question,
"selected_tables": selected_tables,
"sample_table_to_idx_keys": list(table_to_idx.keys())[:50],
}, indent=2, default=str)
)
return ""

sub_embs = embs[all_idxs]
Expand All @@ -1628,41 +1657,56 @@ def call_fvs_field_search_global_k(
d = docs[doc_i]

meta = getattr(d, "metadata", {}) or {}

is_table = meta.get("is_table")
table = meta.get("table")
field = meta.get("field") or meta.get("name")
field = meta.get("field") or meta.get("name")

if not table or not field:
continue

key = (table, field)
if key in seen:
continue

seen.add(key)

name = field

# join hint
if meta.get("join_hint"):
linked_table = meta["join_hint"].get("table")
join_hint = meta.get("join_hint")
if isinstance(join_hint, dict):
linked_table = join_hint.get("table")
if linked_table:
name += f" -> {linked_table}"
elif isinstance(join_hint, str) and join_hint.strip():
name += f" -> {join_hint.strip()}"

# options
if meta.get("options"):
opts = meta["options"]
opts = meta.get("options")
if opts:
if isinstance(opts, list):
name += " {" + ", ".join(str(o) for o in opts[:5]) + "}"
else:
name += " {" + str(opts) + "}"

grouped.setdefault(table, {
"is_table": is_table,
"fields": []
})

grouped[table]["fields"].append(name)

res = format_schema_context(grouped)
# 🔥 final compact string
return res
grouped[table]["fields"].append(name)

if not grouped:
frappe.log_error(
title="ChangAI Field Search: Empty Grouped Result",
message=json.dumps({
"user_question": user_question,
"selected_tables": selected_tables,
"all_idxs_count": len(all_idxs),
"top_global_count": len(top_global),
}, indent=2, default=str)
)
return ""
return format_schema_context(grouped)


# Node 1: Retrive with Fiass Vector Store.
Expand Down Expand Up @@ -3398,6 +3442,7 @@ def non_erp_response(non_erp_q: str) -> Optional[str]:

@frappe.whitelist(allow_guest=False)
def run_text2sql_pipeline(user_question: str, chat_id: str, request_id: str, sendNonErptoAI: bool = False) -> Dict:
err = ""
memory_status = check_memory_status()
logs = find_similar_log_question(user_question)
if logs.get("matched") and logs["error"] == "" and logs.get("type") != "NonERP":
Expand Down
2 changes: 1 addition & 1 deletion changai/changai/prompts/sql_user_prompt.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ GENERIC FIELDS (available on ALL transaction doctypes):
name, creation, modified, owner, company, docstatus, naming_series, amended_from.
GENERIC FIELDS (available on ALL master doctypes):
name, creation, modified, owner, disabled, naming_series
REMINDER: Use only fields from SCHEMA CONTEXT.and never ever use a field or a table that is not in the given schema.
REMINDER: Use only fields from SCHEMA CONTEXT.and never ever use any field or any table that is not in the given schema.that is important.beware of using non existing fields.
Loading