From ed98ddf6f9b1ef08dcbcad03fd11f75414f5550f Mon Sep 17 00:00:00 2001 From: Ch-Abhinav-Chowdary Date: Tue, 20 Jan 2026 19:13:53 +0530 Subject: [PATCH 1/2] Updated feature --- python/pathway/xpacks/llm/question_answering.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/pathway/xpacks/llm/question_answering.py b/python/pathway/xpacks/llm/question_answering.py index 6bb607a0..75fae11b 100644 --- a/python/pathway/xpacks/llm/question_answering.py +++ b/python/pathway/xpacks/llm/question_answering.py @@ -456,6 +456,8 @@ class BaseRAGQuestionAnswerer(SummaryQuestionAnswerer): Either string template, callable or a pw.udf function is expected. Defaults to ``pathway.xpacks.llm.prompts.prompt_qa``. String template needs to have ``context`` and ``query`` placeholders in curly brackets ``{}``. + query_transformer_prompt: Optional UDF applied to the incoming query before retrieval or prompting, + which allows implementing custom query rewriting, expansion, or normalization logic. context_processor: Utility for representing the fetched documents to the LLM. Callable, UDF or ``BaseContextProcessor`` is expected. Defaults to ``SimpleContextProcessor`` that keeps the 'path' metadata and joins the documents with double new lines. summarize_template: Template for text summarization. Defaults to ``pathway.xpacks.llm.prompts.prompt_summarize``. @@ -516,6 +518,7 @@ def __init__( *, default_llm_name: str | None = None, prompt_template: str | Callable[[str, str], str] | pw.UDF = prompts.prompt_qa, + query_transformer_prompt: pw.UDF | None = None, context_processor: ( BaseContextProcessor | Callable[[list[dict] | list[Doc]], str] | pw.UDF ) = SimpleContextProcessor(), @@ -535,6 +538,7 @@ def __init__( self._init_schemas(default_llm_name) self.prompt_udf = _get_RAG_prompt_udf(prompt_template) + self.query_transformer_prompt = query_transformer_prompt if isinstance(context_processor, BaseContextProcessor): self.docs_to_context_transformer = context_processor.as_udf() @@ -638,6 +642,11 @@ def add_score_to_doc(doc: pw.Json, score: float) -> dict: def answer_query(self, pw_ai_queries: pw.Table) -> pw.Table: """Answer a question based on the available information.""" + if self.query_transformer_prompt is not None: + pw_ai_queries = pw_ai_queries.with_columns( + prompt=self.query_transformer_prompt(pw.this.prompt) + ) + pw_ai_results = pw_ai_queries + self.indexer.retrieve_query( pw_ai_queries.select( metadata_filter=pw.this.filters, @@ -840,6 +849,8 @@ class AdaptiveRAGQuestionAnswerer(BaseRAGQuestionAnswerer): String template needs to have ``context`` and ``query`` placeholders in curly brackets ``{}``. For Adaptive RAG to work, prompt needs to instruct, to return `no_answer_string` when information is not found. + query_transformer_prompt: Optional UDF applied to the incoming query before retrieval or prompting, + which allows implementing custom query rewriting, expansion, or normalization logic. no_answer_string: string that will be returned by the LLM when information is not found. context_processor: Utility for representing the fetched documents to the LLM. Callable, UDF or ``BaseContextProcessor`` is expected. Defaults to ``SimpleContextProcessor`` that keeps the 'path' metadata and joins the documents with double new lines. @@ -896,6 +907,7 @@ def __init__( prompt_template: ( str | Callable[[str, str], str] | pw.UDF ) = prompts.prompt_qa_geometric_rag, + query_transformer_prompt: pw.UDF | None = None, no_answer_string: str = "No information found.", context_processor: ( BaseContextProcessor | Callable[[list[dict] | list[Doc]], str] | pw.UDF @@ -909,6 +921,7 @@ def __init__( llm, indexer, default_llm_name=default_llm_name, + query_transformer_prompt=query_transformer_prompt, summarize_template=summarize_template, context_processor=context_processor, ) From 19d410391ec573cd144a77057fdde2924a233e8a Mon Sep 17 00:00:00 2001 From: Ch-Abhinav-Chowdary Date: Thu, 22 Jan 2026 09:34:03 +0530 Subject: [PATCH 2/2] modified question_answering.py --- .../pathway/xpacks/llm/question_answering.py | 9 ++- python/pathway/xpacks/llm/tests/test_rag.py | 55 +++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/python/pathway/xpacks/llm/question_answering.py b/python/pathway/xpacks/llm/question_answering.py index 75fae11b..12b5a312 100644 --- a/python/pathway/xpacks/llm/question_answering.py +++ b/python/pathway/xpacks/llm/question_answering.py @@ -644,8 +644,15 @@ def answer_query(self, pw_ai_queries: pw.Table) -> pw.Table: if self.query_transformer_prompt is not None: pw_ai_queries = pw_ai_queries.with_columns( - prompt=self.query_transformer_prompt(pw.this.prompt) + rewrite_prompt=self.query_transformer_prompt(pw.this.prompt) ) + pw_ai_queries += pw_ai_queries.select( + prompt=self.llm( + llms.prompt_chat_single_qa(pw.this.rewrite_prompt), + model=pw.this.model, + ) + ) + pw_ai_queries = pw_ai_queries.await_futures() pw_ai_results = pw_ai_queries + self.indexer.retrieve_query( pw_ai_queries.select( diff --git a/python/pathway/xpacks/llm/tests/test_rag.py b/python/pathway/xpacks/llm/tests/test_rag.py index 1bbf76b7..f058a0fe 100644 --- a/python/pathway/xpacks/llm/tests/test_rag.py +++ b/python/pathway/xpacks/llm/tests/test_rag.py @@ -168,3 +168,58 @@ def _accepts_call_arg(self, arg_name: str) -> bool: assert "context" in err_msg assert "query" in err_msg + + +def test_base_rag_with_query_transformer_identity(): + schema = pw.schema_from_types(data=bytes, _metadata=dict) + input = pw.debug.table_from_rows( + schema=schema, rows=[("foo", {}), ("bar", {}), ("baz", {})] + ) + + vector_server = VectorStoreServer( + input, + embedder=fake_embeddings_model, + ) + + class TrueIdentityChat(llms.BaseChat): + def _accepts_call_arg(self, arg_name: str) -> bool: + return False + + async def __wrapped__(self, messages: list[dict] | pw.Json, model: str) -> str: + return messages[0]["content"].as_str() + + @pw.udf + def identity_rewrite(q: str) -> str: + return q + + rag = BaseRAGQuestionAnswerer( + TrueIdentityChat(), + vector_server, + prompt_template=_prompt_template, + summarize_template=_summarize_template, + search_topk=1, + query_transformer_prompt=identity_rewrite, + ) + + answer_queries = pw.debug.table_from_rows( + schema=rag.AnswerQuerySchema, + rows=[ + ("foo", None, "gpt3.5", False), + ], + ) + + answer_output = rag.answer_query(answer_queries) + + casted_table = answer_output.select( + result=pw.apply_with_type(lambda x: x.value, str, pw.this.result["response"]) + ) + + assert_table_equality( + casted_table, + pw.debug.table_from_markdown( + """ + result + foo + """ + ), + )