Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,11 @@ Changes:
* Fix: `StructuralAttack` now respects the `report_individual` flag. Per-record
`record_level_results` and `attack_metrics["individual"]` are only populated when the
flag is set to `True`, matching the behaviour of `LIRAAttack` and `QMIAAttack`.
* Breaking: `LIRAAttack` per-record output field renamed from `"score"` to
`"member_prob"` to standardise with `QMIAAttack` and `WorstCaseAttack`. Affects
the `individual` dict in `attack_metrics`, the `report.json` payload, and the
externalised `.npz` sidecar key (`individual.score` becomes
`individual.member_prob`). Existing LiRA `report.json` files written before this
release will not be readable by `MetaAttack`'s `use_existing_only` mode and
should be regenerated.
* Feat: `WorstCaseAttack` now accepts a `report_individual` flag. When enabled, each
repetition's metrics dict gains an `"individual"` key holding per-record
`"member_prob"` (the attack classifier's membership probability) and `"member"`
(the ground truth label), matching the per-record output convention used by
`LIRAAttack` and `QMIAAttack`. Arrays are sized to the attack-model test slice.

## Version 1.4.3 (Jan 29, 2026)

Expand Down
16 changes: 16 additions & 0 deletions sacroml/attacks/worst_case_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
search_type: str = "grid",
search_n_iter: int = 10,
tuning_metric: str | Callable = "AUC",
report_individual: bool = False,
) -> None:
"""Construct an object to execute a worst case attack.

Expand Down Expand Up @@ -135,6 +136,14 @@ def __init__(
:data:`sacroml.attacks._scorers.SCORERS`, any sklearn scoring
string, or a custom callable following the sklearn
``(estimator, X, y)`` protocol.
report_individual : bool
Whether to expose per-record membership probabilities in the
output. When True, each repetition's metrics dict gains an
``"individual"`` key holding ``"member_prob"`` (the attack
classifier's probability of membership for each test sample)
and ``"member"`` (the ground truth label). The arrays are
sized to the attack-model test slice, not the full target
training set.
"""
super().__init__(output_dir=output_dir, write_report=write_report)
self.n_reps: int = n_reps
Expand Down Expand Up @@ -167,6 +176,7 @@ def __init__(
)
self.tuning_metric = "AUC"
self._resolved_tuning_scorer = resolve_scorer("AUC")
self.report_individual: bool = report_individual
self.dummy_attack_metrics: list = []
self._tuned_params: dict | None = None
self._tuning_info: dict | None = None
Expand Down Expand Up @@ -547,6 +557,12 @@ def run_attack_reps(
y_pred_proba = attack_classifier.predict_proba(mi_test_x)
mia_metrics.append(metrics.get_metrics(y_pred_proba, mi_test_y))

if self.report_individual:
mia_metrics[-1]["individual"] = {
"member_prob": y_pred_proba[:, 1].tolist(),
"member": np.asarray(mi_test_y).tolist(),
}

if self.include_model_correct_feature and train_correct is not None:
# Compute the Yeom TPR and FPR
yeom_preds = mi_test_x[:, -1]
Expand Down
46 changes: 46 additions & 0 deletions tests/attacks/test_worst_case_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,49 @@ def test_no_tuning_leaves_metadata_clean(common_setup):
assert "tuning" not in output["metadata"]
assert attack_obj._tuned_params is None
assert attack_obj._tuning_info is None


def test_wc_report_individual_default_off_omits_individual(common_setup):
"""Default report_individual=False: per-rep dicts have no "individual" key."""
target = common_setup
attack_obj = worst_case_attack.WorstCaseAttack(
n_reps=2,
n_dummy_reps=0,
p_thresh=0.05,
test_prop=0.5,
output_dir="test_output_worstcase",
)
assert attack_obj.report_individual is False

output = attack_obj.attack(target)
instances = output["attack_experiment_logger"]["attack_instance_logger"]
for inst in instances.values():
assert "individual" not in inst


def test_wc_report_individual_on_populates_per_rep_member_prob(common_setup):
"""Report_individual=True populates per-rep individual member_prob and member."""
target = common_setup
attack_obj = worst_case_attack.WorstCaseAttack(
n_reps=2,
n_dummy_reps=0,
p_thresh=0.05,
test_prop=0.5,
output_dir="test_output_worstcase",
report_individual=True,
)

output = attack_obj.attack(target)
instances = output["attack_experiment_logger"]["attack_instance_logger"]
assert len(instances) == 2

for inst in instances.values():
assert "individual" in inst
ind = inst["individual"]
assert set(ind.keys()) == {"member_prob", "member"}
assert len(ind["member_prob"]) == len(ind["member"])
assert len(ind["member_prob"]) > 0
for prob in ind["member_prob"]:
assert 0.0 <= prob <= 1.0
for label in ind["member"]:
assert label in (0, 1)