Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions src/palimpzest/core/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def join(self, other: Dataset, on: str | list[str], how: str = "inner") -> Datas

return Dataset(sources=[self, other], operator=operator, schema=combined_schema)

def sem_join(self, other: Dataset, condition: str, desc: str | None = None, depends_on: str | list[str] | None = None, how: str = "inner") -> Dataset:
def sem_join(self, other: Dataset, condition: str, desc: str | None = None, depends_on: str | list[str] | None = None, how: str = "inner", physical: dict | None = None) -> Dataset:
"""
Perform a semantic (inner) join on the specified join predicate
"""
Expand All @@ -285,6 +285,7 @@ def sem_join(self, other: Dataset, condition: str, desc: str | None = None, depe
how=how,
desc=desc,
depends_on=depends_on,
physical=physical,
)

return Dataset(sources=[self, other], operator=operator, schema=combined_schema)
Expand Down Expand Up @@ -319,6 +320,7 @@ def sem_filter(
filter: str,
desc: str | None = None,
depends_on: str | list[str] | None = None,
physical: dict | None = None,
) -> Dataset:
"""Add a natural language description of a filter to the Set. This filter will possibly restrict the items that are returned later."""
# construct Filter object
Expand All @@ -333,14 +335,15 @@ def sem_filter(
depends_on = [depends_on]

# construct logical operator
operator = FilteredScan(input_schema=self.schema, output_schema=self.schema, filter=f, desc=desc, depends_on=depends_on)
operator = FilteredScan(input_schema=self.schema, output_schema=self.schema, filter=f, desc=desc, depends_on=depends_on, physical=physical)

return Dataset(sources=[self], operator=operator, schema=self.schema)

def _sem_map(self, cols: list[dict] | type[BaseModel] | None,
cardinality: Cardinality,
desc: str | None = None,
depends_on: str | list[str] | None = None) -> Dataset:
depends_on: str | list[str] | None = None,
physical: dict | None = None) -> Dataset:
"""Execute the semantic map operation with the appropriate cardinality."""
# construct new output schema
new_output_schema = None
Expand All @@ -366,6 +369,7 @@ def _sem_map(self, cols: list[dict] | type[BaseModel] | None,
udf=None,
desc=desc,
depends_on=depends_on,
physical=physical,
)

return Dataset(sources=[self], operator=operator, schema=new_output_schema)
Expand Down Expand Up @@ -399,7 +403,7 @@ def sem_add_columns(self, cols: list[dict] | type[BaseModel],

return self._sem_map(cols, cardinality, desc, depends_on)

def sem_map(self, cols: list[dict] | type[BaseModel], desc: str | None = None, depends_on: str | list[str] | None = None) -> Dataset:
def sem_map(self, cols: list[dict] | type[BaseModel], desc: str | None = None, depends_on: str | list[str] | None = None, physical: dict | None = None) -> Dataset:
"""
Compute new field(s) by specifying their names, descriptions, and types. For each input there will
be one output. The field(s) will be computed during the execution of the Dataset.
Expand All @@ -411,9 +415,9 @@ def sem_map(self, cols: list[dict] | type[BaseModel], desc: str | None = None, d
{'name': 'full_name', 'desc': 'The name of the person', 'type': str}]
)
"""
return self._sem_map(cols, Cardinality.ONE_TO_ONE, desc, depends_on)
return self._sem_map(cols, Cardinality.ONE_TO_ONE, desc, depends_on, physical)

def sem_flat_map(self, cols: list[dict] | type[BaseModel], desc: str | None = None, depends_on: str | list[str] | None = None) -> Dataset:
def sem_flat_map(self, cols: list[dict] | type[BaseModel], desc: str | None = None, depends_on: str | list[str] | None = None, physical: dict | None = None) -> Dataset:
"""
Compute new field(s) by specifying their names, descriptions, and types. For each input there will
be one or more output(s). The field(s) will be computed during the execution of the Dataset.
Expand All @@ -427,7 +431,7 @@ def sem_flat_map(self, cols: list[dict] | type[BaseModel], desc: str | None = No
]
)
"""
return self._sem_map(cols, Cardinality.ONE_TO_MANY, desc, depends_on)
return self._sem_map(cols, Cardinality.ONE_TO_MANY, desc, depends_on, physical)

def _map(self, udf: Callable,
cols: list[dict] | type[BaseModel] | None,
Expand Down Expand Up @@ -577,7 +581,7 @@ def groupby(self, groupby: GroupBySig) -> Dataset:
operator = GroupByAggregate(input_schema=self.schema, output_schema=output_schema, group_by_sig=groupby)
return Dataset(sources=[self], operator=operator, schema=output_schema)

def sem_agg(self, col: dict | type[BaseModel], agg: str, depends_on: str | list[str] | None = None) -> Dataset:
def sem_agg(self, col: dict | type[BaseModel], agg: str, depends_on: str | list[str] | None = None, physical: dict | None = None) -> Dataset:
"""
Apply a semantic aggregation to this set. The `agg` string will be applied using an LLM
over the entire set of inputs' fields specified in `depends_on` to generate the output `col`.
Expand All @@ -604,7 +608,7 @@ def sem_agg(self, col: dict | type[BaseModel], agg: str, depends_on: str | list[
depends_on = [depends_on]

# construct logical operator
operator = Aggregate(input_schema=self.schema, output_schema=new_output_schema, agg_str=agg, depends_on=depends_on)
operator = Aggregate(input_schema=self.schema, output_schema=new_output_schema, agg_str=agg, depends_on=depends_on, physical=physical)

return Dataset(sources=[self], operator=operator, schema=operator.output_schema)

Expand Down
20 changes: 20 additions & 0 deletions src/palimpzest/query/operators/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,15 @@ def __init__(
output_schema: type[BaseModel],
input_schema: type[BaseModel] | None = None,
depends_on: list[str] | None = None,
physical: dict | None = None,
):
# TODO: can we eliminate input_schema?
self.output_schema = output_schema
self.input_schema = input_schema
self.depends_on = [] if depends_on is None else sorted(depends_on)
self.physical = physical
if physical is not None:
self._validate_physical(physical)
self.logical_op_id: str | None = None
self.unique_logical_op_id: str | None = None

Expand All @@ -54,6 +58,21 @@ def __init__(
[field_name for field_name in self.output_schema.model_fields if field_name not in input_field_names]
)

@staticmethod
def _validate_physical(physical: dict) -> None:
"""Validate the physical= dict at query construction time."""
if not isinstance(physical, dict):
raise TypeError(f"physical must be a dict, got {type(physical).__name__}")

if "implementation" not in physical:
raise ValueError("physical dict must contain an 'implementation' key")

impl_cls = physical["implementation"]
if not isinstance(impl_cls, type):
raise TypeError(
f"physical['implementation'] must be a class, got {type(impl_cls).__name__}"
)

def __str__(self) -> str:
raise NotImplementedError("Abstract method")

Expand Down Expand Up @@ -107,6 +126,7 @@ def get_logical_op_params(self) -> dict:
"input_schema": self.input_schema,
"output_schema": self.output_schema,
"depends_on": self.depends_on,
"physical": self.physical,
}

def get_logical_op_id(self):
Expand Down
19 changes: 17 additions & 2 deletions src/palimpzest/query/optimizer/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,11 +533,15 @@ def _embedding_model_matches_input(cls, model: Model, logical_expression: Logica
@classmethod
def _get_fixed_op_kwargs(cls, logical_expression: LogicalExpression, runtime_kwargs: dict) -> dict:
"""Get the fixed set of physical op kwargs provided by the logical expression and the runtime keyword arguments."""
# get logical operator
# get logical operator
logical_op = logical_expression.operator

# set initial set of parameters for physical op
op_kwargs = logical_op.get_logical_op_params()

# remove physical — it's used by the optimizer, not the physical operator
op_kwargs.pop("physical", None)

op_kwargs.update(
{
"verbose": runtime_kwargs["verbose"],
Expand Down Expand Up @@ -581,6 +585,16 @@ def _perform_substitution(
# get physical operator kwargs which are fixed for each instance of the physical operator
fixed_op_kwargs = cls._get_fixed_op_kwargs(logical_expression, runtime_kwargs)

# if the user specified a physical= dict, only inject extra kwargs when
# this rule is building the requested implementation class
physical = getattr(logical_expression.operator, "physical", None) or {}
impl_cls = physical.get("implementation")
extra_physical_kwargs = {}
if impl_cls is None or physical_op_class is impl_cls:
extra_physical_kwargs = {
k: v for k, v in physical.items() if k != "implementation"
}

# make variable_op_kwargs a list of dictionaries
if variable_op_kwargs is None:
variable_op_kwargs = [{}]
Expand All @@ -591,7 +605,8 @@ def _perform_substitution(
physical_expressions = []
for var_op_kwargs in variable_op_kwargs:
# get kwargs for this physical operator instance
op_kwargs = {**fixed_op_kwargs, **var_op_kwargs}
# extra_physical_kwargs override rule-generated values when user specified them
op_kwargs = {**fixed_op_kwargs, **var_op_kwargs, **extra_physical_kwargs}

# construct the physical operator
op = physical_op_class(**op_kwargs)
Expand Down
32 changes: 32 additions & 0 deletions src/palimpzest/query/optimizer/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,25 @@

logger = logging.getLogger(__name__)


def _filter_expressions_by_physical(expressions, physical):
"""
Filter physical expressions based on a ``physical`` dict from the logical operator.

Only the ``"implementation"`` key is used for filtering (exact class match).
All other keys are constructor kwargs — they are forwarded to the matching
physical operator by ``_perform_substitution`` in rules.py.
"""
if not isinstance(physical, dict):
return list(expressions)

impl_cls = physical.get("implementation")
if impl_cls is None:
return list(expressions)

return [e for e in expressions if type(e.operator) is impl_cls]


class Task:
"""
Base class for a task. Each task has a method called perform() which executes the task.
Expand Down Expand Up @@ -246,6 +265,19 @@ def perform(
else:
# apply implementation rule
new_expressions = self.rule.substitute(self.logical_expression, **physical_op_params)

# filter physical expressions by physical dict (if present on the logical operator)
physical = getattr(self.logical_expression.operator, "physical", None)
if physical is not None:
pre_filter_count = len(new_expressions)
new_expressions = _filter_expressions_by_physical(new_expressions, physical)
if pre_filter_count > 0 and len(new_expressions) == 0:
logger.warning(
f"physical= hint {physical} on {self.logical_expression.operator.logical_op_name()} "
f"filtered out all {pre_filter_count} candidate(s) from {self.rule.get_rule_id()}. "
f"If no other rule produces a match, optimization will fail."
)

new_expressions = [expr for expr in new_expressions if expr.expr_id not in expressions]

# get the costed_full_op_ids from the context (if provided) and compute whether this
Expand Down
Loading