Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
280 changes: 280 additions & 0 deletions hackagent/attacks/AdvPrefix/aggregation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
import pandas as pd
from typing import Dict, Any
import logging

logger = logging.getLogger(__name__)

# Map judge type to expected column prefix/name used for aggregation stats
JUDGE_AGG_COLUMN_MAP = {
"nuanced": "eval_nj",
"jailbreakbench": "eval_jb",
"harmbench": "eval_hb",
}

GROUP_KEYS = ["goal", "prefix"]


def _filter_by_nll(df: pd.DataFrame, max_ce_threshold: float | None) -> pd.DataFrame:
"""Filters the DataFrame based on the prefix_nll column and a threshold.

Args:
df: The input DataFrame.
max_ce_threshold: The maximum cross-entropy threshold. Rows with
'prefix_nll' greater than or equal to this will be removed.
If None, no filtering is performed.

Returns:
The filtered DataFrame.
"""
if max_ce_threshold is None:
return df

Check warning on line 30 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L29-L30

Added lines #L29 - L30 were not covered by tests

if "prefix_nll" not in df.columns:
logger.warning(

Check warning on line 33 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L32-L33

Added lines #L32 - L33 were not covered by tests
"Column 'prefix_nll' not found. Skipping NLL filtering in aggregation step."
)
return df

Check warning on line 36 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L36

Added line #L36 was not covered by tests

try:
initial_count = len(df)
filtered_df = df[df["prefix_nll"] < max_ce_threshold]
filtered_count = len(filtered_df)
logger.info(

Check warning on line 42 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L38-L42

Added lines #L38 - L42 were not covered by tests
f"Filtered {initial_count - filtered_count} rows based on prefix_nll >= {max_ce_threshold}"
)
return filtered_df
except Exception as e:
logger.error(f"Error during NLL filtering in aggregation: {e}")
return df

Check warning on line 48 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L45-L48

Added lines #L45 - L48 were not covered by tests


def _get_available_judge_agg_cols(
df: pd.DataFrame, config_judges: list[str]
) -> Dict[str, str]:
"""Identifies available judge aggregation columns in the DataFrame.

Compares columns in the DataFrame against JUDGE_AGG_COLUMN_MAP and logs warnings
if expected columns for judges listed in config_judges are missing.

Args:
df: The input DataFrame to check for judge columns.
config_judges: A list of judge types that were expected to be run.

Returns:
A dictionary mapping judge type (str) to its corresponding column name (str)
found in the DataFrame.
"""
available_judges_agg_cols = {}
for judge_type, col_name in JUDGE_AGG_COLUMN_MAP.items():
if col_name in df.columns:
available_judges_agg_cols[judge_type] = col_name
elif judge_type in config_judges:
logger.warning(

Check warning on line 72 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L67-L72

Added lines #L67 - L72 were not covered by tests
f"Expected aggregation column '{col_name}' for judge '{judge_type}' not found in the dataframe."
)
return available_judges_agg_cols

Check warning on line 75 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L75

Added line #L75 was not covered by tests


def _build_agg_funcs(
base_agg_funcs: Dict[str, pd.NamedAgg],
df: pd.DataFrame,
available_judges_agg_cols: Dict[str, str],
) -> Dict[str, pd.NamedAgg]:
"""Builds a dictionary of aggregation functions for pandas groupby.agg.

Starts with base aggregation functions and adds specific aggregations (mean, count, size)
for available judge columns. Handles numeric conversion and potential errors.

Args:
base_agg_funcs: A dictionary of base aggregation functions (NamedAgg objects).
df: The DataFrame to be aggregated (used to check column properties).
available_judges_agg_cols: A dictionary mapping judge types to their column names.

Returns:
A dictionary of aggregation functions (NamedAgg objects) to be used in .agg().
"""
agg_funcs = base_agg_funcs.copy()
for judge_type, col_name in available_judges_agg_cols.items():
try:

Check warning on line 98 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L96-L98

Added lines #L96 - L98 were not covered by tests
# Ensure the column is numeric before calculating mean
# This modification will be applied to a copy, not the original df passed to `execute`
# if the original df needs to be modified, it should be done explicitly.
numeric_col = pd.to_numeric(df[col_name], errors="coerce")
if (

Check warning on line 103 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L102-L103

Added lines #L102 - L103 were not covered by tests
numeric_col.notna().any()
): # Check if there are any numeric values after coercion
agg_funcs[f"{col_name}_mean"] = pd.NamedAgg(

Check warning on line 106 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L106

Added line #L106 was not covered by tests
column=col_name, aggfunc="mean"
)
agg_funcs[f"{col_name}_count"] = pd.NamedAgg(

Check warning on line 109 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L109

Added line #L109 was not covered by tests
column=col_name, aggfunc="count"
)
logger.debug(

Check warning on line 112 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L112

Added line #L112 was not covered by tests
f"Added mean/count aggregation for numeric column '{col_name}'"
)
else:
logger.warning(

Check warning on line 116 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L116

Added line #L116 was not covered by tests
f"Column '{col_name}' for judge '{judge_type}' contains no numeric data after coercion. Skipping mean/count aggregation."
)
# Optionally, still add a size aggregation if mean/count are skipped
agg_funcs[f"{col_name}_size"] = pd.NamedAgg(

Check warning on line 120 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L120

Added line #L120 was not covered by tests
column=col_name, aggfunc="size"
)

except KeyError:
logger.warning(

Check warning on line 125 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L124-L125

Added lines #L124 - L125 were not covered by tests
f"Column '{col_name}' unexpectedly missing during aggregation setup for judge '{judge_type}'. Skipping."
)
except Exception as e:
logger.error(

Check warning on line 129 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L128-L129

Added lines #L128 - L129 were not covered by tests
f"Could not process column '{col_name}' for aggregation for judge '{judge_type}'. Skipping mean/count. Error: {e}"
)
agg_funcs[f"{col_name}_size"] = pd.NamedAgg(column=col_name, aggfunc="size")
return agg_funcs

Check warning on line 133 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L132-L133

Added lines #L132 - L133 were not covered by tests


def execute(
input_df: pd.DataFrame, config: Dict[str, Any], run_dir: str
) -> pd.DataFrame:
"""
Aggregate evaluation results from different judges using the input DataFrame.

This function takes a DataFrame of evaluation results, filters it based on
a cross-entropy threshold (if specified in the config), identifies available
judge scores, and then groups by 'goal' and 'prefix' to calculate aggregate
statistics like mean and count for each judge, along with other metadata.

Args:
input_df: The DataFrame containing evaluation results. Expected to have columns
for 'goal', 'prefix', and various judge scores (e.g., 'eval_nj').
config: A dictionary containing configuration parameters, such as 'max_ce'
for NLL filtering and a list of 'judges' that were expected to run.
run_dir: The directory path for the current run (currently unused in this function
but part of the expected signature).

Returns:
A pandas DataFrame with aggregated results. Each row represents a unique
'goal' and 'prefix' combination, with columns for aggregated scores and counts.
Returns the unaggregated DataFrame (or an empty one with expected columns)
if critical errors occur or if the input is empty.
"""
logger.info("Executing Step 8: Aggregating evaluation results")

Check warning on line 161 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L161

Added line #L161 was not covered by tests

if input_df.empty:
logger.warning("Step 8 received an empty DataFrame. Skipping aggregation.")
cols = GROUP_KEYS + [

Check warning on line 165 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L163-L165

Added lines #L163 - L165 were not covered by tests
"prefix_nll",
"model_name",
"meta_prefix",
"temperature",
"n_eval_samples",
]
for _, col_base in JUDGE_AGG_COLUMN_MAP.items():
cols.extend([f"{col_base}_mean", f"{col_base}_count"])
return pd.DataFrame(columns=cols)

Check warning on line 174 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L172-L174

Added lines #L172 - L174 were not covered by tests

analysis_df = input_df.copy()

Check warning on line 176 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L176

Added line #L176 was not covered by tests

max_ce_threshold = config.get("max_ce")
if max_ce_threshold is not None:
try:
max_ce_threshold = float(max_ce_threshold)
except ValueError:
logger.warning(

Check warning on line 183 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L178-L183

Added lines #L178 - L183 were not covered by tests
f"'max_ce' value '{max_ce_threshold}' is not a valid float. Skipping NLL filtering."
)
max_ce_threshold = None
analysis_df = _filter_by_nll(analysis_df, max_ce_threshold)

Check warning on line 187 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L186-L187

Added lines #L186 - L187 were not covered by tests

config_judges = config.get("judges", [])
available_judges_agg_cols = _get_available_judge_agg_cols(

Check warning on line 190 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L189-L190

Added lines #L189 - L190 were not covered by tests
analysis_df, config_judges
)

if not available_judges_agg_cols:
logger.error(

Check warning on line 195 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L194-L195

Added lines #L194 - L195 were not covered by tests
"No recognized evaluation result columns found for aggregation. Check step 7 output."
)
return analysis_df

Check warning on line 198 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L198

Added line #L198 was not covered by tests

logger.info(

Check warning on line 200 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L200

Added line #L200 was not covered by tests
f"Found aggregation columns for judges: {list(available_judges_agg_cols.keys())}"
)

if not all(key in analysis_df.columns for key in GROUP_KEYS):
missing_keys = [key for key in GROUP_KEYS if key not in analysis_df.columns]
logger.error(

Check warning on line 206 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L204-L206

Added lines #L204 - L206 were not covered by tests
f"Missing required grouping keys for aggregation: {missing_keys}. Cannot aggregate."
)
return analysis_df

Check warning on line 209 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L209

Added line #L209 was not covered by tests

base_agg_funcs = {

Check warning on line 211 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L211

Added line #L211 was not covered by tests
"prefix_nll": pd.NamedAgg(column="prefix_nll", aggfunc="first"),
"model_name": pd.NamedAgg(column="model_name", aggfunc="first"),
"meta_prefix": pd.NamedAgg(column="meta_prefix", aggfunc="first"),
"temperature": pd.NamedAgg(column="temperature", aggfunc="first"),
"n_eval_samples": pd.NamedAgg(column=GROUP_KEYS[0], aggfunc="size"),
}

# Create a copy of analysis_df for modifications specific to aggregation setup
# to avoid SettingWithCopyWarning if _build_agg_funcs modifies it.
# The numeric conversion is now inside _build_agg_funcs and operates on a temporary series.
agg_funcs_to_use = _build_agg_funcs(

Check warning on line 222 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L222

Added line #L222 was not covered by tests
base_agg_funcs, analysis_df.copy(), available_judges_agg_cols
)

# Ensure all columns used in NamedAgg exist in analysis_df before aggregation
for agg_name, named_agg in agg_funcs_to_use.items():
if named_agg.column not in analysis_df.columns:
logger.warning(

Check warning on line 229 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L227-L229

Added lines #L227 - L229 were not covered by tests
f"Column '{named_agg.column}' for aggregation '{agg_name}' not found in DataFrame. Removing this aggregation."
)
# We need to remove this from the dictionary to avoid error during .agg()
# This is tricky because we are iterating over it.
# A better approach might be to rebuild the dict or check before adding.
# For now, let's rely on the checks within _build_agg_funcs and assume
# base_agg_funcs columns are either present or their absence is acceptable (e.g. 'first' on a missing col yields NaT/NaN)

try:

Check warning on line 238 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L238

Added line #L238 was not covered by tests
# Filter out aggregations whose columns are not in analysis_df, except for 'size' which can operate on any column.
final_agg_funcs = {

Check warning on line 240 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L240

Added line #L240 was not covered by tests
name: agg
for name, agg in agg_funcs_to_use.items()
if agg.column in analysis_df.columns or agg.aggfunc == "size"
}

# Also ensure all columns in GROUP_KEYS are present
if not all(key in analysis_df.columns for key in GROUP_KEYS):
present_keys = [key for key in GROUP_KEYS if key in analysis_df.columns]
if not present_keys:
logger.error(

Check warning on line 250 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L247-L250

Added lines #L247 - L250 were not covered by tests
"None of the GROUP_KEYS are present in the DataFrame. Cannot group."
)
return analysis_df # Or raise an error
logger.warning(

Check warning on line 254 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L253-L254

Added lines #L253 - L254 were not covered by tests
f"Not all GROUP_KEYS are present. Grouping by available keys: {present_keys}"
)
current_group_keys = present_keys

Check warning on line 257 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L257

Added line #L257 was not covered by tests
else:
current_group_keys = GROUP_KEYS

Check warning on line 259 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L259

Added line #L259 was not covered by tests

if not final_agg_funcs:
logger.error(

Check warning on line 262 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L261-L262

Added lines #L261 - L262 were not covered by tests
"No valid aggregation functions remaining after column checks. Cannot aggregate."
)
return analysis_df

Check warning on line 265 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L265

Added line #L265 was not covered by tests

grouped = analysis_df.groupby(current_group_keys, observed=False, dropna=False)
aggregated_df = grouped.agg(**final_agg_funcs)
aggregated_df = aggregated_df.reset_index()
except Exception as e:
logger.error(

Check warning on line 271 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L267-L271

Added lines #L267 - L271 were not covered by tests
f"Error during aggregation: {e}. Check aggregation functions and column types."
)
return analysis_df

Check warning on line 274 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L274

Added line #L274 was not covered by tests

logger.info(

Check warning on line 276 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L276

Added line #L276 was not covered by tests
f"Step 8 complete. Aggregated {len(aggregated_df)} prefix results. CSV will be saved by the main pipeline."
)

return aggregated_df

Check warning on line 280 in hackagent/attacks/AdvPrefix/aggregation.py

View check run for this annotation

Codecov / codecov/patch

hackagent/attacks/AdvPrefix/aggregation.py#L280

Added line #L280 was not covered by tests
Loading