diff --git a/docs/source/attacks/report.rst b/docs/source/attacks/report.rst index 78bbd848..4fd1336c 100644 --- a/docs/source/attacks/report.rst +++ b/docs/source/attacks/report.rst @@ -3,3 +3,47 @@ Report .. automodule:: sacroml.attacks.report :members: + +Converting legacy reports +------------------------- + +Older SACRO-ML versions write a flat ``report.json`` keyed by +``"_"``. The new reporting pipeline expects a nested, +catalog-enriched document validated by the bundled JSON schema +(``sacroml/reporting/sacroml_attack_report.schema.json``). + +Use the ``convert-report`` command to upgrade an existing report in place +without re-running any attacks: + +.. prompt:: bash + + sacroml convert-report report.json report_new.json + +The converter: + +* wraps the legacy experiments under a top-level ``attacks`` key; +* injects the four human-readable catalogs (``metric_catalog``, + ``parameter_catalog``, ``attack_category_catalog``, ``attack_catalog``) + from the bundled common definitions in + ``sacroml/reporting/catalog_definitions.json``; +* **warns** when a metric, parameter, attack or attack category observed in + the report is not present in the catalogs (conversion still succeeds); and +* validates the result against the JSON schema. + +Two small normalisations keep real-world legacy reports schema-valid: a +placeholder ``sacroml_version`` is injected when missing, and an empty +``attack_instance_logger`` is supplied for instance-less attacks (e.g. +structural attacks). Anything else that does not match the schema -- such as a +non-scalar metric value or a stray metadata key -- is reported as a schema +error rather than silently rewritten. + +Curve-valued arrays (``fpr`` / ``tpr`` / ``roc_thresh``) are passed through +unchanged. ``roc_thresh`` legitimately starts with ``null``, which the schema +does not yet permit, so such violations are reported as notices rather than +errors. Pass ``--no-validate`` to skip schema validation. + +To extend the catalogs with site-specific metrics, parameters or attacks, +edit ``sacroml/reporting/catalog_definitions.json``. + +.. automodule:: sacroml.reporting.convert + :members: convert_report, convert_report_file, ConversionResult diff --git a/pyproject.toml b/pyproject.toml index 9e69c592..2c24f7a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ dependencies = [ "dictdiffer", "torch", "prompt-toolkit", + "jsonschema", ] [tool.setuptools.dynamic] @@ -87,6 +88,10 @@ packages = {find = {exclude = ["docs*", "examples*", "tests*", "user_stories*"]} [tool.setuptools.package-data] "sacroml.safemodel" = ["rules.json"] +"sacroml.reporting" = [ + "sacroml_attack_report.schema.json", + "catalog_definitions.json", +] [tool.ruff] indent-width = 4 diff --git a/sacroml/main.py b/sacroml/main.py index 2ee5434f..59e4c0eb 100644 --- a/sacroml/main.py +++ b/sacroml/main.py @@ -3,11 +3,58 @@ from __future__ import annotations import argparse +import json import os +import sys from sacroml.attacks.factory import run_attacks from sacroml.config.attack import prompt_for_attack from sacroml.config.target import prompt_for_target +from sacroml.reporting import convert_report_file + + +def _run_convert_report(args: argparse.Namespace) -> int: + """Convert a legacy report.json into the new report format. + + Returns + ------- + int + Process exit code (0 on success, 1 if the output is schema-invalid). + """ + if not os.path.isfile(args.input): + print(f"Input report not found: {args.input}") + return 1 + + try: + result = convert_report_file( + args.input, args.output, validate=not args.no_validate + ) + except json.JSONDecodeError as exc: + print(f"Could not parse '{args.input}' as JSON: {exc}") + return 1 + except OSError as exc: + print(f"Could not read or write report: {exc}") + return 1 + + print(f"Converted '{args.input}' -> '{args.output}'") + for dim, missing in result.coverage.items(): + if missing: + print(f" {len(missing)} uncatalogued {dim.replace('_', ' ')}: {missing}") + if result.curve_warnings: + print( + f" {len(result.curve_warnings)} curve-array notice(s); " + "fpr/tpr/roc_thresh passed through unchanged." + ) + if args.no_validate: + return 0 + if result.is_valid: + print("Converted report is schema-valid.") + return 0 + n_errors = len(result.schema_errors) + print(f"Converted report is NOT schema-valid ({n_errors} error(s)):") + for err in result.schema_errors: + print(f" - {err}") + return 1 def main() -> None: @@ -26,6 +73,18 @@ def main() -> None: subparsers.add_parser("gen-target", help="Generate Target YAML config") subparsers.add_parser("gen-attack", help="Generate Attack YAML config") + convert = subparsers.add_parser( + "convert-report", + help="Convert a legacy report.json into the new report format", + ) + convert.add_argument("input", type=str, help="Path to the legacy report.json") + convert.add_argument("output", type=str, help="Path to write the new report") + convert.add_argument( + "--no-validate", + action="store_true", + help="Skip JSON schema validation of the converted report", + ) + args = parser.parse_args() if args.cmd == "run": @@ -37,6 +96,8 @@ def main() -> None: prompt_for_target() elif args.cmd == "gen-attack": prompt_for_attack() + elif args.cmd == "convert-report": + sys.exit(_run_convert_report(args)) if __name__ == "__main__": # pragma:no cover diff --git a/sacroml/reporting/__init__.py b/sacroml/reporting/__init__.py new file mode 100644 index 00000000..8a0b8877 --- /dev/null +++ b/sacroml/reporting/__init__.py @@ -0,0 +1,21 @@ +"""Reporting utilities for SACRO-ML attack reports. + +This subpackage contains the canonical attack-report JSON schema, a set of +common catalog definitions, and tooling to convert legacy ``report.json`` +files produced by older SACRO-ML versions into the new nested, +catalog-enriched report format. +""" + +from __future__ import annotations + +from sacroml.reporting.convert import ( + ConversionResult, + convert_report, + convert_report_file, +) + +__all__ = [ + "ConversionResult", + "convert_report", + "convert_report_file", +] diff --git a/sacroml/reporting/catalog_definitions.json b/sacroml/reporting/catalog_definitions.json new file mode 100644 index 00000000..dcbcb244 --- /dev/null +++ b/sacroml/reporting/catalog_definitions.json @@ -0,0 +1,780 @@ +{ + "metric_catalog": { + "version": "metrics_sacroml_v1", + "metrics": { + "AUC": { + "label": "Area Under ROC Curve", + "description": "Measures the ability of the attack to distinguish members from non-members.", + "units": null, + "higher_is_better": true, + "category": "performance", + "typical_range": [ + 0.0, + 1.0 + ], + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": null + }, + "TPR": { + "label": "True Positive Rate", + "description": "Fraction of true members correctly identified.", + "units": null, + "higher_is_better": true, + "category": "performance", + "typical_range": [ + 0.0, + 1.0 + ], + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": null + }, + "FPR": { + "label": "False Positive Rate", + "description": "Fraction of non-members incorrectly identified as members.", + "units": null, + "higher_is_better": false, + "category": "performance", + "typical_range": [ + 0.0, + 1.0 + ], + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": null + }, + "risk_score": { + "label": "Aggregate Risk Score", + "description": "Summary privacy risk score aggregated from multiple attacks.", + "units": null, + "higher_is_better": true, + "category": "summary", + "typical_range": [ + 0.0, + 1.0 + ], + "allowed_aggregations": [ + "mean" + ], + "notes": "Interpretation depends on aggregation method." + }, + "Advantage": { + "label": "Attack advantage", + "description": "Difference between true positive rate and false positive rate.", + "units": null, + "higher_is_better": true, + "category": "performance", + "typical_range": null, + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": null + }, + "F1score": { + "label": "F1 score", + "description": "Harmonic mean of precision and recall.", + "units": null, + "higher_is_better": true, + "category": "performance", + "typical_range": [ + 0.0, + 1.0 + ], + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": null + }, + "ACC": { + "label": "Accuracy", + "description": "Overall classification accuracy.", + "units": null, + "higher_is_better": true, + "category": "performance", + "typical_range": [ + 0.0, + 1.0 + ], + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": null + }, + "AUC_sig": { + "label": "AUC significance", + "description": "Statistical significance of AUC relative to baseline.", + "units": null, + "higher_is_better": null, + "category": "significance", + "typical_range": null, + "allowed_aggregations": [ + "count" + ], + "notes": "Qualitative significance annotation." + }, + "PDIF_sig": { + "label": "PDIF significance", + "description": "Statistical significance of PDIF relative to baseline.", + "units": null, + "higher_is_better": null, + "category": "significance", + "typical_range": null, + "allowed_aggregations": [ + "count" + ], + "notes": "Qualitative significance annotation." + }, + "FAR": { + "label": "False Acceptance Rate", + "description": "Fraction of non-members incorrectly accepted as members.", + "units": null, + "higher_is_better": false, + "category": "performance", + "typical_range": [ + 0.0, + 1.0 + ], + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": null + }, + "TNR": { + "label": "True Negative Rate", + "description": "Fraction of non-members correctly identified as non-members.", + "units": null, + "higher_is_better": true, + "category": "performance", + "typical_range": [ + 0.0, + 1.0 + ], + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": null + }, + "PPV": { + "label": "Positive Predictive Value", + "description": "Probability that an instance predicted as a member is truly a member.", + "units": null, + "higher_is_better": true, + "category": "performance", + "typical_range": [ + 0.0, + 1.0 + ], + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": "Also known as precision." + }, + "NPV": { + "label": "Negative Predictive Value", + "description": "Probability that an instance predicted as a non-member is truly a non-member.", + "units": null, + "higher_is_better": true, + "category": "performance", + "typical_range": [ + 0.0, + 1.0 + ], + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": null + }, + "FNR": { + "label": "False Negative Rate", + "description": "Fraction of members incorrectly classified as non-members.", + "units": null, + "higher_is_better": false, + "category": "performance", + "typical_range": [ + 0.0, + 1.0 + ], + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": null + }, + "P_HIGHER_AUC": { + "label": "Probability of Higher AUC", + "description": "Estimated probability that the observed AUC exceeds the null or baseline AUC.", + "units": null, + "higher_is_better": true, + "category": "significance", + "typical_range": [ + 0.0, + 1.0 + ], + "allowed_aggregations": [ + "mean", + "min", + "max" + ], + "notes": "Often derived from permutation or bootstrap tests." + }, + "pred_prob_var": { + "label": "Prediction Probability Variance", + "description": "Variance of predicted membership probabilities.", + "units": null, + "higher_is_better": null, + "category": "diagnostic", + "typical_range": null, + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": "Used as a stability or confidence diagnostic." + }, + "null_auc_3sd_range": { + "label": "Null AUC \u00b13\u03c3 range", + "description": "Expected AUC range under the null hypothesis, defined as \u00b13 standard deviations.", + "units": null, + "higher_is_better": null, + "category": "significance", + "typical_range": null, + "allowed_aggregations": [ + "count" + ], + "notes": "Represented as a string range." + }, + "n_sig_auc_p_vals": { + "label": "Number of significant AUC p-values", + "description": "Count of repetitions where AUC p-values are statistically significant.", + "units": null, + "higher_is_better": true, + "category": "significance", + "typical_range": null, + "allowed_aggregations": [ + "sum", + "min", + "max" + ], + "notes": null + }, + "n_sig_auc_p_vals_corrected": { + "label": "Number of significant AUC p-values (corrected)", + "description": "Count of statistically significant AUC p-values after multiple-testing correction.", + "units": null, + "higher_is_better": true, + "category": "significance", + "typical_range": null, + "allowed_aggregations": [ + "sum", + "min", + "max" + ], + "notes": null + }, + "n_sig_pdif_vals": { + "label": "Number of significant PDIF values", + "description": "Count of PDIF values that are statistically significant.", + "units": null, + "higher_is_better": true, + "category": "significance", + "typical_range": null, + "allowed_aggregations": [ + "sum", + "min", + "max" + ], + "notes": null + }, + "n_sig_pdif_vals_corrected": { + "label": "Number of significant PDIF values (corrected)", + "description": "Count of statistically significant PDIF values after multiple-testing correction.", + "units": null, + "higher_is_better": true, + "category": "significance", + "typical_range": null, + "allowed_aggregations": [ + "sum", + "min", + "max" + ], + "notes": null + }, + "n_pos_test_examples": { + "label": "Number of positive test examples", + "description": "Total number of member (positive) examples in the test set.", + "units": null, + "higher_is_better": null, + "category": "dataset", + "typical_range": null, + "allowed_aggregations": [ + "min", + "max", + "mean" + ], + "notes": "Used to contextualise attack performance metrics." + }, + "n_neg_test_examples": { + "label": "Number of negative test examples", + "description": "Total number of non-member (negative) examples in the test set.", + "units": null, + "higher_is_better": null, + "category": "dataset", + "typical_range": null, + "allowed_aggregations": [ + "min", + "max", + "mean" + ], + "notes": "Used to contextualise attack performance metrics." + }, + "n_normal": { + "label": "Number of normal examples", + "description": "Number of examples considered normal or non-attacked in the evaluation.", + "units": null, + "higher_is_better": null, + "category": "dataset", + "typical_range": null, + "allowed_aggregations": [ + "min", + "max", + "mean" + ], + "notes": "Interpretation depends on attack context." + }, + "roc_thresh": { + "label": "ROC decision thresholds", + "description": "Decision thresholds used to compute ROC-based metrics.", + "units": null, + "higher_is_better": null, + "category": "diagnostic", + "typical_range": null, + "allowed_aggregations": [ + "count" + ], + "notes": "Typically used only for curve construction and diagnostics." + }, + "dof_risk": { + "label": "Dof Risk", + "description": "Degrees-of-freedom risk: model is overly complex relative to the training data size.", + "units": null, + "higher_is_better": false, + "category": "structural_risk", + "typical_range": null, + "allowed_aggregations": [ + "count" + ], + "notes": "Boolean structural-attack risk flag; True indicates a potential vulnerability." + }, + "k_anonymity_risk": { + "label": "K Anonymity Risk", + "description": "K-anonymity risk: training records fall into equivalence classes smaller than k.", + "units": null, + "higher_is_better": false, + "category": "structural_risk", + "typical_range": null, + "allowed_aggregations": [ + "count" + ], + "notes": "Boolean structural-attack risk flag; True indicates a potential vulnerability." + }, + "class_disclosure_risk": { + "label": "Class Disclosure Risk", + "description": "Class disclosure risk: outputs may reveal small disclosive groups.", + "units": null, + "higher_is_better": false, + "category": "structural_risk", + "typical_range": null, + "allowed_aggregations": [ + "count" + ], + "notes": "Boolean structural-attack risk flag; True indicates a potential vulnerability." + }, + "unnecessary_risk": { + "label": "Unnecessary Risk", + "description": "Unnecessary risk: hyperparameters set to values associated with higher MIA risk.", + "units": null, + "higher_is_better": false, + "category": "structural_risk", + "typical_range": null, + "allowed_aggregations": [ + "count" + ], + "notes": "Boolean structural-attack risk flag; True indicates a potential vulnerability." + }, + "lowvals_cd_risk": { + "label": "Lowvals Cd Risk", + "description": "Low-values class disclosure risk: class frequency within an equivalence class below a safe threshold.", + "units": null, + "higher_is_better": false, + "category": "structural_risk", + "typical_range": null, + "allowed_aggregations": [ + "count" + ], + "notes": "Boolean structural-attack risk flag; True indicates a potential vulnerability." + } + }, + "pattern_metrics": [ + { + "pattern": "^(fpr|tpr)$", + "label_template": "{metric_name} curve", + "description": "Curve-valued metric represented as an array of rates evaluated across decision thresholds. Intended for paired use (e.g. fpr vs tpr).", + "units": null, + "higher_is_better": null, + "category": "curve", + "allowed_aggregations": [ + "min", + "max" + ], + "notes": "Curve-valued metric. Should not be aggregated directly; use to derive scalar summaries such as AUC." + }, + { + "pattern": "^(FMAX|FMIN|FDIF|PDIF)[0-9]+$", + "label_template": "{metric_name}", + "description": "Threshold-based extremum or difference metric computed at a specified operating point.", + "units": null, + "higher_is_better": true, + "category": "threshold_analysis", + "allowed_aggregations": [ + "mean", + "min", + "max" + ], + "notes": "Metric name suffix indicates the threshold or operating point." + }, + { + "pattern": "^TPR@.+$", + "label_template": "TPR at threshold {metric_name}", + "description": "True Positive Rate evaluated at a specific decision threshold indicated in the metric name.", + "units": null, + "higher_is_better": true, + "category": "performance", + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": "Metric suffix indicates the decision threshold (e.g. TPR@0.5, TPR@0.1)." + } + ] + }, + "parameter_catalog": { + "version": "params_v1", + "parameters": { + "n_reps": { + "label": "Number of repetitions", + "description": "Number of independent repetitions used to estimate variability.", + "parameter_scope": "attack", + "value_type": "integer", + "notes": null + }, + "p_thresh": { + "label": "Probability threshold", + "description": "Threshold used to decide membership based on predicted probability.", + "parameter_scope": "attack", + "value_type": "float", + "notes": null + }, + "attack_model": { + "label": "Attack model class", + "description": "Classifier used by the attacker.", + "parameter_scope": "attack", + "value_type": "string", + "notes": null + }, + "n_shadow_models": { + "label": "Number of shadow models", + "description": "How many shadow models are trained to emulate the target.", + "parameter_scope": "attack", + "value_type": "integer", + "notes": null + }, + "aggregation_method": { + "label": "Aggregation method", + "description": "Method used to aggregate results from multiple attacks.", + "parameter_scope": "attack", + "value_type": "string", + "notes": null + }, + "mode": { + "label": "Execution mode", + "description": "Mode in which the attack operates, e.g. offline or online.", + "parameter_scope": "attack", + "value_type": "string", + "notes": null + }, + "fix_variance": { + "label": "Fix variance", + "description": "Whether to fix variance when estimating likelihood ratios.", + "parameter_scope": "attack", + "value_type": "boolean", + "notes": null + }, + "report_individual": { + "label": "Report individual results", + "description": "Whether to report per-instance attack results.", + "parameter_scope": "attack", + "value_type": "boolean", + "notes": null + }, + "output_dir": { + "label": "Output directory", + "description": "Directory in which attack artefacts are written.", + "parameter_scope": "global", + "value_type": "string", + "notes": null + }, + "write_report": { + "label": "Write report", + "description": "Whether a report should be written to disk.", + "parameter_scope": "global", + "value_type": "boolean", + "notes": null + }, + "reproduce_split": { + "label": "Reproduce split seed", + "description": "Seed used to reproduce the train/test split.", + "parameter_scope": "attack", + "value_type": "integer", + "notes": null + }, + "n_dummy_reps": { + "label": "Number of dummy repetitions", + "description": "Number of dummy or null repetitions used for baseline estimation.", + "parameter_scope": "attack", + "value_type": "integer", + "notes": null + }, + "train_beta": { + "label": "Training beta", + "description": "Beta parameter applied during training phase of the attack.", + "parameter_scope": "attack", + "value_type": "number", + "notes": null + }, + "test_beta": { + "label": "Testing beta", + "description": "Beta parameter applied during testing phase of the attack.", + "parameter_scope": "attack", + "value_type": "number", + "notes": null + }, + "test_prop": { + "label": "Test proportion", + "description": "Proportion of data reserved for testing.", + "parameter_scope": "training", + "value_type": "number", + "notes": null + }, + "include_model_correct_feature": { + "label": "Include model correctness feature", + "description": "Whether a feature indicating model correctness is included.", + "parameter_scope": "attack", + "value_type": "boolean", + "notes": null + }, + "sort_probs": { + "label": "Sort probabilities", + "description": "Whether predicted probabilities are sorted before analysis.", + "parameter_scope": "attack", + "value_type": "boolean", + "notes": null + }, + "attack_model_params": { + "label": "Attack model parameters", + "description": "Hyperparameters for the attack model.", + "parameter_scope": "attack", + "value_type": "object", + "notes": "Null indicates default parameters." + } + } + }, + "attack_category_catalog": { + "version": "categories_v1", + "categories": { + "structural": { + "label": "Structural attacks", + "description": "Attacks exploiting deterministic or structural properties of the target model.", + "order": 10 + }, + "probabilistic": { + "label": "Probabilistic attacks", + "description": "Attacks based on confidence scores or probability distributions.", + "order": 20 + }, + "meta": { + "label": "Meta attacks", + "description": "Attacks that combine or summarise the results of multiple other attacks.", + "order": 30 + } + } + }, + "attack_catalog": { + "version": "attacks_sacroml_v1", + "attacks": { + "WorstCase attack": { + "label": "Worst\u2011Case Membership Inference Attack", + "description": "Assumes the attacker has maximal auxiliary knowledge.", + "attack_category": "structural", + "key_metrics": [ + "AUC", + "TPR" + ], + "attack_params": { + "n_reps": { + "label": "Number of repetitions", + "description": "Independent repetitions of the attack." + }, + "p_thresh": { + "label": "Probability threshold", + "description": "Decision threshold for membership prediction." + } + } + }, + "ShadowModel attack": { + "label": "Shadow Model Attack", + "description": "Uses shadow models to approximate the target model\u2019s behaviour.", + "attack_category": "probabilistic", + "key_metrics": [ + "AUC" + ], + "attack_params": { + "n_shadow_models": { + "label": "Number of shadow models", + "description": "How many shadow models to train." + }, + "attack_model": { + "label": "Attack model", + "description": "Classifier trained on shadow model outputs." + } + } + }, + "AggregateRisk attack": { + "label": "Aggregate Risk Meta\u2011Attack", + "description": "Aggregates the outputs of multiple attacks into a single risk score.", + "attack_category": "meta", + "key_metrics": [ + "risk_score" + ], + "attack_params": { + "aggregation_method": { + "label": "Aggregation method", + "description": "Strategy used to combine individual attack results." + } + } + }, + "LiRA Attack": { + "label": "Likelihood Ratio Attack (LiRA)", + "description": "Membership inference attack based on likelihood ratios derived from shadow models.", + "attack_category": "probabilistic", + "key_metrics": [ + "AUC", + "TPR" + ], + "attack_params": { + "n_shadow_models": { + "label": "Number of shadow models", + "description": "Number of shadow models used to estimate likelihoods." + }, + "p_thresh": { + "label": "Probability threshold", + "description": "Threshold used for membership decision." + }, + "mode": { + "label": "Execution mode", + "description": "Whether the attack is run in offline or online mode." + }, + "fix_variance": { + "label": "Fix variance", + "description": "Whether variance is fixed when computing likelihood ratios." + }, + "report_individual": { + "label": "Report individual", + "description": "Whether to report individual-instance results." + } + } + }, + "Structural Attack": { + "label": "Structural Attack", + "description": "Static analysis of the target model's structure and training data against pre-defined disclosure-risk rules and thresholds.", + "attack_category": "structural", + "key_metrics": [ + "dof_risk", + "k_anonymity_risk", + "class_disclosure_risk" + ], + "attack_params": { + "output_dir": { + "label": "Output directory", + "description": "Directory in which attack artefacts are written." + } + } + }, + "Attribute inference attack": { + "label": "Attribute Inference Attack", + "description": "Attempts to infer sensitive attribute values of training records from the target model's outputs.", + "attack_category": "probabilistic", + "key_metrics": [ + "risk_score" + ], + "attack_params": { + "output_dir": { + "label": "Output directory", + "description": "Directory in which attack artefacts are written." + } + } + } + } + } +} diff --git a/sacroml/reporting/convert.py b/sacroml/reporting/convert.py new file mode 100644 index 00000000..53658961 --- /dev/null +++ b/sacroml/reporting/convert.py @@ -0,0 +1,377 @@ +"""Convert legacy SACRO-ML ``report.json`` files into the new report format. + +Older SACRO-ML versions write a flat JSON object keyed by +``"_"``. The new reporting pipeline expects a nested, +catalog-enriched document validated by +``sacroml/reporting/sacroml_attack_report.schema.json``. + +The conversion: + +1. wraps the legacy experiments under a top-level ``"attacks"`` key; +2. injects the four catalogs (``metric_catalog``, ``parameter_catalog``, + ``attack_category_catalog``, ``attack_catalog``) from the bundled + ``catalog_definitions.json``; +3. validates the result against the JSON schema; and +4. warns about any metric, parameter, attack or attack category observed in the + report that is not present in the catalogs (the conversion still succeeds). + +Two small normalisations keep real-world legacy reports schema-valid: a +placeholder ``sacroml_version`` is injected when missing, and an empty +``attack_instance_logger`` is supplied for instance-less attacks (e.g. +structural attacks). Anything else that does not match the schema -- such as a +non-scalar metric or a stray metadata key -- is surfaced as a schema error +rather than silently rewritten. + +Curve-valued arrays (``fpr`` / ``tpr`` / ``roc_thresh``) are passed through +untouched; ``roc_thresh`` legitimately starts with ``null``, which the schema +does not yet permit, so such violations are reported as notices rather than +errors. +""" + +from __future__ import annotations + +import json +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from jsonschema import Draft7Validator + +REPORT_SCHEMA_VERSION = "1.2" + +_HERE = Path(__file__).parent +SCHEMA_PATH = _HERE / "sacroml_attack_report.schema.json" +CATALOG_DEFINITIONS_PATH = _HERE / "catalog_definitions.json" + +# Top-level keys that are *not* experiments in an already-converted report. +_NON_EXPERIMENT_KEYS = frozenset( + { + "report_schema_version", + "metric_catalog", + "parameter_catalog", + "attack_category_catalog", + "attack_catalog", + "attacks", + } +) + + +@dataclass +class ConversionResult: + """Outcome of converting a legacy report. + + Attributes + ---------- + report : dict + The converted report (new format). + warnings : list[str] + Non-fatal issues: uncatalogued metrics/parameters/attacks/categories + and the normalisations that were applied. + curve_warnings : list[str] + Schema violations attributable solely to curve-valued arrays + (``fpr`` / ``tpr`` / ``roc_thresh``); these are expected and benign. + schema_errors : list[str] + Genuine schema violations that make the output invalid. + coverage : dict + Per-dimension list of uncatalogued names. + """ + + report: dict[str, Any] + warnings: list[str] = field(default_factory=list) + curve_warnings: list[str] = field(default_factory=list) + schema_errors: list[str] = field(default_factory=list) + coverage: dict[str, list[str]] = field(default_factory=dict) + + @property + def is_valid(self) -> bool: + """Return whether the converted report passes schema validation. + + Curve-array violations do not count as failures. + """ + return not self.schema_errors + + +def _load_json(path: Path) -> dict[str, Any]: + """Load a bundled JSON data file.""" + with open(path, encoding="utf-8") as fh: + return json.load(fh) + + +def _is_experiment(value: Any) -> bool: + """Return whether a top-level value looks like a legacy experiment.""" + return isinstance(value, dict) and ( + "attack_experiment_logger" in value or "metadata" in value or "log_id" in value + ) + + +def _extract_experiments(data: dict[str, Any]) -> dict[str, Any]: + """Return the experiment mapping, whether the input is flat or wrapped.""" + if isinstance(data.get("attacks"), dict): + return dict(data["attacks"]) # already wrapped (idempotent path) + return { + key: value + for key, value in data.items() + if key not in _NON_EXPERIMENT_KEYS and _is_experiment(value) + } + + +def _normalise_experiment( + exp_key: str, experiment: dict[str, Any], warnings: list[str] +) -> dict[str, Any]: + """Apply the minimal normalisations needed for schema validity.""" + exp = dict(experiment) + exp.setdefault("log_id", exp_key) + exp.setdefault("log_time", "") + + metadata = dict(exp.get("metadata") or {}) + if "sacroml_version" not in metadata: + metadata["sacroml_version"] = "unknown" + warnings.append( + f"Experiment '{exp_key}': metadata.sacroml_version was missing; " + "set to 'unknown'." + ) + exp["metadata"] = metadata + + # Instance-less attacks (e.g. structural attacks) have no per-instance + # metrics; supply an empty logger so the result still satisfies the schema. + logger = exp.get("attack_experiment_logger") + instances = logger.get("attack_instance_logger") if isinstance(logger, dict) else {} + if not isinstance(instances, dict): + instances = {} + exp["attack_experiment_logger"] = {"attack_instance_logger": instances} + return exp + + +def _compile_pattern_metrics(metric_catalog: dict[str, Any]) -> list[re.Pattern]: + """Compile the regexes declared in ``metric_catalog.pattern_metrics``.""" + patterns: list[re.Pattern] = [] + for entry in metric_catalog.get("pattern_metrics", []): + pattern = entry.get("pattern") if isinstance(entry, dict) else None + if pattern: + patterns.append(re.compile(pattern)) + return patterns + + +def _uncatalogued( + seen: set[str], explicit: set[str], patterns: list[re.Pattern] | None = None +) -> list[str]: + """Return the sorted names in ``seen`` that no catalog entry covers.""" + patterns = patterns or [] + return sorted( + name + for name in seen + if name not in explicit and not any(p.match(name) for p in patterns) + ) + + +def _compute_coverage( + report: dict[str, Any], catalogs: dict[str, Any] +) -> tuple[dict[str, list[str]], list[str]]: + """Diff observed metrics/parameters/attacks/categories vs the catalogs.""" + seen_metrics: set[str] = set() + seen_params: set[str] = set() + seen_attacks: set[str] = set() + + for exp in report["attacks"].values(): + instances = exp["attack_experiment_logger"]["attack_instance_logger"] + for inst in instances.values(): + if isinstance(inst, dict): + seen_metrics.update(inst.keys()) + metadata = exp.get("metadata", {}) + for key in ("global_metrics", "baseline_global_metrics"): + values = metadata.get(key, {}) + if isinstance(values, dict): + seen_metrics.update(values) + params = metadata.get("attack_params", {}) + if isinstance(params, dict): + seen_params.update(params) + seen_attacks.add(metadata.get("attack_name", "")) + seen_attacks.discard("") + + metric_catalog = catalogs["metric_catalog"] + catalog_attacks = catalogs["attack_catalog"].get("attacks", {}) + # Categories referenced by the attacks we can resolve in the catalog. + seen_categories = { + catalog_attacks[name]["attack_category"] + for name in seen_attacks + if name in catalog_attacks and "attack_category" in catalog_attacks[name] + } + + coverage = { + "metrics": _uncatalogued( + seen_metrics, + set(metric_catalog.get("metrics", {})), + _compile_pattern_metrics(metric_catalog), + ), + "parameters": _uncatalogued( + seen_params, set(catalogs["parameter_catalog"].get("parameters", {})) + ), + "attacks": _uncatalogued(seen_attacks, set(catalog_attacks)), + "attack_categories": _uncatalogued( + seen_categories, + set(catalogs["attack_category_catalog"].get("categories", {})), + ), + } + + labels = { + "metrics": "metric", + "parameters": "parameter", + "attacks": "attack", + "attack_categories": "attack category", + } + warnings = [ + f"Uncatalogued {labels[dim]}: '{name}' is not present in the {dim} catalog." + for dim, names in coverage.items() + for name in names + ] + return coverage, warnings + + +def _is_curve_violation(error: Any, report: dict[str, Any]) -> bool: + """Return whether a schema error is caused solely by a curve array. + + Curve-valued metrics (``fpr`` / ``tpr`` / ``roc_thresh``) are stored as + raw arrays that may contain ``null`` (e.g. ``roc_thresh[0]``). These are + a known schema limitation and are reported as warnings, not errors. + """ + path = list(error.absolute_path) + # Expected path: attacks / / attack_experiment_logger / + # attack_instance_logger / instance_ / [/ idx] + if len(path) < 6: + return False + if path[0] != "attacks" or path[2] != "attack_experiment_logger": + return False + if path[3] != "attack_instance_logger": + return False + metric = path[5] + if not isinstance(metric, str): + return False + try: + value = report["attacks"][path[1]]["attack_experiment_logger"][ + "attack_instance_logger" + ][path[4]][metric] + except (KeyError, TypeError, IndexError): + return False + return isinstance(value, list) + + +def _validate(report: dict[str, Any]) -> tuple[list[str], list[str]]: + """Validate ``report`` against the schema. + + Returns + ------- + tuple[list[str], list[str]] + ``(schema_errors, curve_warnings)`` -- human-readable messages. + """ + validator = Draft7Validator(_load_json(SCHEMA_PATH)) + schema_errors: list[str] = [] + curve_warnings: list[str] = [] + for error in sorted(validator.iter_errors(report), key=str): + location = "/".join(str(p) for p in error.absolute_path) or "" + detail = error.message + if len(detail) > 200: + detail = detail[:200] + "... (truncated)" + message = f"{location}: {detail}" + if _is_curve_violation(error, report): + curve_warnings.append(message) + else: + schema_errors.append(message) + return schema_errors, curve_warnings + + +def convert_report(data: dict[str, Any], *, validate: bool = True) -> ConversionResult: + """Convert an in-memory legacy report dictionary to the new format. + + Parameters + ---------- + data : dict + The parsed legacy ``report.json`` (flat) or an already-wrapped report. + validate : bool, default True + Whether to validate the converted report against the JSON schema. + + Returns + ------- + ConversionResult + The converted report plus warnings, schema errors and a coverage + summary. + """ + warnings: list[str] = [] + experiments: dict[str, Any] = {} + if isinstance(data, dict): + experiments = _extract_experiments(data) + else: + warnings.append( + f"Top-level report was a {type(data).__name__}, not an object; " + "no experiments could be extracted." + ) + + catalogs = _load_json(CATALOG_DEFINITIONS_PATH) + report: dict[str, Any] = { + "report_schema_version": REPORT_SCHEMA_VERSION, + "metric_catalog": catalogs["metric_catalog"], + "parameter_catalog": catalogs["parameter_catalog"], + "attack_category_catalog": catalogs["attack_category_catalog"], + "attack_catalog": catalogs["attack_catalog"], + "attacks": { + key: _normalise_experiment(key, exp, warnings) + for key, exp in experiments.items() + if isinstance(exp, dict) + }, + } + + coverage, coverage_warnings = _compute_coverage(report, catalogs) + warnings.extend(coverage_warnings) + + schema_errors: list[str] = [] + curve_warnings: list[str] = [] + if validate: + schema_errors, curve_warnings = _validate(report) + + return ConversionResult( + report=report, + warnings=warnings, + curve_warnings=curve_warnings, + schema_errors=schema_errors, + coverage=coverage, + ) + + +def convert_report_file( + input_path: str | Path, + output_path: str | Path, + *, + validate: bool = True, + indent: int = 2, +) -> ConversionResult: + """Convert a legacy ``report.json`` file and write the new format to disk. + + Parameters + ---------- + input_path : str | Path + Path to the legacy report. + output_path : str | Path + Path to write the converted report to. + validate : bool, default True + Whether to validate the converted report against the JSON schema. + indent : int, default 2 + Indentation for the written JSON. + + Returns + ------- + ConversionResult + The conversion outcome. + """ + with open(input_path, encoding="utf-8") as fh: + data = json.load(fh) + + result = convert_report(data, validate=validate) + + output_path = Path(output_path) + if output_path.parent and not output_path.parent.exists(): + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w", encoding="utf-8") as fh: + json.dump(result.report, fh, indent=indent) + fh.write("\n") + + return result diff --git a/sacroml/reporting/sacroml_attack_report.schema.json b/sacroml/reporting/sacroml_attack_report.schema.json new file mode 100644 index 00000000..564b6592 --- /dev/null +++ b/sacroml/reporting/sacroml_attack_report.schema.json @@ -0,0 +1,396 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "SacroML Attack Report", + "description": "A nested, attack-agnostic schema for SacroML attack reports. Includes human-readable catalogs (metrics, parameters, attack categories, attacks).", + "type": "object", + "properties": { + "report_schema_version": { + "type": "string" + }, + "metric_catalog": { + "type": "object", + "properties": { + "version": { + "type": "string" + }, + "metrics": { + "type": "object", + "patternProperties": { + "^[A-Za-z0-9_@.eE+\\-]+$": { + "$ref": "#/$defs/MetricDefinition" + } + }, + "additionalProperties": false + }, + "pattern_metrics": { + "type": "array", + "items": { + "$ref": "#/$defs/PatternMetricDefinition" + } + } + }, + "required": [ + "version", + "metrics" + ], + "additionalProperties": false + }, + "parameter_catalog": { + "type": "object", + "properties": { + "version": { + "type": "string" + }, + "parameters": { + "type": "object", + "patternProperties": { + "^[A-Za-z0-9_\\-]+$": { + "$ref": "#/$defs/ParameterDefinition" + } + }, + "additionalProperties": false + } + }, + "required": [ + "version", + "parameters" + ], + "additionalProperties": false + }, + "attack_category_catalog": { + "type": "object", + "description": "Human-readable definitions of high-level attack categories.", + "properties": { + "version": { + "type": "string" + }, + "categories": { + "type": "object", + "patternProperties": { + "^[A-Za-z0-9_\\-]+$": { + "$ref": "#/$defs/AttackCategoryDefinition" + } + }, + "additionalProperties": false + } + }, + "required": [ + "version", + "categories" + ], + "additionalProperties": false + }, + "attack_catalog": { + "type": "object", + "properties": { + "version": { + "type": "string" + }, + "attacks": { + "type": "object", + "patternProperties": { + "^[A-Za-z0-9 _\\-]+$": { + "$ref": "#/$defs/AttackKindDefinition" + } + }, + "additionalProperties": false + } + }, + "required": [ + "attacks" + ], + "additionalProperties": false + }, + "attacks": { + "type": "object", + "patternProperties": { + "^[A-Za-z0-9 _\\-]+$": { + "type": "object", + "properties": { + "log_id": { + "type": "string" + }, + "log_time": { + "type": "string" + }, + "metadata": { + "type": "object", + "properties": { + "sacroml_version": { + "type": "string" + }, + "attack_name": { + "type": "string" + }, + "attack_params": { + "type": "object", + "additionalProperties": true + }, + "global_metrics": { + "type": "object", + "additionalProperties": { + "type": [ + "number", + "string", + "boolean", + "null" + ] + } + }, + "baseline_global_metrics": { + "type": "object", + "additionalProperties": { + "type": [ + "number", + "string", + "boolean", + "null" + ] + } + }, + "target_model": { + "type": "string" + }, + "target_model_params": { + "type": "object", + "additionalProperties": true + }, + "target_train_params": { + "type": "object", + "additionalProperties": true + } + }, + "required": [ + "sacroml_version", + "attack_name", + "attack_params", + "global_metrics" + ], + "additionalProperties": false + }, + "attack_experiment_logger": { + "type": "object", + "properties": { + "attack_instance_logger": { + "type": "object", + "patternProperties": { + "^instance_[0-9]+$": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array", + "items": { + "type": "number" + } + }, + { + "type": "null" + } + ] + } + } + }, + "additionalProperties": false + } + }, + "required": [ + "attack_instance_logger" + ], + "additionalProperties": false + } + }, + "required": [ + "log_id", + "log_time", + "metadata", + "attack_experiment_logger" + ], + "additionalProperties": true + } + }, + "additionalProperties": false + } + }, + "required": [ + "metric_catalog", + "parameter_catalog", + "attack_category_catalog", + "attack_catalog", + "attacks" + ], + "additionalProperties": false, + "$defs": { + "AttackCategoryDefinition": { + "type": "object", + "description": "Human-readable definition of an attack category, including report ordering.", + "properties": { + "label": { + "type": "string", + "description": "Human-readable category label." + }, + "description": { + "type": "string", + "description": "Explanation of what this category represents." + }, + "order": { + "type": "integer", + "description": "Defines the order in which this category should be presented in reports (lower first)." + } + }, + "required": [ + "label", + "description", + "order" + ], + "additionalProperties": false + }, + "AttackKindDefinition": { + "type": "object", + "description": "Attack kind definition with category, key metrics, and obligatory attack-specific parameter declarations.", + "properties": { + "label": { + "type": "string", + "description": "Human-readable attack name." + }, + "description": { + "type": "string", + "description": "Description of what the attack does." + }, + "attack_category": { + "type": "string", + "description": "Key into attack_category_catalog.categories." + }, + "key_metrics": { + "type": "array", + "items": { + "type": "string" + }, + "minItems": 1 + }, + "attack_params": { + "type": "object", + "description": "Human-readable declarations of all parameters supported by this attack kind.", + "patternProperties": { + "^[A-Za-z0-9_\\-]+$": { + "type": "object", + "properties": { + "label": { + "type": "string" + }, + "description": { + "type": "string" + } + }, + "required": [ + "label", + "description" + ], + "additionalProperties": false + } + }, + "additionalProperties": false + } + }, + "required": [ + "label", + "description", + "attack_category", + "key_metrics", + "attack_params" + ], + "additionalProperties": false + }, + "MetricDefinition": { + "type": "object", + "description": "Human-readable description of a metric.", + "properties": { + "label": { + "type": "string" + }, + "description": { + "type": "string" + }, + "units": { + "type": [ + "string", + "null" + ] + }, + "higher_is_better": { + "type": [ + "boolean", + "null" + ] + }, + "category": { + "type": [ + "string", + "null" + ] + }, + "typical_range": { + "oneOf": [ + { + "type": "array", + "items": { + "type": "number" + }, + "minItems": 2, + "maxItems": 2 + }, + { + "type": "null" + } + ] + }, + "allowed_aggregations": { + "type": "array", + "description": "List of aggregation operations that are valid and meaningful for this metric.", + "items": { + "type": "string", + "enum": [ + "sum", + "mean", + "min", + "max", + "median", + "stddev", + "var", + "quantile", + "count" + ] + }, + "minItems": 0, + "uniqueItems": true + }, + "notes": { + "type": [ + "string", + "null" + ] + } + }, + "required": [ + "label", + "description", + "allowed_aggregations" + ], + "additionalProperties": false + }, + "PatternMetricDefinition": { + "type": "object", + "additionalProperties": true + }, + "ParameterDefinition": { + "type": "object", + "additionalProperties": true + } + } +} diff --git a/tests/reporting/__init__.py b/tests/reporting/__init__.py new file mode 100644 index 00000000..e3c7d084 --- /dev/null +++ b/tests/reporting/__init__.py @@ -0,0 +1 @@ +"""Tests for the sacroml.reporting package.""" diff --git a/tests/reporting/fixtures/legacy_basic.json b/tests/reporting/fixtures/legacy_basic.json new file mode 100644 index 00000000..7566660e --- /dev/null +++ b/tests/reporting/fixtures/legacy_basic.json @@ -0,0 +1,91 @@ +{ + "WorstCase attack_de2c5fc0-fb0c-4925-ac4d-26662fe7f786": { + "log_id": "de2c5fc0-fb0c-4925-ac4d-26662fe7f786", + "log_time": "05/11/2025 16:54:34", + "metadata": { + "sacroml_version": "1.4.0", + "attack_name": "WorstCase attack", + "attack_params": { + "output_dir": "training_artefacts", + "write_report": true, + "n_reps": 2, + "p_thresh": 0.05 + }, + "global_metrics": { + "null_auc_3sd_range": "0.4629 -> 0.5371", + "n_sig_auc_p_vals": 2 + }, + "baseline_global_metrics": { + "n_sig_auc_p_vals": 0 + }, + "target_model": "RandomForestClassifier", + "target_model_params": {}, + "target_train_params": {} + }, + "attack_experiment_logger": { + "attack_instance_logger": { + "instance_0": { + "TPR": 0.019, + "FPR": 0.007, + "ACC": 0.7, + "AUC": 0.57, + "P_HIGHER_AUC": 3.05e-09, + "FDIF01": 0.64, + "PDIF01": 31.5, + "TPR@0.1": 0.43 + }, + "instance_1": { + "TPR": 0.012, + "FPR": 0.004, + "ACC": 0.7, + "AUC": 0.58, + "P_HIGHER_AUC": 1.8e-13, + "FDIF01": 0.57, + "PDIF01": 28.1, + "TPR@0.1": 0.39 + } + } + } + }, + "LiRA Attack_4e9b5db3-f93f-43dc-9a67-1035a68b892a": { + "log_id": "4e9b5db3-f93f-43dc-9a67-1035a68b892a", + "log_time": "05/11/2025 16:55:01", + "metadata": { + "sacroml_version": "1.4.0", + "attack_name": "LiRA Attack", + "attack_params": { + "output_dir": "training_artefacts", + "n_shadow_models": 100, + "p_thresh": 0.05 + }, + "global_metrics": { + "AUC_sig": "Significant at p=0.05" + }, + "target_model": "RandomForestClassifier", + "target_model_params": {}, + "target_train_params": {} + }, + "attack_experiment_logger": { + "attack_instance_logger": { + "instance_0": { + "TPR": 0.76, + "FPR": 0.49, + "AUC": 0.75, + "P_HIGHER_AUC": 0.0, + "fpr": [ + 0.0, + 0.5, + 1.0 + ], + "tpr": [ + 0.0, + 0.7, + 1.0 + ], + "n_pos_test_examples": 398, + "n_neg_test_examples": 171 + } + } + } + } +} diff --git a/tests/reporting/test_convert.py b/tests/reporting/test_convert.py new file mode 100644 index 00000000..e76e368a --- /dev/null +++ b/tests/reporting/test_convert.py @@ -0,0 +1,324 @@ +"""Tests for legacy report.json -> new format conversion.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from sacroml.main import main +from sacroml.reporting import ConversionResult, convert_report, convert_report_file +from sacroml.reporting.convert import _is_curve_violation + +FIXTURES = Path(__file__).parent / "fixtures" +DOCS_EXAMPLES = Path(__file__).parents[2] / "docs" / "source" / "attacks" + + +def _load(name: str) -> dict: + """Load a fixture JSON file by name.""" + with open(FIXTURES / name, encoding="utf-8") as fh: + return json.load(fh) + + +def _experiment(metadata: dict | None = None, instances: dict | None = None) -> dict: + """Build a minimal, schema-valid legacy experiment for inline tests.""" + meta = { + "sacroml_version": "1.4.0", + "attack_name": "LiRA Attack", + "attack_params": {}, + "global_metrics": {"AUC_sig": "Significant at p=0.05"}, + } + if metadata is not None: + meta.update(metadata) + exp: dict = {"log_id": "id", "log_time": "now", "metadata": meta} + if instances is not None: + exp["attack_experiment_logger"] = {"attack_instance_logger": instances} + return exp + + +@pytest.fixture +def basic_result() -> ConversionResult: + """Conversion result for the well-formed legacy fixture.""" + return convert_report(_load("legacy_basic.json")) + + +# --- R1/R2: wrapping and catalog injection ------------------------------- + + +def test_basic_conversion_structure(basic_result: ConversionResult) -> None: + """A clean legacy report wraps under 'attacks' with all catalogs.""" + report = basic_result.report + assert report["report_schema_version"] + for catalog in ( + "metric_catalog", + "parameter_catalog", + "attack_category_catalog", + "attack_catalog", + "attacks", + ): + assert catalog in report + assert set(report["attacks"]) == { + "WorstCase attack_de2c5fc0-fb0c-4925-ac4d-26662fe7f786", + "LiRA Attack_4e9b5db3-f93f-43dc-9a67-1035a68b892a", + } + + +# --- R3: schema validation ----------------------------------------------- + + +def test_basic_conversion_is_schema_valid(basic_result: ConversionResult) -> None: + """A clean legacy report converts to a schema-valid document.""" + assert basic_result.is_valid + assert basic_result.schema_errors == [] + assert basic_result.warnings == [] + assert basic_result.curve_warnings == [] + + +def test_conversion_is_idempotent(basic_result: ConversionResult) -> None: + """Converting an already-converted report is a no-op for the payload.""" + again = convert_report(basic_result.report) + assert again.report["attacks"] == basic_result.report["attacks"] + assert again.is_valid + + +@pytest.mark.parametrize( + "example", + ["report_example_lira.json", "report_example_worstcase.json"], +) +def test_real_docs_examples_validate(example: str) -> None: + """Real example reports convert, validate, and trip the load-bearing paths. + + Both real reports lack ``sacroml_version`` (injected) and contain + ``roc_thresh``/curve arrays that start with ``null`` (downgraded to curve + notices); without either behaviour the result would not be schema-valid. + """ + with open(DOCS_EXAMPLES / example, encoding="utf-8") as fh: + data = json.load(fh) + result = convert_report(data) + assert result.is_valid, result.schema_errors + assert result.curve_warnings # roc/curve null arrays present + assert all( + exp["metadata"]["sacroml_version"] == "unknown" + for exp in result.report["attacks"].values() + ) + + +def test_non_curve_schema_error_is_reported() -> None: + """An instance metric with an object value yields a real schema error.""" + legacy = { + "LiRA Attack_yy": _experiment( + instances={"instance_0": {"AUC": {"unexpected": "object"}}} + ) + } + result = convert_report(legacy) + assert not result.is_valid + assert result.schema_errors + + +def test_curve_array_is_warning_not_error() -> None: + """An roc_thresh array starting with null is a curve notice, not an error.""" + legacy = { + "LiRA Attack_zz": _experiment( + instances={"instance_0": {"AUC": 0.75, "roc_thresh": [None, 1.0, 0.0]}} + ) + } + result = convert_report(legacy) + assert result.is_valid + assert result.curve_warnings + assert all("roc_thresh" in w for w in result.curve_warnings) + + +class _FakeError: + """Minimal stand-in for a jsonschema ValidationError (path only).""" + + def __init__(self, path: list) -> None: + self.absolute_path = path + + +_AEL = "attack_experiment_logger" +_AIL = "attack_instance_logger" + + +@pytest.mark.parametrize( + "path", + [ + ["attacks"], # path too short + ["other", "e", _AEL, _AIL, "i", "AUC"], # not the attacks subtree + ["attacks", "e", _AEL, "wrong", "i", "AUC"], # not the instance logger + ["attacks", "e", _AEL, _AIL, "i", 0], # metric name is not a string + ["attacks", "x", _AEL, _AIL, "i", "AUC"], # value not present in report + ], +) +def test_is_curve_violation_rejects_non_curve_paths(path: list) -> None: + """Only a well-formed instance-metric path with a list value is a curve.""" + assert _is_curve_violation(_FakeError(path), {"attacks": {}}) is False + + +# --- minimal normalisation (load-bearing on real reports) ---------------- + + +def test_missing_sacroml_version_injected() -> None: + """Reports without sacroml_version get an 'unknown' placeholder.""" + legacy = {"LiRA Attack_x": _experiment(metadata={"sacroml_version": None})} + del legacy["LiRA Attack_x"]["metadata"]["sacroml_version"] + result = convert_report(legacy) + meta = result.report["attacks"]["LiRA Attack_x"]["metadata"] + assert meta["sacroml_version"] == "unknown" + assert any("sacroml_version was missing" in w for w in result.warnings) + + +def test_structural_attack_gets_empty_logger() -> None: + """An instance-less attack gets an empty instance logger, not a crash.""" + legacy = { + "Structural Attack_aaaa": { + "log_id": "aaaa", + "metadata": { + "sacroml_version": "1.4.0", + "attack_name": "Structural Attack", + "attack_params": {}, + "global_metrics": {}, + }, + } + } + result = convert_report(legacy) + struct = result.report["attacks"]["Structural Attack_aaaa"] + assert struct["attack_experiment_logger"]["attack_instance_logger"] == {} + assert result.is_valid + + +def test_non_dict_instance_logger_becomes_empty() -> None: + """A non-dict attack_instance_logger is replaced with an empty one.""" + legacy = { + "LiRA Attack_qq": { + "log_id": "qq", + "metadata": {"sacroml_version": "1.0", "attack_name": "LiRA Attack"}, + "attack_experiment_logger": {"attack_instance_logger": "not a dict"}, + } + } + result = convert_report(legacy, validate=False) + logger = result.report["attacks"]["LiRA Attack_qq"]["attack_experiment_logger"] + assert logger["attack_instance_logger"] == {} + + +# --- R4: uncatalogued warnings ------------------------------------------- + + +def test_basic_coverage_all_catalogued(basic_result: ConversionResult) -> None: + """All metrics/params/attacks in the clean fixture are catalogued.""" + assert all(not missing for missing in basic_result.coverage.values()) + + +def test_uncatalogued_entries_warn() -> None: + """Unknown metrics/params/attacks are reported but not fatal.""" + legacy = { + "Mystery attack_bbbb": _experiment( + metadata={ + "attack_name": "Mystery attack", + "attack_params": {"weird_param": True}, + }, + instances={"instance_0": {"AUC": 0.6, "weird_metric": 0.42}}, + ) + } + result = convert_report(legacy) + assert "weird_metric" in result.coverage["metrics"] + assert "weird_param" in result.coverage["parameters"] + assert "Mystery attack" in result.coverage["attacks"] + assert any("Uncatalogued metric" in w for w in result.warnings) + assert any("Uncatalogued parameter" in w for w in result.warnings) + assert any("Uncatalogued attack" in w for w in result.warnings) + + +def test_no_validate_skips_validation() -> None: + """Validate=False produces no schema errors but keeps coverage warnings.""" + legacy = { + "Mystery attack_bbbb": _experiment( + metadata={"attack_name": "Mystery attack"}, + instances={"instance_0": {"weird_metric": 0.42}}, + ) + } + result = convert_report(legacy, validate=False) + assert result.schema_errors == [] + assert result.curve_warnings == [] + assert result.coverage["metrics"] # coverage is independent of validation + + +# --- robustness: non-object input must not crash (regression guard) ------ + + +@pytest.mark.parametrize("payload", [[1, 2, 3], "a string", 42, None]) +def test_non_dict_top_level_yields_empty_report(payload: object) -> None: + """A non-object top-level report converts to empty attacks with a warning.""" + result = convert_report(payload, validate=False) + assert result.report["attacks"] == {} + assert any("not an object" in w for w in result.warnings) + + +# --- file + CLI entry points --------------------------------------------- + + +def test_convert_report_file_roundtrip(tmp_path: Path) -> None: + """Convert_report_file writes valid JSON and returns a result.""" + out = tmp_path / "nested" / "converted.json" + result = convert_report_file(FIXTURES / "legacy_basic.json", out) + assert out.is_file() + written = json.loads(out.read_text(encoding="utf-8")) + assert written["attacks"] == result.report["attacks"] + assert result.is_valid + + +def test_cli_convert_report( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture[str], +) -> None: + """The `sacroml convert-report` CLI converts and exits cleanly.""" + out = tmp_path / "converted.json" + monkeypatch.setattr( + "sys.argv", + ["sacroml", "convert-report", str(FIXTURES / "legacy_basic.json"), str(out)], + ) + with pytest.raises(SystemExit) as exc: + main() + assert exc.value.code == 0 + assert out.is_file() + assert "schema-valid" in capsys.readouterr().out + + +def test_cli_convert_report_malformed_json( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture[str], +) -> None: + """The CLI reports a friendly error and exits 1 on malformed JSON input.""" + bad = tmp_path / "bad.json" + bad.write_text("{ not valid json", encoding="utf-8") + out = tmp_path / "out.json" + monkeypatch.setattr( + "sys.argv", + ["sacroml", "convert-report", str(bad), str(out)], + ) + with pytest.raises(SystemExit) as exc: + main() + assert exc.value.code == 1 + assert "Could not parse" in capsys.readouterr().out + assert not out.exists() + + +def test_cli_convert_report_missing_input( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """The CLI exits non-zero when the input file does not exist.""" + monkeypatch.setattr( + "sys.argv", + [ + "sacroml", + "convert-report", + str(tmp_path / "nope.json"), + str(tmp_path / "out.json"), + ], + ) + with pytest.raises(SystemExit) as exc: + main() + assert exc.value.code == 1