Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
4599422
tasks 1 and 2 implementation
Nov 29, 2025
b643be0
code for task 3
Dec 2, 2025
b65d5de
tests for Semantic GroupBy
Dec 2, 2025
f07f611
Removed back compatibility with GroupBySig
kepler11c Jan 4, 2026
9351efc
restored field check in ApplyGroupByOp
kepler11c Jan 4, 2026
8393b25
Simplied aggregation logic in Semantic GroupBy's call
kepler11c Jan 4, 2026
fdecc47
Added Implementation Rule for Semantic GroupBy
kepler11c Jan 4, 2026
12ba5f1
Updated implementation rule and added distinction between semantic an…
kepler11c Jan 4, 2026
dd9dd0b
New Implementation Rule for Non Semantic GroupBys
kepler11c Jan 6, 2026
ebe125d
Deleted get_fields_to_generate from SemanticGroupByOp
kepler11c Jan 6, 2026
ba1ec68
updated prompt strategy in SemanticGroupBy's implementation rule
kepler11c Jan 6, 2026
c45312b
SemanticGroupByOp's call uses output_schema to set output_field_names
kepler11c Jan 6, 2026
d6ba70d
updated schema initialization in test_semantic_groupby
kepler11c Jan 6, 2026
b1d8861
updated total cost parameter
kepler11c Jan 11, 2026
ff2a5c3
Added output schema during groupByAggregate creation
kepler11c Jan 11, 2026
142be5b
Created schema from fields helper for groupBy functions
kepler11c Jan 11, 2026
1f4d870
updated agg_field_name align with previous changes
kepler11c Jan 11, 2026
f9b4631
Updated input parameters in groupby schema to field helper
kepler11c Jan 11, 2026
1fe6063
minor
mdr223 Jan 13, 2026
197564c
formatted queries for wildlife, ecommerce and amazon reviews
kepler11c Jan 29, 2026
c1dfee9
formatted queries - movies dataset
kepler11c Jan 29, 2026
094d14d
PZ program for movies query 1 + added functionality to handle usd per…
kepler11c Feb 9, 2026
3250c29
testing
kepler11c Feb 16, 2026
593303e
updated sem_groupBy
kepler11c Feb 24, 2026
72a5024
Queries 1 through 5
kepler11c Mar 3, 2026
0560038
checking in sem gby changes before refactor
kepler11c Mar 10, 2026
b925ef0
resolved conflict
kepler11c Mar 10, 2026
3bef618
merging in staging dev
kepler11c Mar 10, 2026
e1478e6
updated __call__ structure for SemanticGroupByOp
kepler11c Mar 15, 2026
509b4dd
WIP: updated Semantic group-by implementation
kepler11c Mar 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 269 additions & 9 deletions src/palimpzest/core/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,34 +577,293 @@ def groupby(self, gby_fields, agg_fields, agg_funcs) -> Dataset:
operator = GroupByAggregate(input_schema=self.schema, output_schema=output_schema, gby_fields=gby_fields, agg_fields=agg_fields, agg_funcs=agg_funcs)
return Dataset(sources=[self], operator=operator, schema=output_schema)

def sem_groupby(self, gby_fields: list[str], agg_fields: list[str], agg_funcs: list[str]) -> Dataset:
def group_by(
self,
group_cols: list[str] | list[dict],
agg_func: Callable,
output_col: str,
) -> Dataset:
"""
Apply a semantic group by operation with detailed field specifications.

Args:
group_cols: List of group-by field specifications. Each can be:
- A string (field name): Uses default grouping behavior
- A dict with keys: 'name', 'desc', 'type', and optionally 'model'
agg_func: Aggregation function to apply (e.g., count, sum, average)
output_col: Name of the output aggregation column

Example:
ds.group_by(
group_cols=[
{'name': 'era', 'desc': 'Era bucket: pre-2000, 2000s, 2010s, or 2020s', 'type': str}
],
agg_func=count_reviews,
output_col="review_count"
)
"""
# Normalize group_cols to list of dicts
normalized_group_cols = []
for col in group_cols:
if isinstance(col, str):
normalized_group_cols.append({
'name': col,
'desc': f'Group by {col}',
'type': str
})
elif isinstance(col, dict):
normalized_group_cols.append(col)
else:
raise ValueError("group_cols must be a list of strings or dicts")

# Extract field names for the logical operator
gby_field_names = [col['name'] for col in normalized_group_cols]

# Infer aggregation function name from the callable
# For now, we'll use 'count' as default - user can extend this
agg_func_name = agg_func.__name__ if hasattr(agg_func, '__name__') else 'count'
if 'count' in agg_func_name.lower():
agg_func_str = 'count'
else:
# Default to custom function - will need to be handled
agg_func_str = 'count' # fallback

# Create output schema
output_schema = create_groupby_schema_from_fields(gby_field_names, [output_col])

# Create logical operator
operator = GroupByAggregate(
input_schema=self.schema,
is_semantic=True,
output_schema=output_schema,
gby_fields=normalized_group_cols, # Pass full dict specifications
agg_fields=[output_col],
agg_funcs=[agg_func_str]
)

return Dataset(sources=[self], operator=operator, schema=output_schema)

def sem_groupby(self, gby_fields: list[str] | list[dict], agg_fields: list[str] | list[dict], agg_funcs: list[str]) -> Dataset:
"""
Apply a semantic group by operation to this set using an LLM. This operator groups records
by the specified `gby_fields` and applies the `agg_funcs` to the `agg_fields` for each group.

Args:
gby_fields: List of field names to group by (e.g., ['complaint'])
agg_fields: List of field names to aggregate (e.g., ['contents'])
gby_fields: List of field specifications to group by. Each can be:
- A string (field name): Uses default grouping behavior
- A dict with keys: 'name', 'desc', 'type', and optionally 'model'
agg_fields: List of field specifications to aggregate. Each can be:
- A string (field name): Uses default aggregation behavior
- A dict with keys: 'name', 'desc', 'type', and optionally 'model'
agg_funcs: List of aggregation functions to apply (e.g., ['count'])

Example:
ds = pz.TextFileDataset(id="reviews", dir="product-reviews/")
ds = ds.sem_groupby(gby_fields=['complaint'], agg_fields=['contents'], agg_funcs=['count'])
ds = ds.sem_groupby(
gby_fields=[{'name': 'complaint', 'desc': 'Type of complaint', 'type': str}],
agg_fields=['contents'],
agg_funcs=['count']
)
"""
output_schema = create_groupby_schema_from_fields(gby_fields, agg_fields)
# Normalize gby_fields to list of dicts
normalized_gby_fields = []
for field in gby_fields:
if isinstance(field, str):
normalized_gby_fields.append({
'name': field,
'desc': f'Group by {field}',
'type': str
})
elif isinstance(field, dict):
normalized_gby_fields.append(field)
else:
raise ValueError("gby_fields must be a list of strings or dicts")

# Normalize agg_fields to list of dicts
normalized_agg_fields = []
for field in agg_fields:
if isinstance(field, str):
normalized_agg_fields.append({
'name': field,
'desc': f'Aggregate {field}',
'type': str
})
elif isinstance(field, dict):
normalized_agg_fields.append(field)
else:
raise ValueError("agg_fields must be a list of strings or dicts")

# Extract field names for schema creation
gby_field_names = [f['name'] for f in normalized_gby_fields]
agg_field_names = [f['name'] for f in normalized_agg_fields]

output_schema = create_groupby_schema_from_fields(gby_field_names, agg_field_names)

# Create logical operator with direct parameters (no GroupBySig)
# Create logical operator with full dict specifications
operator = GroupByAggregate(
input_schema=self.schema,
is_semantic=True,
output_schema=output_schema,
gby_fields=gby_fields,
agg_fields=agg_fields,
gby_fields=normalized_gby_fields,
agg_fields=normalized_agg_fields,
agg_funcs=agg_funcs
)

return Dataset(sources=[self], operator=operator, schema=output_schema)

def hierarchical_groupby(
self,
groupby_fields: list[list[str]],
agg_fields: list[list[str]],
agg_funcs: list[list[str]],
) -> dict:
"""
Perform hierarchical (nested) exact groupby operations across multiple levels.

At each level except the last, records are partitioned by the groupby fields
without aggregation; the last level applies full aggregation.

Args:
groupby_fields: List of lists of field names to group by at each level.
agg_fields: List of lists of field names to aggregate at each level.
agg_funcs: List of lists of aggregation function names at each level.

Returns:
A DataRecordSet for a single level, or a nested dict
``{group_key: <result_for_next_level>}`` for multiple levels.
"""
from palimpzest.core.lib.schemas import create_groupby_schema_from_fields
from palimpzest.query.operators.aggregate import ApplyGroupByOp

assert len(groupby_fields) == len(agg_fields) == len(agg_funcs), \
"groupby_fields, agg_fields, and agg_funcs must all have the same length"

result = self.run()
candidates = result.data_records

def run_level(candidates, level):
gby_names = groupby_fields[level]
agg_names = agg_fields[level]
funcs = agg_funcs[level]
output_schema = create_groupby_schema_from_fields(gby_names, agg_names)
op = ApplyGroupByOp(
gby_fields=gby_names,
agg_fields=agg_names,
agg_funcs=funcs,
output_schema=output_schema,
input_schema=self.schema,
)
if level == len(groupby_fields) - 1:
return op(candidates)
# Intermediate level: partition candidates by exact field values
outer_groups = {}
for candidate in candidates:
key = tuple(getattr(candidate, f, None) for f in gby_names)
outer_groups.setdefault(key, []).append(candidate)
return {key: run_level(grp, level + 1) for key, grp in outer_groups.items()}

return run_level(candidates, 0)

def hierarchical_sem_groupby(
self,
groupby_fields: list[list[str | dict]],
agg_fields: list[list[str | dict]],
agg_funcs: list[list[str]],
model=None,
prompt_strategy=None,
reasoning_effort=None,
) -> dict:
"""
Perform hierarchical (nested) semantic groupby operations using LLMs.

At each intermediate level the LLM assigns group labels to the original records
(without aggregation) so that inner levels can operate on the same raw records.
The final level runs a full semantic groupby with aggregation.

Args:
groupby_fields: List of lists of field specs (str or dict with name/desc/type) per level.
agg_fields: List of lists of field specs to aggregate per level.
agg_funcs: List of lists of aggregation function names per level.
model: Optional LLM model override.
prompt_strategy: Optional prompt strategy override.
reasoning_effort: Optional reasoning effort override.

Returns:
A DataRecordSet for a single level, or a nested dict
``{group_key: <result_for_next_level>}`` for multiple levels.
"""
from palimpzest.constants import Model, PromptStrategy
from palimpzest.core.lib.schemas import create_groupby_schema_from_fields
from palimpzest.query.operators.aggregate import SemanticGroupByOp

assert len(groupby_fields) == len(agg_fields) == len(agg_funcs), \
"groupby_fields, agg_fields, and agg_funcs must all have the same length"

# Default to GPT-4o if no model specified; sem_groupby requires an explicit model
# because hierarchical_sem_groupby bypasses the query optimizer / policy system.
_model = model if model is not None else Model.GPT_4o
_prompt_strategy = prompt_strategy if prompt_strategy is not None else PromptStrategy.AGG

from palimpzest.core.models import GenerationStats

result = self.run()
candidates = result.data_records

# Accumulate GenerationStats across all levels so callers can track
# total cost / token usage for the entire hierarchical operation.
accumulated_stats = GenerationStats()

def normalize_fields(fields):
out = []
for f in fields:
if isinstance(f, str):
out.append({'name': f, 'desc': f'Group by {f}', 'type': str})
else:
out.append(f)
return out

def run_level(candidates, level):
nonlocal accumulated_stats
gby_specs = normalize_fields(groupby_fields[level])
agg_specs = normalize_fields(agg_fields[level])
funcs = agg_funcs[level]
gby_names = [s['name'] for s in gby_specs]
agg_names = [s['name'] for s in agg_specs]
output_schema = create_groupby_schema_from_fields(gby_names, agg_names)
op = SemanticGroupByOp(
gby_fields=gby_specs,
agg_fields=agg_specs,
agg_funcs=funcs,
model=_model,
prompt_strategy=_prompt_strategy,
reasoning_effort=reasoning_effort,
output_schema=output_schema,
input_schema=self.schema,
)
if level == len(groupby_fields) - 1:
# Final level: full groupby with aggregation.
# Extract per-group RecordOpStats and fold into accumulated_stats.
dataset_result = op(candidates)
for ros in dataset_result.record_op_stats:
accumulated_stats.total_input_tokens += ros.total_input_tokens
accumulated_stats.total_output_tokens += ros.total_output_tokens
accumulated_stats.total_input_cost += ros.total_input_cost
accumulated_stats.total_output_cost += ros.total_output_cost
accumulated_stats.llm_call_duration_secs += ros.llm_call_duration_secs
return dataset_result
# Intermediate level: LLM assigns group labels without aggregation.
# Capture and accumulate the GenerationStats that were previously discarded.
group_labels, gen_stats = op._assign_groups_llm(candidates)
accumulated_stats += gen_stats
outer_groups = {}
for candidate, label in zip(candidates, group_labels):
key = (label,) if not isinstance(label, tuple) else label
outer_groups.setdefault(key, []).append(candidate)
return {key: run_level(grp, level + 1) for key, grp in outer_groups.items()}

nested_result = run_level(candidates, 0)
return nested_result, accumulated_stats

def sem_agg(self, col: dict | type[BaseModel], agg: str, depends_on: str | list[str] | None = None) -> Dataset:
"""
Apply a semantic aggregation to this set. The `agg` string will be applied using an LLM
Expand Down Expand Up @@ -696,6 +955,7 @@ def run(self, config: QueryProcessorConfig | None = None, **kwargs):
"""Invoke the QueryProcessor to execute the query. `kwargs` will be applied to the QueryProcessorConfig."""
# TODO: this import currently needs to be here to avoid a circular import; we should fix this in a subsequent PR
from palimpzest.query.processor.query_processor_factory import QueryProcessorFactory
print("Running Query Processor...")

# as syntactic sugar, we will allow some keyword arguments to parameterize our policies
policy = construct_policy_from_kwargs(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions src/palimpzest/query/generators/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,9 @@ def __call__(self, candidate: DataRecord | list[DataRecord], fields: dict[str, F
logger.debug(f"PROMPT:\n{prompt}")
logger.debug(Fore.GREEN + f"{completion_text}\n" + Style.RESET_ALL)

print(f"PROMPT:\n{prompt}")
print(Fore.GREEN + f"{completion_text}\n" + Style.RESET_ALL)

# parse reasoning
reasoning = None
try:
Expand Down
Loading