diff --git a/CHANGELOG.md b/CHANGELOG.md index fd85fed3..c6c86809 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,13 +27,11 @@ Changes: * Fix: `StructuralAttack` now respects the `report_individual` flag. Per-record `record_level_results` and `attack_metrics["individual"]` are only populated when the flag is set to `True`, matching the behaviour of `LIRAAttack` and `QMIAAttack`. -* Breaking: `LIRAAttack` per-record output field renamed from `"score"` to - `"member_prob"` to standardise with `QMIAAttack` and `WorstCaseAttack`. Affects - the `individual` dict in `attack_metrics`, the `report.json` payload, and the - externalised `.npz` sidecar key (`individual.score` becomes - `individual.member_prob`). Existing LiRA `report.json` files written before this - release will not be readable by `MetaAttack`'s `use_existing_only` mode and - should be regenerated. +* Feat: `WorstCaseAttack` now accepts a `report_individual` flag. When enabled, each + repetition's metrics dict gains an `"individual"` key holding per-record + `"member_prob"` (the attack classifier's membership probability) and `"member"` + (the ground truth label), matching the per-record output convention used by + `LIRAAttack` and `QMIAAttack`. Arrays are sized to the attack-model test slice. ## Version 1.4.3 (Jan 29, 2026) diff --git a/sacroml/attacks/worst_case_attack.py b/sacroml/attacks/worst_case_attack.py index 119325a4..6e77d32b 100644 --- a/sacroml/attacks/worst_case_attack.py +++ b/sacroml/attacks/worst_case_attack.py @@ -75,6 +75,7 @@ def __init__( search_type: str = "grid", search_n_iter: int = 10, tuning_metric: str | Callable = "AUC", + report_individual: bool = False, ) -> None: """Construct an object to execute a worst case attack. @@ -135,6 +136,14 @@ def __init__( :data:`sacroml.attacks._scorers.SCORERS`, any sklearn scoring string, or a custom callable following the sklearn ``(estimator, X, y)`` protocol. + report_individual : bool + Whether to expose per-record membership probabilities in the + output. When True, each repetition's metrics dict gains an + ``"individual"`` key holding ``"member_prob"`` (the attack + classifier's probability of membership for each test sample) + and ``"member"`` (the ground truth label). The arrays are + sized to the attack-model test slice, not the full target + training set. """ super().__init__(output_dir=output_dir, write_report=write_report) self.n_reps: int = n_reps @@ -167,6 +176,7 @@ def __init__( ) self.tuning_metric = "AUC" self._resolved_tuning_scorer = resolve_scorer("AUC") + self.report_individual: bool = report_individual self.dummy_attack_metrics: list = [] self._tuned_params: dict | None = None self._tuning_info: dict | None = None @@ -547,6 +557,12 @@ def run_attack_reps( y_pred_proba = attack_classifier.predict_proba(mi_test_x) mia_metrics.append(metrics.get_metrics(y_pred_proba, mi_test_y)) + if self.report_individual: + mia_metrics[-1]["individual"] = { + "member_prob": y_pred_proba[:, 1].tolist(), + "member": np.asarray(mi_test_y).tolist(), + } + if self.include_model_correct_feature and train_correct is not None: # Compute the Yeom TPR and FPR yeom_preds = mi_test_x[:, -1] diff --git a/tests/attacks/test_worst_case_attack.py b/tests/attacks/test_worst_case_attack.py index d6229a77..2ac1c09b 100644 --- a/tests/attacks/test_worst_case_attack.py +++ b/tests/attacks/test_worst_case_attack.py @@ -358,3 +358,49 @@ def test_no_tuning_leaves_metadata_clean(common_setup): assert "tuning" not in output["metadata"] assert attack_obj._tuned_params is None assert attack_obj._tuning_info is None + + +def test_wc_report_individual_default_off_omits_individual(common_setup): + """Default report_individual=False: per-rep dicts have no "individual" key.""" + target = common_setup + attack_obj = worst_case_attack.WorstCaseAttack( + n_reps=2, + n_dummy_reps=0, + p_thresh=0.05, + test_prop=0.5, + output_dir="test_output_worstcase", + ) + assert attack_obj.report_individual is False + + output = attack_obj.attack(target) + instances = output["attack_experiment_logger"]["attack_instance_logger"] + for inst in instances.values(): + assert "individual" not in inst + + +def test_wc_report_individual_on_populates_per_rep_member_prob(common_setup): + """Report_individual=True populates per-rep individual member_prob and member.""" + target = common_setup + attack_obj = worst_case_attack.WorstCaseAttack( + n_reps=2, + n_dummy_reps=0, + p_thresh=0.05, + test_prop=0.5, + output_dir="test_output_worstcase", + report_individual=True, + ) + + output = attack_obj.attack(target) + instances = output["attack_experiment_logger"]["attack_instance_logger"] + assert len(instances) == 2 + + for inst in instances.values(): + assert "individual" in inst + ind = inst["individual"] + assert set(ind.keys()) == {"member_prob", "member"} + assert len(ind["member_prob"]) == len(ind["member"]) + assert len(ind["member_prob"]) > 0 + for prob in ind["member_prob"]: + assert 0.0 <= prob <= 1.0 + for label in ind["member"]: + assert label in (0, 1)