From fac3b0d9d43ac49330dcefaa843cbbd6d72a2d5c Mon Sep 17 00:00:00 2001 From: Daniel Namaki Date: Thu, 17 Jul 2025 15:23:39 +0200 Subject: [PATCH] support additional fields --- pyterrier_rag/prompt/_context_aggregation.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pyterrier_rag/prompt/_context_aggregation.py b/pyterrier_rag/prompt/_context_aggregation.py index f1a3e9e..d76150b 100644 --- a/pyterrier_rag/prompt/_context_aggregation.py +++ b/pyterrier_rag/prompt/_context_aggregation.py @@ -39,6 +39,7 @@ def __init__( self, in_fields: Optional[List[str]] = ["text"], out_field: Optional[str] = "qcontext", + additional_fields: Optional[List[str]] = None, text_loader: Optional[callable] = None, intermediate_format: Optional[callable] = None, tokenizer: Optional[Any] = None, @@ -53,6 +54,7 @@ def __init__( self.in_fields = in_fields self.out_field = out_field + self.additional_fields = additional_fields self.aggregate_func = aggregate_func self.text_loader = text_loader self.intermediate_format = intermediate_format @@ -93,7 +95,12 @@ def transform_by_query(self, inp: Iterable[dict]) -> Iterable[dict]: max_per_context=self.max_per_context, truncation_rate=self.truncation_rate, ) - return [{self.out_field: context, "qid": qid, "query": query}] + + # If additional fields are specified, include them in the output + if self.additional_fields is not None: + additional_fields = {field: inp[0].get(field, None) for field in self.additional_fields} + + return [{self.out_field: context, "qid": qid, "query": query, **additional_fields} if self.additional_fields else {self.out_field: context, "qid": qid, "query": query}] __all__ = ["Concatenator"]