From 92a6405724604df99ce06a5024771b2757ca6bf0 Mon Sep 17 00:00:00 2001 From: "Md. Tareq Mahmood" Date: Wed, 18 Mar 2026 15:54:59 -0500 Subject: [PATCH 1/2] Add physical= parameter for per-operator physical implementation hints MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a `physical=` dict parameter to sem_filter, sem_map, sem_flat_map, sem_join, and sem_agg that lets users override the optimizer's physical operator selection. The dict requires an "implementation" key (the physical operator class). All other keys are forwarded as constructor kwargs to that operator, overriding rule-generated defaults. Only the matching implementation receives the extra kwargs — other rules build operators normally and are filtered out post-substitution. Changes: - logical.py: store physical on LogicalOperator, validate at construction, include in get_logical_op_params() but not get_logical_id_params() - dataset.py: thread physical= through semantic Dataset methods - rules.py: guard extra kwargs injection by implementation class match - tasks.py: post-filter expressions by implementation, warn on empty --- src/palimpzest/core/data/dataset.py | 22 +++++++++------- src/palimpzest/query/operators/logical.py | 20 ++++++++++++++ src/palimpzest/query/optimizer/rules.py | 19 ++++++++++++-- src/palimpzest/query/optimizer/tasks.py | 32 +++++++++++++++++++++++ 4 files changed, 82 insertions(+), 11 deletions(-) diff --git a/src/palimpzest/core/data/dataset.py b/src/palimpzest/core/data/dataset.py index 3c1181de3..bfcb16879 100644 --- a/src/palimpzest/core/data/dataset.py +++ b/src/palimpzest/core/data/dataset.py @@ -266,7 +266,7 @@ def join(self, other: Dataset, on: str | list[str], how: str = "inner") -> Datas return Dataset(sources=[self, other], operator=operator, schema=combined_schema) - def sem_join(self, other: Dataset, condition: str, desc: str | None = None, depends_on: str | list[str] | None = None, how: str = "inner") -> Dataset: + def sem_join(self, other: Dataset, condition: str, desc: str | None = None, depends_on: str | list[str] | None = None, how: str = "inner", physical: dict | None = None) -> Dataset: """ Perform a semantic (inner) join on the specified join predicate """ @@ -285,6 +285,7 @@ def sem_join(self, other: Dataset, condition: str, desc: str | None = None, depe how=how, desc=desc, depends_on=depends_on, + physical=physical, ) return Dataset(sources=[self, other], operator=operator, schema=combined_schema) @@ -319,6 +320,7 @@ def sem_filter( filter: str, desc: str | None = None, depends_on: str | list[str] | None = None, + physical: dict | None = None, ) -> Dataset: """Add a natural language description of a filter to the Set. This filter will possibly restrict the items that are returned later.""" # construct Filter object @@ -333,14 +335,15 @@ def sem_filter( depends_on = [depends_on] # construct logical operator - operator = FilteredScan(input_schema=self.schema, output_schema=self.schema, filter=f, desc=desc, depends_on=depends_on) + operator = FilteredScan(input_schema=self.schema, output_schema=self.schema, filter=f, desc=desc, depends_on=depends_on, physical=physical) return Dataset(sources=[self], operator=operator, schema=self.schema) def _sem_map(self, cols: list[dict] | type[BaseModel] | None, cardinality: Cardinality, desc: str | None = None, - depends_on: str | list[str] | None = None) -> Dataset: + depends_on: str | list[str] | None = None, + physical: dict | None = None) -> Dataset: """Execute the semantic map operation with the appropriate cardinality.""" # construct new output schema new_output_schema = None @@ -366,6 +369,7 @@ def _sem_map(self, cols: list[dict] | type[BaseModel] | None, udf=None, desc=desc, depends_on=depends_on, + physical=physical, ) return Dataset(sources=[self], operator=operator, schema=new_output_schema) @@ -399,7 +403,7 @@ def sem_add_columns(self, cols: list[dict] | type[BaseModel], return self._sem_map(cols, cardinality, desc, depends_on) - def sem_map(self, cols: list[dict] | type[BaseModel], desc: str | None = None, depends_on: str | list[str] | None = None) -> Dataset: + def sem_map(self, cols: list[dict] | type[BaseModel], desc: str | None = None, depends_on: str | list[str] | None = None, physical: dict | None = None) -> Dataset: """ Compute new field(s) by specifying their names, descriptions, and types. For each input there will be one output. The field(s) will be computed during the execution of the Dataset. @@ -411,9 +415,9 @@ def sem_map(self, cols: list[dict] | type[BaseModel], desc: str | None = None, d {'name': 'full_name', 'desc': 'The name of the person', 'type': str}] ) """ - return self._sem_map(cols, Cardinality.ONE_TO_ONE, desc, depends_on) + return self._sem_map(cols, Cardinality.ONE_TO_ONE, desc, depends_on, physical) - def sem_flat_map(self, cols: list[dict] | type[BaseModel], desc: str | None = None, depends_on: str | list[str] | None = None) -> Dataset: + def sem_flat_map(self, cols: list[dict] | type[BaseModel], desc: str | None = None, depends_on: str | list[str] | None = None, physical: dict | None = None) -> Dataset: """ Compute new field(s) by specifying their names, descriptions, and types. For each input there will be one or more output(s). The field(s) will be computed during the execution of the Dataset. @@ -427,7 +431,7 @@ def sem_flat_map(self, cols: list[dict] | type[BaseModel], desc: str | None = No ] ) """ - return self._sem_map(cols, Cardinality.ONE_TO_MANY, desc, depends_on) + return self._sem_map(cols, Cardinality.ONE_TO_MANY, desc, depends_on, physical) def _map(self, udf: Callable, cols: list[dict] | type[BaseModel] | None, @@ -577,7 +581,7 @@ def groupby(self, groupby: GroupBySig) -> Dataset: operator = GroupByAggregate(input_schema=self.schema, output_schema=output_schema, group_by_sig=groupby) return Dataset(sources=[self], operator=operator, schema=output_schema) - def sem_agg(self, col: dict | type[BaseModel], agg: str, depends_on: str | list[str] | None = None) -> Dataset: + def sem_agg(self, col: dict | type[BaseModel], agg: str, depends_on: str | list[str] | None = None, physical: dict | None = None) -> Dataset: """ Apply a semantic aggregation to this set. The `agg` string will be applied using an LLM over the entire set of inputs' fields specified in `depends_on` to generate the output `col`. @@ -604,7 +608,7 @@ def sem_agg(self, col: dict | type[BaseModel], agg: str, depends_on: str | list[ depends_on = [depends_on] # construct logical operator - operator = Aggregate(input_schema=self.schema, output_schema=new_output_schema, agg_str=agg, depends_on=depends_on) + operator = Aggregate(input_schema=self.schema, output_schema=new_output_schema, agg_str=agg, depends_on=depends_on, physical=physical) return Dataset(sources=[self], operator=operator, schema=operator.output_schema) diff --git a/src/palimpzest/query/operators/logical.py b/src/palimpzest/query/operators/logical.py index d933ef0f7..2fa1593af 100644 --- a/src/palimpzest/query/operators/logical.py +++ b/src/palimpzest/query/operators/logical.py @@ -40,11 +40,15 @@ def __init__( output_schema: type[BaseModel], input_schema: type[BaseModel] | None = None, depends_on: list[str] | None = None, + physical: dict | None = None, ): # TODO: can we eliminate input_schema? self.output_schema = output_schema self.input_schema = input_schema self.depends_on = [] if depends_on is None else sorted(depends_on) + self.physical = physical + if physical is not None: + self._validate_physical(physical) self.logical_op_id: str | None = None self.unique_logical_op_id: str | None = None @@ -54,6 +58,21 @@ def __init__( [field_name for field_name in self.output_schema.model_fields if field_name not in input_field_names] ) + @staticmethod + def _validate_physical(physical: dict) -> None: + """Validate the physical= dict at query construction time.""" + if not isinstance(physical, dict): + raise TypeError(f"physical must be a dict, got {type(physical).__name__}") + + if "implementation" not in physical: + raise ValueError("physical dict must contain an 'implementation' key") + + impl_cls = physical["implementation"] + if not isinstance(impl_cls, type): + raise TypeError( + f"physical['implementation'] must be a class, got {type(impl_cls).__name__}" + ) + def __str__(self) -> str: raise NotImplementedError("Abstract method") @@ -107,6 +126,7 @@ def get_logical_op_params(self) -> dict: "input_schema": self.input_schema, "output_schema": self.output_schema, "depends_on": self.depends_on, + "physical": self.physical, } def get_logical_op_id(self): diff --git a/src/palimpzest/query/optimizer/rules.py b/src/palimpzest/query/optimizer/rules.py index 93c9529be..c4534544c 100644 --- a/src/palimpzest/query/optimizer/rules.py +++ b/src/palimpzest/query/optimizer/rules.py @@ -533,11 +533,15 @@ def _embedding_model_matches_input(cls, model: Model, logical_expression: Logica @classmethod def _get_fixed_op_kwargs(cls, logical_expression: LogicalExpression, runtime_kwargs: dict) -> dict: """Get the fixed set of physical op kwargs provided by the logical expression and the runtime keyword arguments.""" - # get logical operator + # get logical operator logical_op = logical_expression.operator # set initial set of parameters for physical op op_kwargs = logical_op.get_logical_op_params() + + # remove physical — it's used by the optimizer, not the physical operator + op_kwargs.pop("physical", None) + op_kwargs.update( { "verbose": runtime_kwargs["verbose"], @@ -581,6 +585,16 @@ def _perform_substitution( # get physical operator kwargs which are fixed for each instance of the physical operator fixed_op_kwargs = cls._get_fixed_op_kwargs(logical_expression, runtime_kwargs) + # if the user specified a physical= dict, only inject extra kwargs when + # this rule is building the requested implementation class + physical = getattr(logical_expression.operator, "physical", None) or {} + impl_cls = physical.get("implementation") + extra_physical_kwargs = {} + if impl_cls is None or physical_op_class is impl_cls: + extra_physical_kwargs = { + k: v for k, v in physical.items() if k != "implementation" + } + # make variable_op_kwargs a list of dictionaries if variable_op_kwargs is None: variable_op_kwargs = [{}] @@ -591,7 +605,8 @@ def _perform_substitution( physical_expressions = [] for var_op_kwargs in variable_op_kwargs: # get kwargs for this physical operator instance - op_kwargs = {**fixed_op_kwargs, **var_op_kwargs} + # extra_physical_kwargs override rule-generated values when user specified them + op_kwargs = {**fixed_op_kwargs, **var_op_kwargs, **extra_physical_kwargs} # construct the physical operator op = physical_op_class(**op_kwargs) diff --git a/src/palimpzest/query/optimizer/tasks.py b/src/palimpzest/query/optimizer/tasks.py index 9b5247ad8..d9e8146d8 100644 --- a/src/palimpzest/query/optimizer/tasks.py +++ b/src/palimpzest/query/optimizer/tasks.py @@ -14,6 +14,25 @@ logger = logging.getLogger(__name__) + +def _filter_expressions_by_physical(expressions, physical): + """ + Filter physical expressions based on a ``physical`` dict from the logical operator. + + Only the ``"implementation"`` key is used for filtering (exact class match). + All other keys are constructor kwargs — they are forwarded to the matching + physical operator by ``_perform_substitution`` in rules.py. + """ + if not isinstance(physical, dict): + return list(expressions) + + impl_cls = physical.get("implementation") + if impl_cls is None: + return list(expressions) + + return [e for e in expressions if type(e.operator) is impl_cls] + + class Task: """ Base class for a task. Each task has a method called perform() which executes the task. @@ -246,6 +265,19 @@ def perform( else: # apply implementation rule new_expressions = self.rule.substitute(self.logical_expression, **physical_op_params) + + # filter physical expressions by physical dict (if present on the logical operator) + physical = getattr(self.logical_expression.operator, "physical", None) + if physical is not None: + pre_filter_count = len(new_expressions) + new_expressions = _filter_expressions_by_physical(new_expressions, physical) + if pre_filter_count > 0 and len(new_expressions) == 0: + logger.warning( + f"physical= hint {physical} on {self.logical_expression.operator.logical_op_name()} " + f"filtered out all {pre_filter_count} candidate(s) from {self.rule.get_rule_id()}. " + f"If no other rule produces a match, optimization will fail." + ) + new_expressions = [expr for expr in new_expressions if expr.expr_id not in expressions] # get the costed_full_op_ids from the context (if provided) and compute whether this From 3130bd42e3d9cec5c4539cac688891702b9fdd85 Mon Sep 17 00:00:00 2001 From: "Md. Tareq Mahmood" Date: Wed, 18 Mar 2026 15:55:08 -0500 Subject: [PATCH 2/2] Add unit tests for physical= operator hints 18 tests covering: - Expression filtering by implementation class (exact type match) - Validation (rejects missing/invalid implementation key) - Propagation through logical operators and copy - Dataset API integration (sem_filter, sem_map, sem_flat_map) - End-to-end usage pattern --- tests/pytest/test_hints.py | 221 +++++++++++++++++++++++++++++++++++++ 1 file changed, 221 insertions(+) create mode 100644 tests/pytest/test_hints.py diff --git a/tests/pytest/test_hints.py b/tests/pytest/test_hints.py new file mode 100644 index 000000000..d40aaf09d --- /dev/null +++ b/tests/pytest/test_hints.py @@ -0,0 +1,221 @@ +"""Tests for the physical= query hinting system for physical operator selection.""" + +import pytest +from pydantic import BaseModel, Field + +from palimpzest.constants import Cardinality, Model +from palimpzest.core.elements.filters import Filter +from palimpzest.query.operators.convert import LLMConvertBonded +from palimpzest.query.operators.filter import LLMFilter, NonLLMFilter +from palimpzest.query.operators.logical import ConvertScan, FilteredScan +from palimpzest.query.operators.mixture_of_agents import MixtureOfAgentsFilter +from palimpzest.query.operators.rag import RAGFilter +from palimpzest.query.optimizer.tasks import _filter_expressions_by_physical + + +# --- Fixtures --- + + +@pytest.fixture +def schema(): + class SimpleSchema(BaseModel): + text: str = Field(description="The text of the document") + + return SimpleSchema + + +@pytest.fixture +def output_schema(): + class OutputSchema(BaseModel): + text: str = Field(description="The text of the document") + summary: str = Field(description="Summary of the text") + + return OutputSchema + + +# --- Mock physical expressions --- + + +class _MockOp: + """A lightweight stand-in for a physical operator.""" + + def __init__(self, cls): + self.__class__ = cls + + +def _make_mock_expr(cls): + op = _MockOp(cls) + expr = type("Expr", (), {"operator": op, "expr_id": f"{cls.__name__}-{id(op)}"})() + return expr + + +# --- Tests for _filter_expressions_by_physical --- + + +class TestFilterExpressionsByPhysical: + def test_impl_filters_by_exact_class(self): + exprs = [ + _make_mock_expr(LLMFilter), + _make_mock_expr(NonLLMFilter), + _make_mock_expr(MixtureOfAgentsFilter), + ] + result = _filter_expressions_by_physical(exprs, {"implementation": LLMFilter}) + assert len(result) == 1 + assert type(result[0].operator) is LLMFilter + + def test_impl_no_subclass_matching(self): + """Exact type match — LLMFilter should NOT match MixtureOfAgentsFilter subclass.""" + exprs = [ + _make_mock_expr(LLMFilter), + _make_mock_expr(MixtureOfAgentsFilter), + ] + result = _filter_expressions_by_physical(exprs, {"implementation": LLMFilter}) + assert len(result) == 1 + assert type(result[0].operator) is LLMFilter + + def test_no_match_returns_empty(self): + exprs = [_make_mock_expr(LLMFilter)] + result = _filter_expressions_by_physical(exprs, {"implementation": RAGFilter}) + assert len(result) == 0 + + def test_none_physical_returns_all(self): + exprs = [_make_mock_expr(LLMFilter), _make_mock_expr(NonLLMFilter)] + result = _filter_expressions_by_physical(exprs, None) + assert len(result) == 2 + + def test_no_implementation_key_returns_all(self): + exprs = [_make_mock_expr(LLMFilter), _make_mock_expr(NonLLMFilter)] + result = _filter_expressions_by_physical(exprs, {"model": Model.GPT_4o}) + assert len(result) == 2 + + def test_extra_kwargs_ignored_during_filtering(self): + """Extra keys beyond implementation don't affect filtering.""" + exprs = [_make_mock_expr(LLMFilter), _make_mock_expr(RAGFilter)] + result = _filter_expressions_by_physical( + exprs, {"implementation": RAGFilter, "chunk_size": 2000} + ) + assert len(result) == 1 + assert type(result[0].operator) is RAGFilter + + +# --- Tests for physical validation --- + + +class TestPhysicalValidation: + def test_valid_physical_dict(self, schema): + """LLMFilter accepts 'model' via its constructor.""" + f = Filter("text contains 'hello'") + op = FilteredScan( + input_schema=schema, output_schema=schema, filter=f, + physical={"implementation": LLMFilter, "model": Model.GPT_4o}, + ) + assert op.physical["model"] is Model.GPT_4o + + def test_rejects_missing_implementation(self, schema): + f = Filter("text contains 'hello'") + with pytest.raises(ValueError, match="implementation"): + FilteredScan( + input_schema=schema, output_schema=schema, filter=f, + physical={"model": Model.GPT_4o}, + ) + + def test_rejects_non_class_implementation(self, schema): + f = Filter("text contains 'hello'") + with pytest.raises(TypeError, match="must be a class"): + FilteredScan( + input_schema=schema, output_schema=schema, filter=f, + physical={"implementation": "LLMFilter"}, + ) + + def test_accepts_valid_rag_kwargs(self, schema): + """RAGFilter accepts chunk_size, embedding_model, num_chunks_per_field.""" + f = Filter("text contains 'hello'") + op = FilteredScan( + input_schema=schema, output_schema=schema, filter=f, + physical={ + "implementation": RAGFilter, + "model": Model.GPT_4o, + "embedding_model": Model.TEXT_EMBEDDING_3_SMALL, + "chunk_size": 2000, + "num_chunks_per_field": 4, + }, + ) + assert op.physical["chunk_size"] == 2000 + + def test_no_physical_is_fine(self, schema): + f = Filter("text contains 'hello'") + op = FilteredScan(input_schema=schema, output_schema=schema, filter=f) + assert op.physical is None + + +# --- Tests for physical propagation through logical operators --- + + +class TestPhysicalPropagation: + def test_physical_not_in_id_params(self, schema): + """physical should NOT affect the logical operator identity.""" + f = Filter("text contains 'hello'") + op_with = FilteredScan( + input_schema=schema, output_schema=schema, filter=f, + physical={"implementation": LLMFilter}, + ) + op_without = FilteredScan(input_schema=schema, output_schema=schema, filter=f) + assert op_with.get_logical_op_id() == op_without.get_logical_op_id() + + def test_copy_preserves_physical(self, schema): + f = Filter("text contains 'hello'") + phys = {"implementation": MixtureOfAgentsFilter} + op = FilteredScan(input_schema=schema, output_schema=schema, filter=f, physical=phys) + op_copy = op.copy() + assert op_copy.physical == phys + + +# --- Tests for physical in Dataset API --- + + +class TestDatasetPhysicalAPI: + def test_sem_filter_accepts_physical(self): + from palimpzest.core.data.iter_dataset import MemoryDataset + + ds = MemoryDataset(id="test", vals=["hello", "world"]) + phys = {"implementation": LLMFilter} + result = ds.sem_filter("text contains 'hello'", physical=phys) + assert result._operator.physical is phys + + def test_sem_map_accepts_physical(self): + from palimpzest.core.data.iter_dataset import MemoryDataset + + ds = MemoryDataset(id="test", vals=["hello", "world"]) + phys = {"implementation": LLMConvertBonded, "model": Model.GPT_4o} + result = ds.sem_map( + [{"name": "summary", "desc": "Summary", "type": str}], + physical=phys, + ) + assert result._operator.physical is phys + + def test_sem_flat_map_accepts_physical(self): + from palimpzest.core.data.iter_dataset import MemoryDataset + + ds = MemoryDataset(id="test", vals=["hello", "world"]) + phys = {"implementation": LLMConvertBonded} + result = ds.sem_flat_map( + [{"name": "word", "desc": "A word", "type": str}], + physical=phys, + ) + assert result._operator.physical is phys + + +# --- Test end-to-end usage pattern --- + + +class TestUsagePattern: + def test_example_usage_pattern(self): + from palimpzest.core.data.iter_dataset import MemoryDataset + + ds = MemoryDataset(id="demo", vals=["a", "b"]) + result = ds.sem_filter( + "text is scientific", + physical={"implementation": LLMFilter, "model": Model.GPT_4o}, + ) + assert result._operator.physical["implementation"] is LLMFilter + assert result._operator.physical["model"] is Model.GPT_4o