From c1ea34f30971af75998693a929017b5a3fe555ce Mon Sep 17 00:00:00 2001 From: ssrhaso Date: Tue, 19 May 2026 12:19:28 +0100 Subject: [PATCH 1/7] feat(reporting): add legacy report.json conversion tooling Add `sacroml convert-report` CLI plus a `sacroml.reporting` subpackage that converts legacy flat report.json files into the new nested, catalog-enriched format and validates them against the bundled JSON schema. Includes catalog definitions, schema, tests, and docs. --- docs/source/attacks/report.rst | 41 + pyproject.toml | 2 + sacroml/main.py | 63 ++ sacroml/reporting/__init__.py | 21 + sacroml/reporting/catalog_definitions.json | 780 ++++++++++++++++++ sacroml/reporting/convert.py | 518 ++++++++++++ .../sacroml_attack_report.schema.json | 396 +++++++++ tests/reporting/__init__.py | 1 + tests/reporting/fixtures/legacy_basic.json | 91 ++ tests/reporting/fixtures/legacy_edge.json | 80 ++ tests/reporting/test_convert.py | 227 +++++ 11 files changed, 2220 insertions(+) create mode 100644 sacroml/reporting/__init__.py create mode 100644 sacroml/reporting/catalog_definitions.json create mode 100644 sacroml/reporting/convert.py create mode 100644 sacroml/reporting/sacroml_attack_report.schema.json create mode 100644 tests/reporting/__init__.py create mode 100644 tests/reporting/fixtures/legacy_basic.json create mode 100644 tests/reporting/fixtures/legacy_edge.json create mode 100644 tests/reporting/test_convert.py diff --git a/docs/source/attacks/report.rst b/docs/source/attacks/report.rst index 78bbd848..02a7789c 100644 --- a/docs/source/attacks/report.rst +++ b/docs/source/attacks/report.rst @@ -3,3 +3,44 @@ Report .. automodule:: sacroml.attacks.report :members: + +Converting legacy reports +------------------------- + +Older SACRO-ML versions write a flat ``report.json`` keyed by +``"_"``. The new reporting pipeline expects a nested, +catalog-enriched document validated by the bundled JSON schema +(``sacroml/reporting/sacroml_attack_report.schema.json``). + +Use the ``convert-report`` command to upgrade an existing report in place +without re-running any attacks: + +.. prompt:: bash + + sacroml convert-report report.json report_new.json + +The converter: + +* wraps the legacy experiments under a top-level ``attacks`` key; +* normalises experiment metadata so it satisfies the schema (injecting a + placeholder ``sacroml_version`` for very old reports, ensuring an + ``attack_instance_logger`` exists for instance-less structural attacks, + serialising non-scalar ``global_metrics`` values, etc.); +* injects the four human-readable catalogs (``metric_catalog``, + ``parameter_catalog``, ``attack_category_catalog``, ``attack_catalog``) + from the bundled common definitions in + ``sacroml/reporting/catalog_definitions.json``; +* **warns** when a metric, parameter, attack or attack category observed in + the report is not present in the catalogs (conversion still succeeds); and +* validates the result against the JSON schema. + +Curve-valued arrays (``fpr`` / ``tpr`` / ``roc_thresh``) are passed through +unchanged. ``roc_thresh`` legitimately starts with ``null``, which the schema +does not yet permit, so such violations are reported as notices rather than +errors. Pass ``--no-validate`` to skip schema validation. + +To extend the catalogs with site-specific metrics, parameters or attacks, +edit ``sacroml/reporting/catalog_definitions.json``. + +.. automodule:: sacroml.reporting.convert + :members: convert_report, convert_report_file, ConversionResult diff --git a/pyproject.toml b/pyproject.toml index 9e69c592..7f109dd9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ dependencies = [ "dictdiffer", "torch", "prompt-toolkit", + "jsonschema", ] [tool.setuptools.dynamic] @@ -87,6 +88,7 @@ packages = {find = {exclude = ["docs*", "examples*", "tests*", "user_stories*"]} [tool.setuptools.package-data] "sacroml.safemodel" = ["rules.json"] +"sacroml.reporting" = ["*.json"] [tool.ruff] indent-width = 4 diff --git a/sacroml/main.py b/sacroml/main.py index 2ee5434f..40c58333 100644 --- a/sacroml/main.py +++ b/sacroml/main.py @@ -4,10 +4,59 @@ import argparse import os +import sys from sacroml.attacks.factory import run_attacks from sacroml.config.attack import prompt_for_attack from sacroml.config.target import prompt_for_target +from sacroml.reporting import convert_report_file + + +def _run_convert_report(args: argparse.Namespace) -> int: + """Convert a legacy report.json into the new report format. + + Returns + ------- + int + Process exit code (0 on success, 1 if the output is schema-invalid). + """ + if not os.path.isfile(args.input): + print(f"Input report not found: {args.input}") + return 1 + + result = convert_report_file(args.input, args.output, validate=not args.no_validate) + + print(f"Converted '{args.input}' -> '{args.output}'") + for dim, summary in result.coverage.items(): + print( + f" {dim}: {len(summary['covered'])} catalogued, " + f"{len(summary['missing'])} uncatalogued" + ) + + if result.warnings: + print(f"\nWarnings ({len(result.warnings)}):") + for warning in result.warnings: + print(f" - {warning}") + + if result.curve_warnings: + print( + f"\nCurve-array notices ({len(result.curve_warnings)}): " + "fpr/tpr/roc_thresh arrays are passed through unchanged and do " + "not strictly validate yet." + ) + for warning in result.curve_warnings: + print(f" - {warning}") + + if result.schema_errors: + print(f"\nSchema errors ({len(result.schema_errors)}):") + for error in result.schema_errors: + print(f" - {error}") + print("\nConverted report is NOT schema-valid.") + return 1 + + if not args.no_validate: + print("\nConverted report is schema-valid.") + return 0 def main() -> None: @@ -26,6 +75,18 @@ def main() -> None: subparsers.add_parser("gen-target", help="Generate Target YAML config") subparsers.add_parser("gen-attack", help="Generate Attack YAML config") + convert = subparsers.add_parser( + "convert-report", + help="Convert a legacy report.json into the new report format", + ) + convert.add_argument("input", type=str, help="Path to the legacy report.json") + convert.add_argument("output", type=str, help="Path to write the new report") + convert.add_argument( + "--no-validate", + action="store_true", + help="Skip JSON schema validation of the converted report", + ) + args = parser.parse_args() if args.cmd == "run": @@ -37,6 +98,8 @@ def main() -> None: prompt_for_target() elif args.cmd == "gen-attack": prompt_for_attack() + elif args.cmd == "convert-report": + sys.exit(_run_convert_report(args)) if __name__ == "__main__": # pragma:no cover diff --git a/sacroml/reporting/__init__.py b/sacroml/reporting/__init__.py new file mode 100644 index 00000000..8a0b8877 --- /dev/null +++ b/sacroml/reporting/__init__.py @@ -0,0 +1,21 @@ +"""Reporting utilities for SACRO-ML attack reports. + +This subpackage contains the canonical attack-report JSON schema, a set of +common catalog definitions, and tooling to convert legacy ``report.json`` +files produced by older SACRO-ML versions into the new nested, +catalog-enriched report format. +""" + +from __future__ import annotations + +from sacroml.reporting.convert import ( + ConversionResult, + convert_report, + convert_report_file, +) + +__all__ = [ + "ConversionResult", + "convert_report", + "convert_report_file", +] diff --git a/sacroml/reporting/catalog_definitions.json b/sacroml/reporting/catalog_definitions.json new file mode 100644 index 00000000..dcbcb244 --- /dev/null +++ b/sacroml/reporting/catalog_definitions.json @@ -0,0 +1,780 @@ +{ + "metric_catalog": { + "version": "metrics_sacroml_v1", + "metrics": { + "AUC": { + "label": "Area Under ROC Curve", + "description": "Measures the ability of the attack to distinguish members from non-members.", + "units": null, + "higher_is_better": true, + "category": "performance", + "typical_range": [ + 0.0, + 1.0 + ], + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": null + }, + "TPR": { + "label": "True Positive Rate", + "description": "Fraction of true members correctly identified.", + "units": null, + "higher_is_better": true, + "category": "performance", + "typical_range": [ + 0.0, + 1.0 + ], + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": null + }, + "FPR": { + "label": "False Positive Rate", + "description": "Fraction of non-members incorrectly identified as members.", + "units": null, + "higher_is_better": false, + "category": "performance", + "typical_range": [ + 0.0, + 1.0 + ], + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": null + }, + "risk_score": { + "label": "Aggregate Risk Score", + "description": "Summary privacy risk score aggregated from multiple attacks.", + "units": null, + "higher_is_better": true, + "category": "summary", + "typical_range": [ + 0.0, + 1.0 + ], + "allowed_aggregations": [ + "mean" + ], + "notes": "Interpretation depends on aggregation method." + }, + "Advantage": { + "label": "Attack advantage", + "description": "Difference between true positive rate and false positive rate.", + "units": null, + "higher_is_better": true, + "category": "performance", + "typical_range": null, + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": null + }, + "F1score": { + "label": "F1 score", + "description": "Harmonic mean of precision and recall.", + "units": null, + "higher_is_better": true, + "category": "performance", + "typical_range": [ + 0.0, + 1.0 + ], + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": null + }, + "ACC": { + "label": "Accuracy", + "description": "Overall classification accuracy.", + "units": null, + "higher_is_better": true, + "category": "performance", + "typical_range": [ + 0.0, + 1.0 + ], + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": null + }, + "AUC_sig": { + "label": "AUC significance", + "description": "Statistical significance of AUC relative to baseline.", + "units": null, + "higher_is_better": null, + "category": "significance", + "typical_range": null, + "allowed_aggregations": [ + "count" + ], + "notes": "Qualitative significance annotation." + }, + "PDIF_sig": { + "label": "PDIF significance", + "description": "Statistical significance of PDIF relative to baseline.", + "units": null, + "higher_is_better": null, + "category": "significance", + "typical_range": null, + "allowed_aggregations": [ + "count" + ], + "notes": "Qualitative significance annotation." + }, + "FAR": { + "label": "False Acceptance Rate", + "description": "Fraction of non-members incorrectly accepted as members.", + "units": null, + "higher_is_better": false, + "category": "performance", + "typical_range": [ + 0.0, + 1.0 + ], + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": null + }, + "TNR": { + "label": "True Negative Rate", + "description": "Fraction of non-members correctly identified as non-members.", + "units": null, + "higher_is_better": true, + "category": "performance", + "typical_range": [ + 0.0, + 1.0 + ], + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": null + }, + "PPV": { + "label": "Positive Predictive Value", + "description": "Probability that an instance predicted as a member is truly a member.", + "units": null, + "higher_is_better": true, + "category": "performance", + "typical_range": [ + 0.0, + 1.0 + ], + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": "Also known as precision." + }, + "NPV": { + "label": "Negative Predictive Value", + "description": "Probability that an instance predicted as a non-member is truly a non-member.", + "units": null, + "higher_is_better": true, + "category": "performance", + "typical_range": [ + 0.0, + 1.0 + ], + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": null + }, + "FNR": { + "label": "False Negative Rate", + "description": "Fraction of members incorrectly classified as non-members.", + "units": null, + "higher_is_better": false, + "category": "performance", + "typical_range": [ + 0.0, + 1.0 + ], + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": null + }, + "P_HIGHER_AUC": { + "label": "Probability of Higher AUC", + "description": "Estimated probability that the observed AUC exceeds the null or baseline AUC.", + "units": null, + "higher_is_better": true, + "category": "significance", + "typical_range": [ + 0.0, + 1.0 + ], + "allowed_aggregations": [ + "mean", + "min", + "max" + ], + "notes": "Often derived from permutation or bootstrap tests." + }, + "pred_prob_var": { + "label": "Prediction Probability Variance", + "description": "Variance of predicted membership probabilities.", + "units": null, + "higher_is_better": null, + "category": "diagnostic", + "typical_range": null, + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": "Used as a stability or confidence diagnostic." + }, + "null_auc_3sd_range": { + "label": "Null AUC \u00b13\u03c3 range", + "description": "Expected AUC range under the null hypothesis, defined as \u00b13 standard deviations.", + "units": null, + "higher_is_better": null, + "category": "significance", + "typical_range": null, + "allowed_aggregations": [ + "count" + ], + "notes": "Represented as a string range." + }, + "n_sig_auc_p_vals": { + "label": "Number of significant AUC p-values", + "description": "Count of repetitions where AUC p-values are statistically significant.", + "units": null, + "higher_is_better": true, + "category": "significance", + "typical_range": null, + "allowed_aggregations": [ + "sum", + "min", + "max" + ], + "notes": null + }, + "n_sig_auc_p_vals_corrected": { + "label": "Number of significant AUC p-values (corrected)", + "description": "Count of statistically significant AUC p-values after multiple-testing correction.", + "units": null, + "higher_is_better": true, + "category": "significance", + "typical_range": null, + "allowed_aggregations": [ + "sum", + "min", + "max" + ], + "notes": null + }, + "n_sig_pdif_vals": { + "label": "Number of significant PDIF values", + "description": "Count of PDIF values that are statistically significant.", + "units": null, + "higher_is_better": true, + "category": "significance", + "typical_range": null, + "allowed_aggregations": [ + "sum", + "min", + "max" + ], + "notes": null + }, + "n_sig_pdif_vals_corrected": { + "label": "Number of significant PDIF values (corrected)", + "description": "Count of statistically significant PDIF values after multiple-testing correction.", + "units": null, + "higher_is_better": true, + "category": "significance", + "typical_range": null, + "allowed_aggregations": [ + "sum", + "min", + "max" + ], + "notes": null + }, + "n_pos_test_examples": { + "label": "Number of positive test examples", + "description": "Total number of member (positive) examples in the test set.", + "units": null, + "higher_is_better": null, + "category": "dataset", + "typical_range": null, + "allowed_aggregations": [ + "min", + "max", + "mean" + ], + "notes": "Used to contextualise attack performance metrics." + }, + "n_neg_test_examples": { + "label": "Number of negative test examples", + "description": "Total number of non-member (negative) examples in the test set.", + "units": null, + "higher_is_better": null, + "category": "dataset", + "typical_range": null, + "allowed_aggregations": [ + "min", + "max", + "mean" + ], + "notes": "Used to contextualise attack performance metrics." + }, + "n_normal": { + "label": "Number of normal examples", + "description": "Number of examples considered normal or non-attacked in the evaluation.", + "units": null, + "higher_is_better": null, + "category": "dataset", + "typical_range": null, + "allowed_aggregations": [ + "min", + "max", + "mean" + ], + "notes": "Interpretation depends on attack context." + }, + "roc_thresh": { + "label": "ROC decision thresholds", + "description": "Decision thresholds used to compute ROC-based metrics.", + "units": null, + "higher_is_better": null, + "category": "diagnostic", + "typical_range": null, + "allowed_aggregations": [ + "count" + ], + "notes": "Typically used only for curve construction and diagnostics." + }, + "dof_risk": { + "label": "Dof Risk", + "description": "Degrees-of-freedom risk: model is overly complex relative to the training data size.", + "units": null, + "higher_is_better": false, + "category": "structural_risk", + "typical_range": null, + "allowed_aggregations": [ + "count" + ], + "notes": "Boolean structural-attack risk flag; True indicates a potential vulnerability." + }, + "k_anonymity_risk": { + "label": "K Anonymity Risk", + "description": "K-anonymity risk: training records fall into equivalence classes smaller than k.", + "units": null, + "higher_is_better": false, + "category": "structural_risk", + "typical_range": null, + "allowed_aggregations": [ + "count" + ], + "notes": "Boolean structural-attack risk flag; True indicates a potential vulnerability." + }, + "class_disclosure_risk": { + "label": "Class Disclosure Risk", + "description": "Class disclosure risk: outputs may reveal small disclosive groups.", + "units": null, + "higher_is_better": false, + "category": "structural_risk", + "typical_range": null, + "allowed_aggregations": [ + "count" + ], + "notes": "Boolean structural-attack risk flag; True indicates a potential vulnerability." + }, + "unnecessary_risk": { + "label": "Unnecessary Risk", + "description": "Unnecessary risk: hyperparameters set to values associated with higher MIA risk.", + "units": null, + "higher_is_better": false, + "category": "structural_risk", + "typical_range": null, + "allowed_aggregations": [ + "count" + ], + "notes": "Boolean structural-attack risk flag; True indicates a potential vulnerability." + }, + "lowvals_cd_risk": { + "label": "Lowvals Cd Risk", + "description": "Low-values class disclosure risk: class frequency within an equivalence class below a safe threshold.", + "units": null, + "higher_is_better": false, + "category": "structural_risk", + "typical_range": null, + "allowed_aggregations": [ + "count" + ], + "notes": "Boolean structural-attack risk flag; True indicates a potential vulnerability." + } + }, + "pattern_metrics": [ + { + "pattern": "^(fpr|tpr)$", + "label_template": "{metric_name} curve", + "description": "Curve-valued metric represented as an array of rates evaluated across decision thresholds. Intended for paired use (e.g. fpr vs tpr).", + "units": null, + "higher_is_better": null, + "category": "curve", + "allowed_aggregations": [ + "min", + "max" + ], + "notes": "Curve-valued metric. Should not be aggregated directly; use to derive scalar summaries such as AUC." + }, + { + "pattern": "^(FMAX|FMIN|FDIF|PDIF)[0-9]+$", + "label_template": "{metric_name}", + "description": "Threshold-based extremum or difference metric computed at a specified operating point.", + "units": null, + "higher_is_better": true, + "category": "threshold_analysis", + "allowed_aggregations": [ + "mean", + "min", + "max" + ], + "notes": "Metric name suffix indicates the threshold or operating point." + }, + { + "pattern": "^TPR@.+$", + "label_template": "TPR at threshold {metric_name}", + "description": "True Positive Rate evaluated at a specific decision threshold indicated in the metric name.", + "units": null, + "higher_is_better": true, + "category": "performance", + "allowed_aggregations": [ + "mean", + "min", + "max", + "stddev", + "var" + ], + "notes": "Metric suffix indicates the decision threshold (e.g. TPR@0.5, TPR@0.1)." + } + ] + }, + "parameter_catalog": { + "version": "params_v1", + "parameters": { + "n_reps": { + "label": "Number of repetitions", + "description": "Number of independent repetitions used to estimate variability.", + "parameter_scope": "attack", + "value_type": "integer", + "notes": null + }, + "p_thresh": { + "label": "Probability threshold", + "description": "Threshold used to decide membership based on predicted probability.", + "parameter_scope": "attack", + "value_type": "float", + "notes": null + }, + "attack_model": { + "label": "Attack model class", + "description": "Classifier used by the attacker.", + "parameter_scope": "attack", + "value_type": "string", + "notes": null + }, + "n_shadow_models": { + "label": "Number of shadow models", + "description": "How many shadow models are trained to emulate the target.", + "parameter_scope": "attack", + "value_type": "integer", + "notes": null + }, + "aggregation_method": { + "label": "Aggregation method", + "description": "Method used to aggregate results from multiple attacks.", + "parameter_scope": "attack", + "value_type": "string", + "notes": null + }, + "mode": { + "label": "Execution mode", + "description": "Mode in which the attack operates, e.g. offline or online.", + "parameter_scope": "attack", + "value_type": "string", + "notes": null + }, + "fix_variance": { + "label": "Fix variance", + "description": "Whether to fix variance when estimating likelihood ratios.", + "parameter_scope": "attack", + "value_type": "boolean", + "notes": null + }, + "report_individual": { + "label": "Report individual results", + "description": "Whether to report per-instance attack results.", + "parameter_scope": "attack", + "value_type": "boolean", + "notes": null + }, + "output_dir": { + "label": "Output directory", + "description": "Directory in which attack artefacts are written.", + "parameter_scope": "global", + "value_type": "string", + "notes": null + }, + "write_report": { + "label": "Write report", + "description": "Whether a report should be written to disk.", + "parameter_scope": "global", + "value_type": "boolean", + "notes": null + }, + "reproduce_split": { + "label": "Reproduce split seed", + "description": "Seed used to reproduce the train/test split.", + "parameter_scope": "attack", + "value_type": "integer", + "notes": null + }, + "n_dummy_reps": { + "label": "Number of dummy repetitions", + "description": "Number of dummy or null repetitions used for baseline estimation.", + "parameter_scope": "attack", + "value_type": "integer", + "notes": null + }, + "train_beta": { + "label": "Training beta", + "description": "Beta parameter applied during training phase of the attack.", + "parameter_scope": "attack", + "value_type": "number", + "notes": null + }, + "test_beta": { + "label": "Testing beta", + "description": "Beta parameter applied during testing phase of the attack.", + "parameter_scope": "attack", + "value_type": "number", + "notes": null + }, + "test_prop": { + "label": "Test proportion", + "description": "Proportion of data reserved for testing.", + "parameter_scope": "training", + "value_type": "number", + "notes": null + }, + "include_model_correct_feature": { + "label": "Include model correctness feature", + "description": "Whether a feature indicating model correctness is included.", + "parameter_scope": "attack", + "value_type": "boolean", + "notes": null + }, + "sort_probs": { + "label": "Sort probabilities", + "description": "Whether predicted probabilities are sorted before analysis.", + "parameter_scope": "attack", + "value_type": "boolean", + "notes": null + }, + "attack_model_params": { + "label": "Attack model parameters", + "description": "Hyperparameters for the attack model.", + "parameter_scope": "attack", + "value_type": "object", + "notes": "Null indicates default parameters." + } + } + }, + "attack_category_catalog": { + "version": "categories_v1", + "categories": { + "structural": { + "label": "Structural attacks", + "description": "Attacks exploiting deterministic or structural properties of the target model.", + "order": 10 + }, + "probabilistic": { + "label": "Probabilistic attacks", + "description": "Attacks based on confidence scores or probability distributions.", + "order": 20 + }, + "meta": { + "label": "Meta attacks", + "description": "Attacks that combine or summarise the results of multiple other attacks.", + "order": 30 + } + } + }, + "attack_catalog": { + "version": "attacks_sacroml_v1", + "attacks": { + "WorstCase attack": { + "label": "Worst\u2011Case Membership Inference Attack", + "description": "Assumes the attacker has maximal auxiliary knowledge.", + "attack_category": "structural", + "key_metrics": [ + "AUC", + "TPR" + ], + "attack_params": { + "n_reps": { + "label": "Number of repetitions", + "description": "Independent repetitions of the attack." + }, + "p_thresh": { + "label": "Probability threshold", + "description": "Decision threshold for membership prediction." + } + } + }, + "ShadowModel attack": { + "label": "Shadow Model Attack", + "description": "Uses shadow models to approximate the target model\u2019s behaviour.", + "attack_category": "probabilistic", + "key_metrics": [ + "AUC" + ], + "attack_params": { + "n_shadow_models": { + "label": "Number of shadow models", + "description": "How many shadow models to train." + }, + "attack_model": { + "label": "Attack model", + "description": "Classifier trained on shadow model outputs." + } + } + }, + "AggregateRisk attack": { + "label": "Aggregate Risk Meta\u2011Attack", + "description": "Aggregates the outputs of multiple attacks into a single risk score.", + "attack_category": "meta", + "key_metrics": [ + "risk_score" + ], + "attack_params": { + "aggregation_method": { + "label": "Aggregation method", + "description": "Strategy used to combine individual attack results." + } + } + }, + "LiRA Attack": { + "label": "Likelihood Ratio Attack (LiRA)", + "description": "Membership inference attack based on likelihood ratios derived from shadow models.", + "attack_category": "probabilistic", + "key_metrics": [ + "AUC", + "TPR" + ], + "attack_params": { + "n_shadow_models": { + "label": "Number of shadow models", + "description": "Number of shadow models used to estimate likelihoods." + }, + "p_thresh": { + "label": "Probability threshold", + "description": "Threshold used for membership decision." + }, + "mode": { + "label": "Execution mode", + "description": "Whether the attack is run in offline or online mode." + }, + "fix_variance": { + "label": "Fix variance", + "description": "Whether variance is fixed when computing likelihood ratios." + }, + "report_individual": { + "label": "Report individual", + "description": "Whether to report individual-instance results." + } + } + }, + "Structural Attack": { + "label": "Structural Attack", + "description": "Static analysis of the target model's structure and training data against pre-defined disclosure-risk rules and thresholds.", + "attack_category": "structural", + "key_metrics": [ + "dof_risk", + "k_anonymity_risk", + "class_disclosure_risk" + ], + "attack_params": { + "output_dir": { + "label": "Output directory", + "description": "Directory in which attack artefacts are written." + } + } + }, + "Attribute inference attack": { + "label": "Attribute Inference Attack", + "description": "Attempts to infer sensitive attribute values of training records from the target model's outputs.", + "attack_category": "probabilistic", + "key_metrics": [ + "risk_score" + ], + "attack_params": { + "output_dir": { + "label": "Output directory", + "description": "Directory in which attack artefacts are written." + } + } + } + } + } +} diff --git a/sacroml/reporting/convert.py b/sacroml/reporting/convert.py new file mode 100644 index 00000000..93187dd8 --- /dev/null +++ b/sacroml/reporting/convert.py @@ -0,0 +1,518 @@ +"""Convert legacy SACRO-ML ``report.json`` files into the new report format. + +Older SACRO-ML versions write a flat JSON object keyed by +``"_"``. The new reporting pipeline expects a nested, +catalog-enriched document validated by +``sacroml/reporting/sacroml_attack_report.schema.json``. + +This module performs that conversion: + +1. Wrap the top-level experiments under an ``"attacks"`` key. +2. Normalise experiment metadata so it satisfies the schema (injecting a + placeholder ``sacroml_version`` for very old reports, ensuring an + ``attack_instance_logger`` exists for instance-less structural attacks, + coercing non-scalar ``global_metrics`` values, etc.). +3. Inject the four human-readable catalogs (``metric_catalog``, + ``parameter_catalog``, ``attack_category_catalog``, ``attack_catalog``) + from a bundled common-definitions file. +4. Diff everything observed in the report against the catalogs and emit + warnings for metrics, parameters, attacks and attack categories that are + not catalogued (the conversion still succeeds). +5. Validate the result against the JSON schema. Schema violations caused + purely by curve-valued arrays (``fpr`` / ``tpr`` / ``roc_thresh`` -- the + latter legitimately starts with ``null``) are downgraded to warnings, + since they are a known limitation tracked separately; any other schema + violation is reported as an error. + +The conversion never mutates curve arrays: they are passed through verbatim. +""" + +from __future__ import annotations + +import json +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from jsonschema import Draft7Validator + +REPORT_SCHEMA_VERSION = "1.2" + +_HERE = Path(__file__).parent +SCHEMA_PATH = _HERE / "sacroml_attack_report.schema.json" +CATALOG_DEFINITIONS_PATH = _HERE / "catalog_definitions.json" + +# Top-level keys that are *not* experiments in an already-converted report. +_NON_EXPERIMENT_KEYS = frozenset( + { + "report_schema_version", + "metric_catalog", + "parameter_catalog", + "attack_category_catalog", + "attack_catalog", + "attacks", + } +) + +# Metadata keys permitted by the schema (``additionalProperties: false``). +_ALLOWED_METADATA_KEYS = frozenset( + { + "sacroml_version", + "attack_name", + "attack_params", + "global_metrics", + "baseline_global_metrics", + "target_model", + "target_model_params", + "target_train_params", + } +) +_ATTACK_KEY_RE = re.compile(r"^[A-Za-z0-9 _\-]+$") +_INSTANCE_KEY_RE = re.compile(r"^instance_[0-9]+$") + + +@dataclass +class ConversionResult: + """Outcome of converting a legacy report. + + Attributes + ---------- + report : dict + The converted report (new format). + warnings : list[str] + Non-fatal issues: uncatalogued metrics/parameters/attacks/categories + and structural normalisations that were applied. + curve_warnings : list[str] + Schema violations attributable solely to curve-valued arrays + (``fpr`` / ``tpr`` / ``roc_thresh``); these are expected and benign. + schema_errors : list[str] + Genuine schema violations that make the output invalid. + coverage : dict + Per-dimension ``{"covered": [...], "missing": [...]}`` summary. + """ + + report: dict[str, Any] + warnings: list[str] = field(default_factory=list) + curve_warnings: list[str] = field(default_factory=list) + schema_errors: list[str] = field(default_factory=list) + coverage: dict[str, dict[str, list[str]]] = field(default_factory=dict) + + @property + def is_valid(self) -> bool: + """Return whether the converted report passes schema validation. + + Curve-array violations do not count as failures. + """ + return not self.schema_errors + + +def _load_catalog_definitions() -> dict[str, Any]: + """Load the bundled common catalog definitions.""" + with open(CATALOG_DEFINITIONS_PATH, encoding="utf-8") as fh: + return json.load(fh) + + +def _load_schema() -> dict[str, Any]: + """Load the bundled attack-report JSON schema.""" + with open(SCHEMA_PATH, encoding="utf-8") as fh: + return json.load(fh) + + +def _is_experiment(value: Any) -> bool: + """Return whether a top-level value looks like a legacy experiment.""" + return isinstance(value, dict) and ( + "attack_experiment_logger" in value or "metadata" in value or "log_id" in value + ) + + +def _extract_experiments(data: dict[str, Any], warnings: list[str]) -> dict[str, Any]: + """Return the experiment mapping, whether the input is flat or wrapped.""" + if isinstance(data.get("attacks"), dict): + # Already wrapped (idempotent path). + return dict(data["attacks"]) + + experiments: dict[str, Any] = {} + for key, value in data.items(): + if key in _NON_EXPERIMENT_KEYS: + continue + if _is_experiment(value): + experiments[key] = value + else: + warnings.append( + f"Top-level key '{key}' does not look like an experiment " + "and was dropped during conversion." + ) + return experiments + + +def _coerce_scalar_metrics( + metrics: Any, where: str, warnings: list[str] +) -> dict[str, Any]: + """Ensure a (baseline_)global_metrics mapping holds only scalars.""" + if not isinstance(metrics, dict): + warnings.append(f"{where} was not an object; replaced with empty object.") + return {} + cleaned: dict[str, Any] = {} + for key, value in metrics.items(): + if value is None or isinstance(value, (str, bool, int, float)): + cleaned[key] = value + else: + cleaned[key] = json.dumps(value, default=str) + warnings.append( + f"{where}['{key}'] held a non-scalar value; " + "serialised to a JSON string to satisfy the schema." + ) + return cleaned + + +def _normalise_instance_logger( + logger: Any, exp_key: str, warnings: list[str] +) -> dict[str, Any]: + """Return an ``attack_instance_logger`` with schema-conforming keys.""" + if not isinstance(logger, dict): + return {} + normalised: dict[str, Any] = {} + next_index = 0 + for key, value in logger.items(): + if _INSTANCE_KEY_RE.match(key): + normalised[key] = value + else: + new_key = f"instance_{next_index}" + while new_key in logger or new_key in normalised: + next_index += 1 + new_key = f"instance_{next_index}" + normalised[new_key] = value + warnings.append( + f"Experiment '{exp_key}': instance key '{key}' did not match " + f"'instance_'; renamed to '{new_key}'." + ) + next_index += 1 + return normalised + + +def _normalise_metadata( + exp_key: str, raw_metadata: Any, warnings: list[str] +) -> dict[str, Any]: + """Normalise an experiment's metadata so it conforms to the schema.""" + metadata = dict(raw_metadata or {}) + + for key in sorted(set(metadata) - _ALLOWED_METADATA_KEYS): + warnings.append( + f"Experiment '{exp_key}': metadata key '{key}' is not permitted " + "by the schema and was dropped." + ) + metadata.pop(key, None) + + if "sacroml_version" not in metadata: + metadata["sacroml_version"] = "unknown" + warnings.append( + f"Experiment '{exp_key}': metadata.sacroml_version was missing; " + "set to 'unknown'." + ) + else: + metadata["sacroml_version"] = str(metadata["sacroml_version"]) + + metadata["attack_name"] = str( + metadata.get("attack_name", exp_key.rsplit("_", 1)[0]) + ) + if not isinstance(metadata.get("attack_params"), dict): + metadata["attack_params"] = {} + + metadata["global_metrics"] = _coerce_scalar_metrics( + metadata.get("global_metrics", {}), + f"Experiment '{exp_key}': metadata.global_metrics", + warnings, + ) + if "baseline_global_metrics" in metadata: + metadata["baseline_global_metrics"] = _coerce_scalar_metrics( + metadata["baseline_global_metrics"], + f"Experiment '{exp_key}': metadata.baseline_global_metrics", + warnings, + ) + + if "target_model" in metadata and not isinstance(metadata["target_model"], str): + metadata["target_model"] = str(metadata["target_model"]) + for opt in ("target_model_params", "target_train_params"): + if opt in metadata and not isinstance(metadata[opt], dict): + metadata[opt] = {} + return metadata + + +def _normalise_attack_experiment_logger( + exp_key: str, raw_logger: Any, warnings: list[str] +) -> dict[str, Any]: + """Return an ``attack_experiment_logger`` holding only the instance log.""" + ael = raw_logger if isinstance(raw_logger, dict) else {} + instance_logger = _normalise_instance_logger( + ael.get("attack_instance_logger", {}), exp_key, warnings + ) + for key in sorted(set(ael) - {"attack_instance_logger"}): + warnings.append( + f"Experiment '{exp_key}': attack_experiment_logger.{key} is not " + "permitted by the schema and was dropped." + ) + return {"attack_instance_logger": instance_logger} + + +def _normalise_experiment( + exp_key: str, experiment: Any, warnings: list[str] +) -> dict[str, Any]: + """Normalise a single experiment so it conforms to the schema.""" + if not isinstance(experiment, dict): + warnings.append(f"Experiment '{exp_key}' was not an object and was skipped.") + return {} + + exp = dict(experiment) + exp["log_id"] = str(exp.get("log_id", exp_key)) + if "log_time" not in exp or not isinstance(exp["log_time"], str): + exp["log_time"] = str(exp.get("log_time", "unknown")) + exp["metadata"] = _normalise_metadata(exp_key, exp.get("metadata", {}), warnings) + exp["attack_experiment_logger"] = _normalise_attack_experiment_logger( + exp_key, exp.get("attack_experiment_logger"), warnings + ) + return exp + + +def _compile_pattern_metrics(metric_catalog: dict[str, Any]) -> list[re.Pattern]: + """Compile the regexes declared in ``metric_catalog.pattern_metrics``.""" + patterns: list[re.Pattern] = [] + for entry in metric_catalog.get("pattern_metrics", []): + pattern = entry.get("pattern") if isinstance(entry, dict) else None + if pattern: + patterns.append(re.compile(pattern)) + return patterns + + +def _diff( + seen: set[str], explicit: set[str], patterns: list[re.Pattern] | None = None +) -> tuple[list[str], list[str]]: + """Split ``seen`` names into catalogued (covered) and missing.""" + patterns = patterns or [] + + def catalogued(name: str) -> bool: + if name in explicit: + return True + return any(p.match(name) for p in patterns) + + missing = sorted(n for n in seen if not catalogued(n)) + covered = sorted(seen - set(missing)) + return covered, missing + + +def _compute_coverage( + report: dict[str, Any], catalogs: dict[str, Any] +) -> tuple[dict[str, dict[str, list[str]]], list[str]]: + """Diff observed metrics/parameters/attacks/categories vs the catalogs.""" + seen_metrics: set[str] = set() + seen_params: set[str] = set() + seen_attacks: set[str] = set() + + for exp in report["attacks"].values(): + instances = exp["attack_experiment_logger"]["attack_instance_logger"] + for inst in instances.values(): + if isinstance(inst, dict): + seen_metrics.update(inst.keys()) + metadata = exp.get("metadata", {}) + seen_metrics.update(metadata.get("global_metrics", {}).keys()) + seen_metrics.update(metadata.get("baseline_global_metrics", {}).keys()) + seen_params.update(metadata.get("attack_params", {}).keys()) + seen_attacks.add(metadata.get("attack_name", "")) + + seen_attacks.discard("") + + metric_catalog = catalogs["metric_catalog"] + explicit_metrics = set(metric_catalog.get("metrics", {}).keys()) + pattern_regexes = _compile_pattern_metrics(metric_catalog) + m_covered, m_missing = _diff(seen_metrics, explicit_metrics, pattern_regexes) + + explicit_params = set(catalogs["parameter_catalog"].get("parameters", {}).keys()) + p_covered, p_missing = _diff(seen_params, explicit_params) + + catalog_attacks = catalogs["attack_catalog"].get("attacks", {}) + a_covered, a_missing = _diff(seen_attacks, set(catalog_attacks.keys())) + + # Categories referenced by the attacks we *can* resolve in the catalog. + known_categories = set( + catalogs["attack_category_catalog"].get("categories", {}).keys() + ) + seen_categories = { + catalog_attacks[name]["attack_category"] + for name in seen_attacks + if name in catalog_attacks and "attack_category" in catalog_attacks[name] + } + c_covered, c_missing = _diff(seen_categories, known_categories) + + coverage = { + "metrics": {"covered": m_covered, "missing": m_missing}, + "parameters": {"covered": p_covered, "missing": p_missing}, + "attacks": {"covered": a_covered, "missing": a_missing}, + "attack_categories": {"covered": c_covered, "missing": c_missing}, + } + + warnings: list[str] = [] + labels = { + "metrics": "metric", + "parameters": "parameter", + "attacks": "attack", + "attack_categories": "attack category", + } + for dim, label in labels.items(): + for name in coverage[dim]["missing"]: + warnings.append( + f"Uncatalogued {label}: '{name}' is not present in the {dim} catalog." + ) + return coverage, warnings + + +def _is_curve_violation(error: Any, report: dict[str, Any]) -> bool: + """Return whether a schema error is caused solely by a curve array. + + Curve-valued metrics (``fpr`` / ``tpr`` / ``roc_thresh``) are stored as + raw arrays that may contain ``null`` (e.g. ``roc_thresh[0]``). These are + a known schema limitation and are reported as warnings, not errors. + """ + path = list(error.absolute_path) + # Expected path: attacks / / attack_experiment_logger / + # attack_instance_logger / instance_ / [/ idx] + if len(path) < 6: + return False + if path[0] != "attacks" or path[2] != "attack_experiment_logger": + return False + if path[3] != "attack_instance_logger": + return False + metric = path[5] + if not isinstance(metric, str): + return False + try: + value = report["attacks"][path[1]]["attack_experiment_logger"][ + "attack_instance_logger" + ][path[4]][metric] + except (KeyError, TypeError, IndexError): + return False + return isinstance(value, list) + + +def _validate( + report: dict[str, Any], +) -> tuple[list[str], list[str]]: + """Validate ``report`` against the schema. + + Returns + ------- + tuple[list[str], list[str]] + ``(schema_errors, curve_warnings)`` -- human-readable messages. + """ + validator = Draft7Validator(_load_schema()) + schema_errors: list[str] = [] + curve_warnings: list[str] = [] + for error in sorted(validator.iter_errors(report), key=str): + location = "/".join(str(p) for p in error.absolute_path) or "" + detail = error.message + if len(detail) > 200: + detail = detail[:200] + "... (truncated)" + message = f"{location}: {detail}" + if _is_curve_violation(error, report): + curve_warnings.append(message) + else: + schema_errors.append(message) + return schema_errors, curve_warnings + + +def convert_report(data: dict[str, Any], *, validate: bool = True) -> ConversionResult: + """Convert an in-memory legacy report dictionary to the new format. + + Parameters + ---------- + data : dict + The parsed legacy ``report.json`` (flat) or an already-wrapped report. + validate : bool, default True + Whether to validate the converted report against the JSON schema. + + Returns + ------- + ConversionResult + The converted report plus warnings, schema errors and a coverage + summary. + """ + warnings: list[str] = [] + experiments = _extract_experiments(data, warnings) + + converted_attacks: dict[str, Any] = {} + for exp_key, experiment in experiments.items(): + safe_key = exp_key + if not _ATTACK_KEY_RE.match(exp_key): + safe_key = re.sub(r"[^A-Za-z0-9 _\-]", "_", exp_key) + warnings.append( + f"Experiment key '{exp_key}' contained characters not allowed " + f"by the schema; renamed to '{safe_key}'." + ) + normalised = _normalise_experiment(exp_key, experiment, warnings) + if normalised: + converted_attacks[safe_key] = normalised + + catalogs = _load_catalog_definitions() + report: dict[str, Any] = { + "report_schema_version": REPORT_SCHEMA_VERSION, + "metric_catalog": catalogs["metric_catalog"], + "parameter_catalog": catalogs["parameter_catalog"], + "attack_category_catalog": catalogs["attack_category_catalog"], + "attack_catalog": catalogs["attack_catalog"], + "attacks": converted_attacks, + } + + coverage, coverage_warnings = _compute_coverage(report, catalogs) + warnings.extend(coverage_warnings) + + schema_errors: list[str] = [] + curve_warnings: list[str] = [] + if validate: + schema_errors, curve_warnings = _validate(report) + + return ConversionResult( + report=report, + warnings=warnings, + curve_warnings=curve_warnings, + schema_errors=schema_errors, + coverage=coverage, + ) + + +def convert_report_file( + input_path: str | Path, + output_path: str | Path, + *, + validate: bool = True, + indent: int = 2, +) -> ConversionResult: + """Convert a legacy ``report.json`` file and write the new format to disk. + + Parameters + ---------- + input_path : str | Path + Path to the legacy report. + output_path : str | Path + Path to write the converted report to. + validate : bool, default True + Whether to validate the converted report against the JSON schema. + indent : int, default 2 + Indentation for the written JSON. + + Returns + ------- + ConversionResult + The conversion outcome. + """ + with open(input_path, encoding="utf-8") as fh: + data = json.load(fh) + + result = convert_report(data, validate=validate) + + output_path = Path(output_path) + if output_path.parent and not output_path.parent.exists(): + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w", encoding="utf-8") as fh: + json.dump(result.report, fh, indent=indent) + fh.write("\n") + + return result diff --git a/sacroml/reporting/sacroml_attack_report.schema.json b/sacroml/reporting/sacroml_attack_report.schema.json new file mode 100644 index 00000000..564b6592 --- /dev/null +++ b/sacroml/reporting/sacroml_attack_report.schema.json @@ -0,0 +1,396 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "SacroML Attack Report", + "description": "A nested, attack-agnostic schema for SacroML attack reports. Includes human-readable catalogs (metrics, parameters, attack categories, attacks).", + "type": "object", + "properties": { + "report_schema_version": { + "type": "string" + }, + "metric_catalog": { + "type": "object", + "properties": { + "version": { + "type": "string" + }, + "metrics": { + "type": "object", + "patternProperties": { + "^[A-Za-z0-9_@.eE+\\-]+$": { + "$ref": "#/$defs/MetricDefinition" + } + }, + "additionalProperties": false + }, + "pattern_metrics": { + "type": "array", + "items": { + "$ref": "#/$defs/PatternMetricDefinition" + } + } + }, + "required": [ + "version", + "metrics" + ], + "additionalProperties": false + }, + "parameter_catalog": { + "type": "object", + "properties": { + "version": { + "type": "string" + }, + "parameters": { + "type": "object", + "patternProperties": { + "^[A-Za-z0-9_\\-]+$": { + "$ref": "#/$defs/ParameterDefinition" + } + }, + "additionalProperties": false + } + }, + "required": [ + "version", + "parameters" + ], + "additionalProperties": false + }, + "attack_category_catalog": { + "type": "object", + "description": "Human-readable definitions of high-level attack categories.", + "properties": { + "version": { + "type": "string" + }, + "categories": { + "type": "object", + "patternProperties": { + "^[A-Za-z0-9_\\-]+$": { + "$ref": "#/$defs/AttackCategoryDefinition" + } + }, + "additionalProperties": false + } + }, + "required": [ + "version", + "categories" + ], + "additionalProperties": false + }, + "attack_catalog": { + "type": "object", + "properties": { + "version": { + "type": "string" + }, + "attacks": { + "type": "object", + "patternProperties": { + "^[A-Za-z0-9 _\\-]+$": { + "$ref": "#/$defs/AttackKindDefinition" + } + }, + "additionalProperties": false + } + }, + "required": [ + "attacks" + ], + "additionalProperties": false + }, + "attacks": { + "type": "object", + "patternProperties": { + "^[A-Za-z0-9 _\\-]+$": { + "type": "object", + "properties": { + "log_id": { + "type": "string" + }, + "log_time": { + "type": "string" + }, + "metadata": { + "type": "object", + "properties": { + "sacroml_version": { + "type": "string" + }, + "attack_name": { + "type": "string" + }, + "attack_params": { + "type": "object", + "additionalProperties": true + }, + "global_metrics": { + "type": "object", + "additionalProperties": { + "type": [ + "number", + "string", + "boolean", + "null" + ] + } + }, + "baseline_global_metrics": { + "type": "object", + "additionalProperties": { + "type": [ + "number", + "string", + "boolean", + "null" + ] + } + }, + "target_model": { + "type": "string" + }, + "target_model_params": { + "type": "object", + "additionalProperties": true + }, + "target_train_params": { + "type": "object", + "additionalProperties": true + } + }, + "required": [ + "sacroml_version", + "attack_name", + "attack_params", + "global_metrics" + ], + "additionalProperties": false + }, + "attack_experiment_logger": { + "type": "object", + "properties": { + "attack_instance_logger": { + "type": "object", + "patternProperties": { + "^instance_[0-9]+$": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array", + "items": { + "type": "number" + } + }, + { + "type": "null" + } + ] + } + } + }, + "additionalProperties": false + } + }, + "required": [ + "attack_instance_logger" + ], + "additionalProperties": false + } + }, + "required": [ + "log_id", + "log_time", + "metadata", + "attack_experiment_logger" + ], + "additionalProperties": true + } + }, + "additionalProperties": false + } + }, + "required": [ + "metric_catalog", + "parameter_catalog", + "attack_category_catalog", + "attack_catalog", + "attacks" + ], + "additionalProperties": false, + "$defs": { + "AttackCategoryDefinition": { + "type": "object", + "description": "Human-readable definition of an attack category, including report ordering.", + "properties": { + "label": { + "type": "string", + "description": "Human-readable category label." + }, + "description": { + "type": "string", + "description": "Explanation of what this category represents." + }, + "order": { + "type": "integer", + "description": "Defines the order in which this category should be presented in reports (lower first)." + } + }, + "required": [ + "label", + "description", + "order" + ], + "additionalProperties": false + }, + "AttackKindDefinition": { + "type": "object", + "description": "Attack kind definition with category, key metrics, and obligatory attack-specific parameter declarations.", + "properties": { + "label": { + "type": "string", + "description": "Human-readable attack name." + }, + "description": { + "type": "string", + "description": "Description of what the attack does." + }, + "attack_category": { + "type": "string", + "description": "Key into attack_category_catalog.categories." + }, + "key_metrics": { + "type": "array", + "items": { + "type": "string" + }, + "minItems": 1 + }, + "attack_params": { + "type": "object", + "description": "Human-readable declarations of all parameters supported by this attack kind.", + "patternProperties": { + "^[A-Za-z0-9_\\-]+$": { + "type": "object", + "properties": { + "label": { + "type": "string" + }, + "description": { + "type": "string" + } + }, + "required": [ + "label", + "description" + ], + "additionalProperties": false + } + }, + "additionalProperties": false + } + }, + "required": [ + "label", + "description", + "attack_category", + "key_metrics", + "attack_params" + ], + "additionalProperties": false + }, + "MetricDefinition": { + "type": "object", + "description": "Human-readable description of a metric.", + "properties": { + "label": { + "type": "string" + }, + "description": { + "type": "string" + }, + "units": { + "type": [ + "string", + "null" + ] + }, + "higher_is_better": { + "type": [ + "boolean", + "null" + ] + }, + "category": { + "type": [ + "string", + "null" + ] + }, + "typical_range": { + "oneOf": [ + { + "type": "array", + "items": { + "type": "number" + }, + "minItems": 2, + "maxItems": 2 + }, + { + "type": "null" + } + ] + }, + "allowed_aggregations": { + "type": "array", + "description": "List of aggregation operations that are valid and meaningful for this metric.", + "items": { + "type": "string", + "enum": [ + "sum", + "mean", + "min", + "max", + "median", + "stddev", + "var", + "quantile", + "count" + ] + }, + "minItems": 0, + "uniqueItems": true + }, + "notes": { + "type": [ + "string", + "null" + ] + } + }, + "required": [ + "label", + "description", + "allowed_aggregations" + ], + "additionalProperties": false + }, + "PatternMetricDefinition": { + "type": "object", + "additionalProperties": true + }, + "ParameterDefinition": { + "type": "object", + "additionalProperties": true + } + } +} diff --git a/tests/reporting/__init__.py b/tests/reporting/__init__.py new file mode 100644 index 00000000..e3c7d084 --- /dev/null +++ b/tests/reporting/__init__.py @@ -0,0 +1 @@ +"""Tests for the sacroml.reporting package.""" diff --git a/tests/reporting/fixtures/legacy_basic.json b/tests/reporting/fixtures/legacy_basic.json new file mode 100644 index 00000000..7566660e --- /dev/null +++ b/tests/reporting/fixtures/legacy_basic.json @@ -0,0 +1,91 @@ +{ + "WorstCase attack_de2c5fc0-fb0c-4925-ac4d-26662fe7f786": { + "log_id": "de2c5fc0-fb0c-4925-ac4d-26662fe7f786", + "log_time": "05/11/2025 16:54:34", + "metadata": { + "sacroml_version": "1.4.0", + "attack_name": "WorstCase attack", + "attack_params": { + "output_dir": "training_artefacts", + "write_report": true, + "n_reps": 2, + "p_thresh": 0.05 + }, + "global_metrics": { + "null_auc_3sd_range": "0.4629 -> 0.5371", + "n_sig_auc_p_vals": 2 + }, + "baseline_global_metrics": { + "n_sig_auc_p_vals": 0 + }, + "target_model": "RandomForestClassifier", + "target_model_params": {}, + "target_train_params": {} + }, + "attack_experiment_logger": { + "attack_instance_logger": { + "instance_0": { + "TPR": 0.019, + "FPR": 0.007, + "ACC": 0.7, + "AUC": 0.57, + "P_HIGHER_AUC": 3.05e-09, + "FDIF01": 0.64, + "PDIF01": 31.5, + "TPR@0.1": 0.43 + }, + "instance_1": { + "TPR": 0.012, + "FPR": 0.004, + "ACC": 0.7, + "AUC": 0.58, + "P_HIGHER_AUC": 1.8e-13, + "FDIF01": 0.57, + "PDIF01": 28.1, + "TPR@0.1": 0.39 + } + } + } + }, + "LiRA Attack_4e9b5db3-f93f-43dc-9a67-1035a68b892a": { + "log_id": "4e9b5db3-f93f-43dc-9a67-1035a68b892a", + "log_time": "05/11/2025 16:55:01", + "metadata": { + "sacroml_version": "1.4.0", + "attack_name": "LiRA Attack", + "attack_params": { + "output_dir": "training_artefacts", + "n_shadow_models": 100, + "p_thresh": 0.05 + }, + "global_metrics": { + "AUC_sig": "Significant at p=0.05" + }, + "target_model": "RandomForestClassifier", + "target_model_params": {}, + "target_train_params": {} + }, + "attack_experiment_logger": { + "attack_instance_logger": { + "instance_0": { + "TPR": 0.76, + "FPR": 0.49, + "AUC": 0.75, + "P_HIGHER_AUC": 0.0, + "fpr": [ + 0.0, + 0.5, + 1.0 + ], + "tpr": [ + 0.0, + 0.7, + 1.0 + ], + "n_pos_test_examples": 398, + "n_neg_test_examples": 171 + } + } + } + } +} diff --git a/tests/reporting/fixtures/legacy_edge.json b/tests/reporting/fixtures/legacy_edge.json new file mode 100644 index 00000000..37bc7e50 --- /dev/null +++ b/tests/reporting/fixtures/legacy_edge.json @@ -0,0 +1,80 @@ +{ + "LiRA Attack_0d8cf41e": { + "log_id": "0d8cf41e", + "log_time": "30/06/2024 18:34:53", + "metadata": { + "attack_name": "LiRA Attack", + "attack_params": { + "output_dir": "outputs_lira", + "n_shadow_models": 100, + "weird_param": true + }, + "global_metrics": { + "AUC_sig": "Significant at p=0.05", + "nested_metric": { + "a": 1, + "b": 2 + } + }, + "target_model": "RandomForestClassifier", + "target_model_params": { + "n_estimators": 100 + }, + "stray_metadata_key": "should be dropped" + }, + "attack_experiment_logger": { + "attack_instance_logger": { + "instance_0": { + "AUC": 0.75, + "weird_metric": 0.42, + "roc_thresh": [ + null, + 1.0, + 0.5, + 0.0 + ] + }, + "rep1": { + "AUC": 0.71 + } + }, + "dummy_attack_experiment_logger": { + "attack_instance_logger": {} + } + } + }, + "Structural Attack_aaaa": { + "log_id": "aaaa", + "log_time": "30/06/2024 18:40:00", + "metadata": { + "sacroml_version": "1.4.0", + "attack_name": "Structural Attack", + "attack_params": { + "output_dir": "outputs_struct" + }, + "global_metrics": { + "dof_risk": true, + "k_anonymity_risk": false + } + } + }, + "Mystery attack_bbbb": { + "log_id": "bbbb", + "log_time": "30/06/2024 18:45:00", + "metadata": { + "sacroml_version": "1.4.0", + "attack_name": "Mystery attack", + "attack_params": {}, + "global_metrics": { + "AUC": 0.6 + } + }, + "attack_experiment_logger": { + "attack_instance_logger": { + "instance_0": { + "AUC": 0.6 + } + } + } + } +} diff --git a/tests/reporting/test_convert.py b/tests/reporting/test_convert.py new file mode 100644 index 00000000..694f752a --- /dev/null +++ b/tests/reporting/test_convert.py @@ -0,0 +1,227 @@ +"""Tests for legacy report.json -> new format conversion.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from sacroml.main import main +from sacroml.reporting import ConversionResult, convert_report, convert_report_file + +FIXTURES = Path(__file__).parent / "fixtures" +DOCS_EXAMPLES = Path(__file__).parents[2] / "docs" / "source" / "attacks" + + +def _load(name: str) -> dict: + """Load a fixture JSON file by name.""" + with open(FIXTURES / name, encoding="utf-8") as fh: + return json.load(fh) + + +@pytest.fixture +def basic_result() -> ConversionResult: + """Conversion result for the well-formed legacy fixture.""" + return convert_report(_load("legacy_basic.json")) + + +@pytest.fixture +def edge_result() -> ConversionResult: + """Conversion result for the edge-case legacy fixture.""" + return convert_report(_load("legacy_edge.json")) + + +def test_basic_conversion_structure(basic_result: ConversionResult) -> None: + """A clean legacy report wraps under 'attacks' with all catalogs.""" + report = basic_result.report + assert report["report_schema_version"] + for catalog in ( + "metric_catalog", + "parameter_catalog", + "attack_category_catalog", + "attack_catalog", + "attacks", + ): + assert catalog in report + assert set(report["attacks"]) == { + "WorstCase attack_de2c5fc0-fb0c-4925-ac4d-26662fe7f786", + "LiRA Attack_4e9b5db3-f93f-43dc-9a67-1035a68b892a", + } + + +def test_basic_conversion_is_schema_valid( + basic_result: ConversionResult, +) -> None: + """A clean legacy report converts to a schema-valid document.""" + assert basic_result.is_valid + assert basic_result.schema_errors == [] + assert basic_result.warnings == [] + assert basic_result.curve_warnings == [] + + +def test_basic_coverage_all_catalogued( + basic_result: ConversionResult, +) -> None: + """All metrics/params/attacks in the clean fixture are catalogued.""" + for dim in basic_result.coverage.values(): + assert dim["missing"] == [] + + +def test_conversion_is_idempotent(basic_result: ConversionResult) -> None: + """Converting an already-converted report is a no-op for the payload.""" + again = convert_report(basic_result.report) + assert again.report["attacks"] == basic_result.report["attacks"] + assert again.is_valid + + +def test_missing_sacroml_version_injected( + edge_result: ConversionResult, +) -> None: + """Old reports without sacroml_version get an 'unknown' placeholder.""" + lira = edge_result.report["attacks"]["LiRA Attack_0d8cf41e"] + assert lira["metadata"]["sacroml_version"] == "unknown" + assert any("sacroml_version was missing" in w for w in edge_result.warnings) + + +def test_stray_metadata_key_dropped(edge_result: ConversionResult) -> None: + """Metadata keys not permitted by the schema are dropped with a warning.""" + lira_meta = edge_result.report["attacks"]["LiRA Attack_0d8cf41e"]["metadata"] + assert "stray_metadata_key" not in lira_meta + assert any("stray_metadata_key" in w for w in edge_result.warnings) + + +def test_non_scalar_global_metric_serialised( + edge_result: ConversionResult, +) -> None: + """Non-scalar global_metrics values are serialised to keep schema valid.""" + gm = edge_result.report["attacks"]["LiRA Attack_0d8cf41e"]["metadata"][ + "global_metrics" + ] + assert isinstance(gm["nested_metric"], str) + assert json.loads(gm["nested_metric"]) == {"a": 1, "b": 2} + + +def test_extra_experiment_logger_sibling_dropped( + edge_result: ConversionResult, +) -> None: + """Attack_experiment_logger keeps only attack_instance_logger.""" + ael = edge_result.report["attacks"]["LiRA Attack_0d8cf41e"][ + "attack_experiment_logger" + ] + assert set(ael) == {"attack_instance_logger"} + assert any("dummy_attack_experiment_logger" in w for w in edge_result.warnings) + + +def test_non_conforming_instance_key_renamed( + edge_result: ConversionResult, +) -> None: + """Instance keys not matching 'instance_' are renamed.""" + logger = edge_result.report["attacks"]["LiRA Attack_0d8cf41e"][ + "attack_experiment_logger" + ]["attack_instance_logger"] + assert all(k.startswith("instance_") for k in logger) + assert any("renamed to 'instance_" in w for w in edge_result.warnings) + + +def test_structural_attack_gets_empty_logger( + edge_result: ConversionResult, +) -> None: + """An instance-less structural attack gets an empty instance logger.""" + struct = edge_result.report["attacks"]["Structural Attack_aaaa"] + assert struct["attack_experiment_logger"]["attack_instance_logger"] == {} + + +def test_uncatalogued_entries_warn(edge_result: ConversionResult) -> None: + """Unknown metrics/params/attacks are reported but not fatal.""" + cov = edge_result.coverage + assert "weird_metric" in cov["metrics"]["missing"] + assert "weird_param" in cov["parameters"]["missing"] + assert "Mystery attack" in cov["attacks"]["missing"] + assert any("Uncatalogued metric" in w for w in edge_result.warnings) + assert any("Uncatalogued parameter" in w for w in edge_result.warnings) + assert any("Uncatalogued attack" in w for w in edge_result.warnings) + + +def test_curve_array_is_warning_not_error( + edge_result: ConversionResult, +) -> None: + """Roc_thresh starting with null is a curve warning, not a schema error.""" + assert edge_result.curve_warnings + assert all("roc_thresh" in w for w in edge_result.curve_warnings) + assert edge_result.is_valid + + +def test_non_experiment_top_level_dropped() -> None: + """Top-level keys that are not experiments are dropped with a warning.""" + result = convert_report({"junk": 123, **_load("legacy_basic.json")}) + assert "junk" not in result.report["attacks"] + assert any("does not look like an experiment" in w for w in result.warnings) + + +@pytest.mark.parametrize( + "example", + ["report_example_lira.json", "report_example_worstcase.json"], +) +def test_real_docs_examples_validate(example: str) -> None: + """The real example reports shipped in docs convert and validate.""" + with open(DOCS_EXAMPLES / example, encoding="utf-8") as fh: + data = json.load(fh) + result = convert_report(data) + assert result.is_valid, result.schema_errors + + +def test_no_validate_skips_validation() -> None: + """Validate=False produces no schema errors or curve warnings.""" + result = convert_report(_load("legacy_edge.json"), validate=False) + assert result.schema_errors == [] + assert result.curve_warnings == [] + # Coverage warnings are independent of schema validation. + assert result.coverage["metrics"]["missing"] + + +def test_convert_report_file_roundtrip(tmp_path: Path) -> None: + """Convert_report_file writes valid JSON and returns a result.""" + out = tmp_path / "nested" / "converted.json" + result = convert_report_file(FIXTURES / "legacy_basic.json", out) + assert out.is_file() + written = json.loads(out.read_text(encoding="utf-8")) + assert written["attacks"] == result.report["attacks"] + assert result.is_valid + + +def test_cli_convert_report( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture[str], +) -> None: + """The `sacroml convert-report` CLI converts and exits cleanly.""" + out = tmp_path / "converted.json" + monkeypatch.setattr( + "sys.argv", + ["sacroml", "convert-report", str(FIXTURES / "legacy_basic.json"), str(out)], + ) + with pytest.raises(SystemExit) as exc: + main() + assert exc.value.code == 0 + assert out.is_file() + assert "schema-valid" in capsys.readouterr().out + + +def test_cli_convert_report_missing_input( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """The CLI exits non-zero when the input file does not exist.""" + monkeypatch.setattr( + "sys.argv", + [ + "sacroml", + "convert-report", + str(tmp_path / "nope.json"), + str(tmp_path / "out.json"), + ], + ) + with pytest.raises(SystemExit) as exc: + main() + assert exc.value.code == 1 From 18eb44c385f26c143018fc0d6a357c6a16c2bc47 Mon Sep 17 00:00:00 2001 From: ssrhaso Date: Wed, 20 May 2026 10:57:20 +0100 Subject: [PATCH 2/7] feat(reporting): add legacy report.json conversion tooling Add `sacroml convert-report` CLI plus a `sacroml.reporting` subpackage that converts legacy flat report.json files into the new nested, catalog-enriched format and validates them against the bundled JSON schema. Includes catalog definitions, schema, tests, and docs. --- pyproject.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7f109dd9..2c24f7a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,10 @@ packages = {find = {exclude = ["docs*", "examples*", "tests*", "user_stories*"]} [tool.setuptools.package-data] "sacroml.safemodel" = ["rules.json"] -"sacroml.reporting" = ["*.json"] +"sacroml.reporting" = [ + "sacroml_attack_report.schema.json", + "catalog_definitions.json", +] [tool.ruff] indent-width = 4 From d0b2646e84b13f0b7acfeb5bd2686fecca16b22c Mon Sep 17 00:00:00 2001 From: ssrhaso Date: Wed, 20 May 2026 11:42:47 +0100 Subject: [PATCH 3/7] test(reporting): cover defensive normalisation branches in convert.py Add tests that drive every remaining defensive isinstance/fallback path in sacroml.reporting.convert: non-dict global_metrics, attack_params, target_model_params, target_train_params and attack_experiment_logger; non-string target_model and log_time; experiment keys with illegal characters; non-dict attack_instance_logger; instance-key collision during renaming; an already-wrapped report whose experiment value is not a dict; a non-curve schema error in an instance metric; and the short/unrelated/missing-lookup branches of _is_curve_violation. --- tests/reporting/test_convert.py | 157 ++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) diff --git a/tests/reporting/test_convert.py b/tests/reporting/test_convert.py index 694f752a..a2c38d03 100644 --- a/tests/reporting/test_convert.py +++ b/tests/reporting/test_convert.py @@ -9,6 +9,7 @@ from sacroml.main import main from sacroml.reporting import ConversionResult, convert_report, convert_report_file +from sacroml.reporting.convert import _is_curve_violation FIXTURES = Path(__file__).parent / "fixtures" DOCS_EXAMPLES = Path(__file__).parents[2] / "docs" / "source" / "attacks" @@ -208,6 +209,162 @@ def test_cli_convert_report( assert "schema-valid" in capsys.readouterr().out +def test_defensive_normalisation_paths() -> None: + """Malformed legacy inputs hit every defensive normalisation branch.""" + legacy = { + "report_schema_version": "0.9", + "WorstCase attack/with!chars_xyz": { + "log_id": 12345, + "log_time": 17181920, + "metadata": { + "attack_name": "WorstCase attack", + "attack_params": "not a dict", + "global_metrics": "not a dict", + "target_model": 42, + "target_model_params": "not a dict", + "target_train_params": "not a dict", + }, + "attack_experiment_logger": "not a dict", + }, + } + result = convert_report(legacy) + assert "WorstCase attack_with_chars_xyz" in result.report["attacks"] + converted = result.report["attacks"]["WorstCase attack_with_chars_xyz"] + meta = converted["metadata"] + assert meta["attack_params"] == {} + assert meta["global_metrics"] == {} + assert meta["target_model"] == "42" + assert meta["target_model_params"] == {} + assert meta["target_train_params"] == {} + assert isinstance(converted["log_time"], str) + assert converted["attack_experiment_logger"]["attack_instance_logger"] == {} + assert any("renamed to" in w for w in result.warnings) + + +def test_non_dict_instance_logger_becomes_empty() -> None: + """A non-dict ``attack_instance_logger`` is replaced with an empty one.""" + legacy = { + "WorstCase attack_qq": { + "log_id": "qq", + "log_time": "now", + "metadata": { + "sacroml_version": "1.0", + "attack_name": "WorstCase attack", + "attack_params": {}, + "global_metrics": {"AUC": 0.5}, + }, + "attack_experiment_logger": {"attack_instance_logger": "not a dict"}, + } + } + result = convert_report(legacy) + logger = result.report["attacks"]["WorstCase attack_qq"][ + "attack_experiment_logger" + ]["attack_instance_logger"] + assert logger == {} + + +def test_instance_key_collision_during_renaming() -> None: + """Renaming a non-conforming key bumps past an existing matching key.""" + legacy = { + "LiRA Attack_zz": { + "log_id": "zz", + "log_time": "now", + "metadata": { + "sacroml_version": "1.0", + "attack_name": "LiRA Attack", + "attack_params": {}, + "global_metrics": {"AUC": 0.5}, + }, + "attack_experiment_logger": { + "attack_instance_logger": { + "weird": {"AUC": 0.1}, + "instance_0": {"AUC": 0.2}, + } + }, + } + } + result = convert_report(legacy) + logger = result.report["attacks"]["LiRA Attack_zz"]["attack_experiment_logger"][ + "attack_instance_logger" + ] + assert set(logger) == {"instance_0", "instance_1"} + + +def test_already_wrapped_with_non_dict_experiment() -> None: + """An already-wrapped report whose experiment isn't a dict is skipped.""" + result = convert_report( + {"attacks": {"BogusAttack_xx": "not a dict"}}, validate=False + ) + assert result.report["attacks"] == {} + assert any("was not an object" in w for w in result.warnings) + + +def test_non_curve_schema_error_is_reported() -> None: + """An instance metric with an object value yields a real schema error.""" + legacy = { + "LiRA Attack_yy": { + "log_id": "yy", + "log_time": "now", + "metadata": { + "sacroml_version": "1.0", + "attack_name": "LiRA Attack", + "attack_params": {}, + "global_metrics": {"AUC": 0.5}, + }, + "attack_experiment_logger": { + "attack_instance_logger": { + "instance_0": {"AUC": {"unexpected": "object"}} + } + }, + } + } + result = convert_report(legacy) + assert not result.is_valid + assert result.schema_errors + + +class _FakeError: + """Minimal stand-in for ``jsonschema.exceptions.ValidationError``.""" + + def __init__(self, path: list) -> None: + self.absolute_path = path + + +@pytest.mark.parametrize( + "path", + [ + ["attacks"], # path < 6 + ["other", "exp", "attack_experiment_logger", "ail", "i0", "AUC"], + ["attacks", "exp", "wrong", "ail", "i0", "AUC"], + ["attacks", "exp", "attack_experiment_logger", "wrong", "i0", "AUC"], + [ + "attacks", + "exp", + "attack_experiment_logger", + "attack_instance_logger", + "i0", + 0, + ], + ], +) +def test_is_curve_violation_rejects_unrelated_paths(path: list) -> None: + """Errors outside the instance-metric path are never curve violations.""" + assert _is_curve_violation(_FakeError(path), {"attacks": {}}) is False + + +def test_is_curve_violation_handles_missing_lookup() -> None: + """A well-shaped path with no matching value is not a curve violation.""" + path = [ + "attacks", + "missing_exp", + "attack_experiment_logger", + "attack_instance_logger", + "instance_0", + "AUC", + ] + assert _is_curve_violation(_FakeError(path), {"attacks": {}}) is False + + def test_cli_convert_report_missing_input( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, From cb90dba61dcc4c1ece7ab2247c84f06f9540ef12 Mon Sep 17 00:00:00 2001 From: ssrhaso Date: Thu, 21 May 2026 11:13:38 +0100 Subject: [PATCH 4/7] refactor(reporting): add ConversionResult.summary helpers and DRY the CLI --- sacroml/main.py | 33 +------------ sacroml/reporting/convert.py | 82 +++++++++++++++++++++++++-------- tests/reporting/test_convert.py | 82 +++++++++++++++++++++++++++++++++ 3 files changed, 148 insertions(+), 49 deletions(-) diff --git a/sacroml/main.py b/sacroml/main.py index 40c58333..dec24112 100644 --- a/sacroml/main.py +++ b/sacroml/main.py @@ -25,38 +25,9 @@ def _run_convert_report(args: argparse.Namespace) -> int: return 1 result = convert_report_file(args.input, args.output, validate=not args.no_validate) - print(f"Converted '{args.input}' -> '{args.output}'") - for dim, summary in result.coverage.items(): - print( - f" {dim}: {len(summary['covered'])} catalogued, " - f"{len(summary['missing'])} uncatalogued" - ) - - if result.warnings: - print(f"\nWarnings ({len(result.warnings)}):") - for warning in result.warnings: - print(f" - {warning}") - - if result.curve_warnings: - print( - f"\nCurve-array notices ({len(result.curve_warnings)}): " - "fpr/tpr/roc_thresh arrays are passed through unchanged and do " - "not strictly validate yet." - ) - for warning in result.curve_warnings: - print(f" - {warning}") - - if result.schema_errors: - print(f"\nSchema errors ({len(result.schema_errors)}):") - for error in result.schema_errors: - print(f" - {error}") - print("\nConverted report is NOT schema-valid.") - return 1 - - if not args.no_validate: - print("\nConverted report is schema-valid.") - return 0 + print(result.summary(validated=not args.no_validate)) + return 0 if result.is_valid else 1 def main() -> None: diff --git a/sacroml/reporting/convert.py b/sacroml/reporting/convert.py index 93187dd8..6a63e34f 100644 --- a/sacroml/reporting/convert.py +++ b/sacroml/reporting/convert.py @@ -106,6 +106,55 @@ def is_valid(self) -> bool: """ return not self.schema_errors + def summary_dict(self) -> dict[str, Any]: + """Return a machine-readable summary of the conversion outcome.""" + return { + "is_valid": self.is_valid, + "warnings": len(self.warnings), + "curve_warnings": len(self.curve_warnings), + "schema_errors": len(self.schema_errors), + "coverage": { + dim: { + "covered": len(summary["covered"]), + "missing": len(summary["missing"]), + } + for dim, summary in self.coverage.items() + }, + } + + def summary(self, *, validated: bool = True) -> str: + """Return a human-readable, multi-line summary of the conversion. + + Parameters + ---------- + validated : bool, default True + Whether schema validation was performed; controls the final + "schema-valid" / "NOT schema-valid" line. + """ + lines: list[str] = [] + for dim, summary in self.coverage.items(): + lines.append( + f" {dim}: {len(summary['covered'])} catalogued, " + f"{len(summary['missing'])} uncatalogued" + ) + if self.warnings: + lines.append(f"\nWarnings ({len(self.warnings)}):") + lines.extend(f" - {w}" for w in self.warnings) + if self.curve_warnings: + lines.append( + f"\nCurve-array notices ({len(self.curve_warnings)}): " + "fpr/tpr/roc_thresh arrays are passed through unchanged and " + "do not strictly validate yet." + ) + lines.extend(f" - {w}" for w in self.curve_warnings) + if self.schema_errors: + lines.append(f"\nSchema errors ({len(self.schema_errors)}):") + lines.extend(f" - {e}" for e in self.schema_errors) + lines.append("\nConverted report is NOT schema-valid.") + elif validated: + lines.append("\nConverted report is schema-valid.") + return "\n".join(lines) + def _load_catalog_definitions() -> dict[str, Any]: """Load the bundled common catalog definitions.""" @@ -177,17 +226,15 @@ def _normalise_instance_logger( for key, value in logger.items(): if _INSTANCE_KEY_RE.match(key): normalised[key] = value - else: - new_key = f"instance_{next_index}" - while new_key in logger or new_key in normalised: - next_index += 1 - new_key = f"instance_{next_index}" - normalised[new_key] = value - warnings.append( - f"Experiment '{exp_key}': instance key '{key}' did not match " - f"'instance_'; renamed to '{new_key}'." - ) + continue + while (new_key := f"instance_{next_index}") in logger or new_key in normalised: + next_index += 1 + normalised[new_key] = value next_index += 1 + warnings.append( + f"Experiment '{exp_key}': instance key '{key}' did not match " + f"'instance_'; renamed to '{new_key}'." + ) return normalised @@ -289,14 +336,13 @@ def _diff( ) -> tuple[list[str], list[str]]: """Split ``seen`` names into catalogued (covered) and missing.""" patterns = patterns or [] - - def catalogued(name: str) -> bool: - if name in explicit: - return True - return any(p.match(name) for p in patterns) - - missing = sorted(n for n in seen if not catalogued(n)) - covered = sorted(seen - set(missing)) + covered: list[str] = [] + missing: list[str] = [] + for name in sorted(seen): + if name in explicit or any(p.match(name) for p in patterns): + covered.append(name) + else: + missing.append(name) return covered, missing diff --git a/tests/reporting/test_convert.py b/tests/reporting/test_convert.py index a2c38d03..3fbe9d79 100644 --- a/tests/reporting/test_convert.py +++ b/tests/reporting/test_convert.py @@ -382,3 +382,85 @@ def test_cli_convert_report_missing_input( with pytest.raises(SystemExit) as exc: main() assert exc.value.code == 1 + + +def test_summary_dict_counts(edge_result: ConversionResult) -> None: + """``summary_dict`` reduces the result to plain counts.""" + summary = edge_result.summary_dict() + assert summary["is_valid"] is True + assert summary["warnings"] == len(edge_result.warnings) + assert summary["curve_warnings"] == len(edge_result.curve_warnings) + assert summary["schema_errors"] == 0 + for dim, counts in summary["coverage"].items(): + assert counts["covered"] == len(edge_result.coverage[dim]["covered"]) + assert counts["missing"] == len(edge_result.coverage[dim]["missing"]) + + +def test_summary_text_includes_sections(edge_result: ConversionResult) -> None: + """``summary`` text surfaces coverage, warnings, curve notices and validity.""" + text = edge_result.summary() + assert "metrics:" in text + assert "Warnings (" in text + assert "Curve-array notices (" in text + assert "Converted report is schema-valid." in text + + +def test_summary_text_when_validation_skipped( + edge_result: ConversionResult, +) -> None: + """``validated=False`` suppresses the trailing schema-valid line.""" + text = edge_result.summary(validated=False) + assert "schema-valid" not in text + + +def test_summary_text_for_schema_errors() -> None: + """A schema-invalid result includes the NOT schema-valid footer.""" + legacy = { + "LiRA Attack_zz": { + "log_id": "zz", + "log_time": "now", + "metadata": { + "sacroml_version": "1.0", + "attack_name": "LiRA Attack", + "attack_params": {}, + "global_metrics": {"AUC": 0.5}, + }, + "attack_experiment_logger": { + "attack_instance_logger": { + "instance_0": {"AUC": {"unexpected": "object"}} + } + }, + } + } + result = convert_report(legacy) + text = result.summary() + assert "Schema errors (" in text + assert "NOT schema-valid" in text + + +def test_renamed_instance_keys_are_compact() -> None: + """All non-conforming keys get consecutive instance_ slots.""" + legacy = { + "LiRA Attack_zz": { + "log_id": "zz", + "log_time": "now", + "metadata": { + "sacroml_version": "1.0", + "attack_name": "LiRA Attack", + "attack_params": {}, + "global_metrics": {"AUC": 0.5}, + }, + "attack_experiment_logger": { + "attack_instance_logger": { + "first": {"AUC": 0.1}, + "second": {"AUC": 0.2}, + "third": {"AUC": 0.3}, + } + }, + } + } + result = convert_report(legacy) + logger = result.report["attacks"]["LiRA Attack_zz"]["attack_experiment_logger"][ + "attack_instance_logger" + ] + assert list(logger) == ["instance_0", "instance_1", "instance_2"] From bd2008e282453fea3b3e382d3fd2be98371bed2e Mon Sep 17 00:00:00 2001 From: ssrhaso Date: Wed, 27 May 2026 14:52:27 +0100 Subject: [PATCH 5/7] feat(reporting): handle malformed and non-object report input gracefully --- sacroml/main.py | 12 +++++++++++- sacroml/reporting/convert.py | 8 +++++++- tests/reporting/test_convert.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 2 deletions(-) diff --git a/sacroml/main.py b/sacroml/main.py index dec24112..da4b8e5b 100644 --- a/sacroml/main.py +++ b/sacroml/main.py @@ -3,6 +3,7 @@ from __future__ import annotations import argparse +import json import os import sys @@ -24,7 +25,16 @@ def _run_convert_report(args: argparse.Namespace) -> int: print(f"Input report not found: {args.input}") return 1 - result = convert_report_file(args.input, args.output, validate=not args.no_validate) + try: + result = convert_report_file( + args.input, args.output, validate=not args.no_validate + ) + except json.JSONDecodeError as exc: + print(f"Could not parse '{args.input}' as JSON: {exc}") + return 1 + except OSError as exc: + print(f"Could not read or write report: {exc}") + return 1 print(f"Converted '{args.input}' -> '{args.output}'") print(result.summary(validated=not args.no_validate)) return 0 if result.is_valid else 1 diff --git a/sacroml/reporting/convert.py b/sacroml/reporting/convert.py index 6a63e34f..4e534c09 100644 --- a/sacroml/reporting/convert.py +++ b/sacroml/reporting/convert.py @@ -175,8 +175,14 @@ def _is_experiment(value: Any) -> bool: ) -def _extract_experiments(data: dict[str, Any], warnings: list[str]) -> dict[str, Any]: +def _extract_experiments(data: Any, warnings: list[str]) -> dict[str, Any]: """Return the experiment mapping, whether the input is flat or wrapped.""" + if not isinstance(data, dict): + warnings.append( + f"Top-level report was a {type(data).__name__}, not an object; " + "no experiments could be extracted." + ) + return {} if isinstance(data.get("attacks"), dict): # Already wrapped (idempotent path). return dict(data["attacks"]) diff --git a/tests/reporting/test_convert.py b/tests/reporting/test_convert.py index 3fbe9d79..1de47fd0 100644 --- a/tests/reporting/test_convert.py +++ b/tests/reporting/test_convert.py @@ -365,6 +365,34 @@ def test_is_curve_violation_handles_missing_lookup() -> None: assert _is_curve_violation(_FakeError(path), {"attacks": {}}) is False +@pytest.mark.parametrize("payload", [[1, 2, 3], "a string", 42, None]) +def test_non_dict_top_level_yields_empty_report(payload: object) -> None: + """A non-object top-level report converts to empty attacks with a warning.""" + result = convert_report(payload, validate=False) + assert result.report["attacks"] == {} + assert any("not an object" in w for w in result.warnings) + + +def test_cli_convert_report_malformed_json( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture[str], +) -> None: + """The CLI reports a friendly error and exits 1 on malformed JSON input.""" + bad = tmp_path / "bad.json" + bad.write_text("{ not valid json", encoding="utf-8") + out = tmp_path / "out.json" + monkeypatch.setattr( + "sys.argv", + ["sacroml", "convert-report", str(bad), str(out)], + ) + with pytest.raises(SystemExit) as exc: + main() + assert exc.value.code == 1 + assert "Could not parse" in capsys.readouterr().out + assert not out.exists() + + def test_cli_convert_report_missing_input( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, From 798196f17c0c810b8a77a94433b9e5a70bdd95f2 Mon Sep 17 00:00:00 2001 From: ssrhaso Date: Fri, 29 May 2026 12:36:35 +0100 Subject: [PATCH 6/7] refactor(reporting): slim legacy report converter to its load-bearing core Address review feedback that the tooling outgrew the issue. Remove speculative normalisation and the post-hoc ConversionResult.summary/summary_dict API, keeping only what R1-R4 and the real example reports require: wrap, catalog injection, schema validation, uncatalogued warnings, sacroml_version injection and the curve-violation downgrade. convert.py 570->377, tests 40->20, legacy_edge fixture removed. --- docs/source/attacks/report.rst | 11 +- sacroml/main.py | 21 +- sacroml/reporting/convert.py | 397 +++++------------- tests/reporting/fixtures/legacy_edge.json | 80 ---- tests/reporting/test_convert.py | 475 ++++++---------------- 5 files changed, 260 insertions(+), 724 deletions(-) delete mode 100644 tests/reporting/fixtures/legacy_edge.json diff --git a/docs/source/attacks/report.rst b/docs/source/attacks/report.rst index 02a7789c..4fd1336c 100644 --- a/docs/source/attacks/report.rst +++ b/docs/source/attacks/report.rst @@ -22,10 +22,6 @@ without re-running any attacks: The converter: * wraps the legacy experiments under a top-level ``attacks`` key; -* normalises experiment metadata so it satisfies the schema (injecting a - placeholder ``sacroml_version`` for very old reports, ensuring an - ``attack_instance_logger`` exists for instance-less structural attacks, - serialising non-scalar ``global_metrics`` values, etc.); * injects the four human-readable catalogs (``metric_catalog``, ``parameter_catalog``, ``attack_category_catalog``, ``attack_catalog``) from the bundled common definitions in @@ -34,6 +30,13 @@ The converter: the report is not present in the catalogs (conversion still succeeds); and * validates the result against the JSON schema. +Two small normalisations keep real-world legacy reports schema-valid: a +placeholder ``sacroml_version`` is injected when missing, and an empty +``attack_instance_logger`` is supplied for instance-less attacks (e.g. +structural attacks). Anything else that does not match the schema -- such as a +non-scalar metric value or a stray metadata key -- is reported as a schema +error rather than silently rewritten. + Curve-valued arrays (``fpr`` / ``tpr`` / ``roc_thresh``) are passed through unchanged. ``roc_thresh`` legitimately starts with ``null``, which the schema does not yet permit, so such violations are reported as notices rather than diff --git a/sacroml/main.py b/sacroml/main.py index da4b8e5b..59e4c0eb 100644 --- a/sacroml/main.py +++ b/sacroml/main.py @@ -35,9 +35,26 @@ def _run_convert_report(args: argparse.Namespace) -> int: except OSError as exc: print(f"Could not read or write report: {exc}") return 1 + print(f"Converted '{args.input}' -> '{args.output}'") - print(result.summary(validated=not args.no_validate)) - return 0 if result.is_valid else 1 + for dim, missing in result.coverage.items(): + if missing: + print(f" {len(missing)} uncatalogued {dim.replace('_', ' ')}: {missing}") + if result.curve_warnings: + print( + f" {len(result.curve_warnings)} curve-array notice(s); " + "fpr/tpr/roc_thresh passed through unchanged." + ) + if args.no_validate: + return 0 + if result.is_valid: + print("Converted report is schema-valid.") + return 0 + n_errors = len(result.schema_errors) + print(f"Converted report is NOT schema-valid ({n_errors} error(s)):") + for err in result.schema_errors: + print(f" - {err}") + return 1 def main() -> None: diff --git a/sacroml/reporting/convert.py b/sacroml/reporting/convert.py index 4e534c09..53658961 100644 --- a/sacroml/reporting/convert.py +++ b/sacroml/reporting/convert.py @@ -5,26 +5,27 @@ catalog-enriched document validated by ``sacroml/reporting/sacroml_attack_report.schema.json``. -This module performs that conversion: - -1. Wrap the top-level experiments under an ``"attacks"`` key. -2. Normalise experiment metadata so it satisfies the schema (injecting a - placeholder ``sacroml_version`` for very old reports, ensuring an - ``attack_instance_logger`` exists for instance-less structural attacks, - coercing non-scalar ``global_metrics`` values, etc.). -3. Inject the four human-readable catalogs (``metric_catalog``, - ``parameter_catalog``, ``attack_category_catalog``, ``attack_catalog``) - from a bundled common-definitions file. -4. Diff everything observed in the report against the catalogs and emit - warnings for metrics, parameters, attacks and attack categories that are - not catalogued (the conversion still succeeds). -5. Validate the result against the JSON schema. Schema violations caused - purely by curve-valued arrays (``fpr`` / ``tpr`` / ``roc_thresh`` -- the - latter legitimately starts with ``null``) are downgraded to warnings, - since they are a known limitation tracked separately; any other schema - violation is reported as an error. - -The conversion never mutates curve arrays: they are passed through verbatim. +The conversion: + +1. wraps the legacy experiments under a top-level ``"attacks"`` key; +2. injects the four catalogs (``metric_catalog``, ``parameter_catalog``, + ``attack_category_catalog``, ``attack_catalog``) from the bundled + ``catalog_definitions.json``; +3. validates the result against the JSON schema; and +4. warns about any metric, parameter, attack or attack category observed in the + report that is not present in the catalogs (the conversion still succeeds). + +Two small normalisations keep real-world legacy reports schema-valid: a +placeholder ``sacroml_version`` is injected when missing, and an empty +``attack_instance_logger`` is supplied for instance-less attacks (e.g. +structural attacks). Anything else that does not match the schema -- such as a +non-scalar metric or a stray metadata key -- is surfaced as a schema error +rather than silently rewritten. + +Curve-valued arrays (``fpr`` / ``tpr`` / ``roc_thresh``) are passed through +untouched; ``roc_thresh`` legitimately starts with ``null``, which the schema +does not yet permit, so such violations are reported as notices rather than +errors. """ from __future__ import annotations @@ -55,22 +56,6 @@ } ) -# Metadata keys permitted by the schema (``additionalProperties: false``). -_ALLOWED_METADATA_KEYS = frozenset( - { - "sacroml_version", - "attack_name", - "attack_params", - "global_metrics", - "baseline_global_metrics", - "target_model", - "target_model_params", - "target_train_params", - } -) -_ATTACK_KEY_RE = re.compile(r"^[A-Za-z0-9 _\-]+$") -_INSTANCE_KEY_RE = re.compile(r"^instance_[0-9]+$") - @dataclass class ConversionResult: @@ -82,21 +67,21 @@ class ConversionResult: The converted report (new format). warnings : list[str] Non-fatal issues: uncatalogued metrics/parameters/attacks/categories - and structural normalisations that were applied. + and the normalisations that were applied. curve_warnings : list[str] Schema violations attributable solely to curve-valued arrays (``fpr`` / ``tpr`` / ``roc_thresh``); these are expected and benign. schema_errors : list[str] Genuine schema violations that make the output invalid. coverage : dict - Per-dimension ``{"covered": [...], "missing": [...]}`` summary. + Per-dimension list of uncatalogued names. """ report: dict[str, Any] warnings: list[str] = field(default_factory=list) curve_warnings: list[str] = field(default_factory=list) schema_errors: list[str] = field(default_factory=list) - coverage: dict[str, dict[str, list[str]]] = field(default_factory=dict) + coverage: dict[str, list[str]] = field(default_factory=dict) @property def is_valid(self) -> bool: @@ -106,65 +91,10 @@ def is_valid(self) -> bool: """ return not self.schema_errors - def summary_dict(self) -> dict[str, Any]: - """Return a machine-readable summary of the conversion outcome.""" - return { - "is_valid": self.is_valid, - "warnings": len(self.warnings), - "curve_warnings": len(self.curve_warnings), - "schema_errors": len(self.schema_errors), - "coverage": { - dim: { - "covered": len(summary["covered"]), - "missing": len(summary["missing"]), - } - for dim, summary in self.coverage.items() - }, - } - - def summary(self, *, validated: bool = True) -> str: - """Return a human-readable, multi-line summary of the conversion. - - Parameters - ---------- - validated : bool, default True - Whether schema validation was performed; controls the final - "schema-valid" / "NOT schema-valid" line. - """ - lines: list[str] = [] - for dim, summary in self.coverage.items(): - lines.append( - f" {dim}: {len(summary['covered'])} catalogued, " - f"{len(summary['missing'])} uncatalogued" - ) - if self.warnings: - lines.append(f"\nWarnings ({len(self.warnings)}):") - lines.extend(f" - {w}" for w in self.warnings) - if self.curve_warnings: - lines.append( - f"\nCurve-array notices ({len(self.curve_warnings)}): " - "fpr/tpr/roc_thresh arrays are passed through unchanged and " - "do not strictly validate yet." - ) - lines.extend(f" - {w}" for w in self.curve_warnings) - if self.schema_errors: - lines.append(f"\nSchema errors ({len(self.schema_errors)}):") - lines.extend(f" - {e}" for e in self.schema_errors) - lines.append("\nConverted report is NOT schema-valid.") - elif validated: - lines.append("\nConverted report is schema-valid.") - return "\n".join(lines) - - -def _load_catalog_definitions() -> dict[str, Any]: - """Load the bundled common catalog definitions.""" - with open(CATALOG_DEFINITIONS_PATH, encoding="utf-8") as fh: - return json.load(fh) - -def _load_schema() -> dict[str, Any]: - """Load the bundled attack-report JSON schema.""" - with open(SCHEMA_PATH, encoding="utf-8") as fh: +def _load_json(path: Path) -> dict[str, Any]: + """Load a bundled JSON data file.""" + with open(path, encoding="utf-8") as fh: return json.load(fh) @@ -175,155 +105,41 @@ def _is_experiment(value: Any) -> bool: ) -def _extract_experiments(data: Any, warnings: list[str]) -> dict[str, Any]: +def _extract_experiments(data: dict[str, Any]) -> dict[str, Any]: """Return the experiment mapping, whether the input is flat or wrapped.""" - if not isinstance(data, dict): - warnings.append( - f"Top-level report was a {type(data).__name__}, not an object; " - "no experiments could be extracted." - ) - return {} if isinstance(data.get("attacks"), dict): - # Already wrapped (idempotent path). - return dict(data["attacks"]) - - experiments: dict[str, Any] = {} - for key, value in data.items(): - if key in _NON_EXPERIMENT_KEYS: - continue - if _is_experiment(value): - experiments[key] = value - else: - warnings.append( - f"Top-level key '{key}' does not look like an experiment " - "and was dropped during conversion." - ) - return experiments - - -def _coerce_scalar_metrics( - metrics: Any, where: str, warnings: list[str] -) -> dict[str, Any]: - """Ensure a (baseline_)global_metrics mapping holds only scalars.""" - if not isinstance(metrics, dict): - warnings.append(f"{where} was not an object; replaced with empty object.") - return {} - cleaned: dict[str, Any] = {} - for key, value in metrics.items(): - if value is None or isinstance(value, (str, bool, int, float)): - cleaned[key] = value - else: - cleaned[key] = json.dumps(value, default=str) - warnings.append( - f"{where}['{key}'] held a non-scalar value; " - "serialised to a JSON string to satisfy the schema." - ) - return cleaned - - -def _normalise_instance_logger( - logger: Any, exp_key: str, warnings: list[str] -) -> dict[str, Any]: - """Return an ``attack_instance_logger`` with schema-conforming keys.""" - if not isinstance(logger, dict): - return {} - normalised: dict[str, Any] = {} - next_index = 0 - for key, value in logger.items(): - if _INSTANCE_KEY_RE.match(key): - normalised[key] = value - continue - while (new_key := f"instance_{next_index}") in logger or new_key in normalised: - next_index += 1 - normalised[new_key] = value - next_index += 1 - warnings.append( - f"Experiment '{exp_key}': instance key '{key}' did not match " - f"'instance_'; renamed to '{new_key}'." - ) - return normalised + return dict(data["attacks"]) # already wrapped (idempotent path) + return { + key: value + for key, value in data.items() + if key not in _NON_EXPERIMENT_KEYS and _is_experiment(value) + } -def _normalise_metadata( - exp_key: str, raw_metadata: Any, warnings: list[str] +def _normalise_experiment( + exp_key: str, experiment: dict[str, Any], warnings: list[str] ) -> dict[str, Any]: - """Normalise an experiment's metadata so it conforms to the schema.""" - metadata = dict(raw_metadata or {}) - - for key in sorted(set(metadata) - _ALLOWED_METADATA_KEYS): - warnings.append( - f"Experiment '{exp_key}': metadata key '{key}' is not permitted " - "by the schema and was dropped." - ) - metadata.pop(key, None) + """Apply the minimal normalisations needed for schema validity.""" + exp = dict(experiment) + exp.setdefault("log_id", exp_key) + exp.setdefault("log_time", "") + metadata = dict(exp.get("metadata") or {}) if "sacroml_version" not in metadata: metadata["sacroml_version"] = "unknown" warnings.append( f"Experiment '{exp_key}': metadata.sacroml_version was missing; " "set to 'unknown'." ) - else: - metadata["sacroml_version"] = str(metadata["sacroml_version"]) - - metadata["attack_name"] = str( - metadata.get("attack_name", exp_key.rsplit("_", 1)[0]) - ) - if not isinstance(metadata.get("attack_params"), dict): - metadata["attack_params"] = {} - - metadata["global_metrics"] = _coerce_scalar_metrics( - metadata.get("global_metrics", {}), - f"Experiment '{exp_key}': metadata.global_metrics", - warnings, - ) - if "baseline_global_metrics" in metadata: - metadata["baseline_global_metrics"] = _coerce_scalar_metrics( - metadata["baseline_global_metrics"], - f"Experiment '{exp_key}': metadata.baseline_global_metrics", - warnings, - ) - - if "target_model" in metadata and not isinstance(metadata["target_model"], str): - metadata["target_model"] = str(metadata["target_model"]) - for opt in ("target_model_params", "target_train_params"): - if opt in metadata and not isinstance(metadata[opt], dict): - metadata[opt] = {} - return metadata - - -def _normalise_attack_experiment_logger( - exp_key: str, raw_logger: Any, warnings: list[str] -) -> dict[str, Any]: - """Return an ``attack_experiment_logger`` holding only the instance log.""" - ael = raw_logger if isinstance(raw_logger, dict) else {} - instance_logger = _normalise_instance_logger( - ael.get("attack_instance_logger", {}), exp_key, warnings - ) - for key in sorted(set(ael) - {"attack_instance_logger"}): - warnings.append( - f"Experiment '{exp_key}': attack_experiment_logger.{key} is not " - "permitted by the schema and was dropped." - ) - return {"attack_instance_logger": instance_logger} - - -def _normalise_experiment( - exp_key: str, experiment: Any, warnings: list[str] -) -> dict[str, Any]: - """Normalise a single experiment so it conforms to the schema.""" - if not isinstance(experiment, dict): - warnings.append(f"Experiment '{exp_key}' was not an object and was skipped.") - return {} - - exp = dict(experiment) - exp["log_id"] = str(exp.get("log_id", exp_key)) - if "log_time" not in exp or not isinstance(exp["log_time"], str): - exp["log_time"] = str(exp.get("log_time", "unknown")) - exp["metadata"] = _normalise_metadata(exp_key, exp.get("metadata", {}), warnings) - exp["attack_experiment_logger"] = _normalise_attack_experiment_logger( - exp_key, exp.get("attack_experiment_logger"), warnings - ) + exp["metadata"] = metadata + + # Instance-less attacks (e.g. structural attacks) have no per-instance + # metrics; supply an empty logger so the result still satisfies the schema. + logger = exp.get("attack_experiment_logger") + instances = logger.get("attack_instance_logger") if isinstance(logger, dict) else {} + if not isinstance(instances, dict): + instances = {} + exp["attack_experiment_logger"] = {"attack_instance_logger": instances} return exp @@ -337,24 +153,21 @@ def _compile_pattern_metrics(metric_catalog: dict[str, Any]) -> list[re.Pattern] return patterns -def _diff( +def _uncatalogued( seen: set[str], explicit: set[str], patterns: list[re.Pattern] | None = None -) -> tuple[list[str], list[str]]: - """Split ``seen`` names into catalogued (covered) and missing.""" +) -> list[str]: + """Return the sorted names in ``seen`` that no catalog entry covers.""" patterns = patterns or [] - covered: list[str] = [] - missing: list[str] = [] - for name in sorted(seen): - if name in explicit or any(p.match(name) for p in patterns): - covered.append(name) - else: - missing.append(name) - return covered, missing + return sorted( + name + for name in seen + if name not in explicit and not any(p.match(name) for p in patterns) + ) def _compute_coverage( report: dict[str, Any], catalogs: dict[str, Any] -) -> tuple[dict[str, dict[str, list[str]]], list[str]]: +) -> tuple[dict[str, list[str]], list[str]]: """Diff observed metrics/parameters/attacks/categories vs the catalogs.""" seen_metrics: set[str] = set() seen_params: set[str] = set() @@ -366,54 +179,52 @@ def _compute_coverage( if isinstance(inst, dict): seen_metrics.update(inst.keys()) metadata = exp.get("metadata", {}) - seen_metrics.update(metadata.get("global_metrics", {}).keys()) - seen_metrics.update(metadata.get("baseline_global_metrics", {}).keys()) - seen_params.update(metadata.get("attack_params", {}).keys()) + for key in ("global_metrics", "baseline_global_metrics"): + values = metadata.get(key, {}) + if isinstance(values, dict): + seen_metrics.update(values) + params = metadata.get("attack_params", {}) + if isinstance(params, dict): + seen_params.update(params) seen_attacks.add(metadata.get("attack_name", "")) - seen_attacks.discard("") metric_catalog = catalogs["metric_catalog"] - explicit_metrics = set(metric_catalog.get("metrics", {}).keys()) - pattern_regexes = _compile_pattern_metrics(metric_catalog) - m_covered, m_missing = _diff(seen_metrics, explicit_metrics, pattern_regexes) - - explicit_params = set(catalogs["parameter_catalog"].get("parameters", {}).keys()) - p_covered, p_missing = _diff(seen_params, explicit_params) - catalog_attacks = catalogs["attack_catalog"].get("attacks", {}) - a_covered, a_missing = _diff(seen_attacks, set(catalog_attacks.keys())) - - # Categories referenced by the attacks we *can* resolve in the catalog. - known_categories = set( - catalogs["attack_category_catalog"].get("categories", {}).keys() - ) + # Categories referenced by the attacks we can resolve in the catalog. seen_categories = { catalog_attacks[name]["attack_category"] for name in seen_attacks if name in catalog_attacks and "attack_category" in catalog_attacks[name] } - c_covered, c_missing = _diff(seen_categories, known_categories) coverage = { - "metrics": {"covered": m_covered, "missing": m_missing}, - "parameters": {"covered": p_covered, "missing": p_missing}, - "attacks": {"covered": a_covered, "missing": a_missing}, - "attack_categories": {"covered": c_covered, "missing": c_missing}, + "metrics": _uncatalogued( + seen_metrics, + set(metric_catalog.get("metrics", {})), + _compile_pattern_metrics(metric_catalog), + ), + "parameters": _uncatalogued( + seen_params, set(catalogs["parameter_catalog"].get("parameters", {})) + ), + "attacks": _uncatalogued(seen_attacks, set(catalog_attacks)), + "attack_categories": _uncatalogued( + seen_categories, + set(catalogs["attack_category_catalog"].get("categories", {})), + ), } - warnings: list[str] = [] labels = { "metrics": "metric", "parameters": "parameter", "attacks": "attack", "attack_categories": "attack category", } - for dim, label in labels.items(): - for name in coverage[dim]["missing"]: - warnings.append( - f"Uncatalogued {label}: '{name}' is not present in the {dim} catalog." - ) + warnings = [ + f"Uncatalogued {labels[dim]}: '{name}' is not present in the {dim} catalog." + for dim, names in coverage.items() + for name in names + ] return coverage, warnings @@ -445,9 +256,7 @@ def _is_curve_violation(error: Any, report: dict[str, Any]) -> bool: return isinstance(value, list) -def _validate( - report: dict[str, Any], -) -> tuple[list[str], list[str]]: +def _validate(report: dict[str, Any]) -> tuple[list[str], list[str]]: """Validate ``report`` against the schema. Returns @@ -455,7 +264,7 @@ def _validate( tuple[list[str], list[str]] ``(schema_errors, curve_warnings)`` -- human-readable messages. """ - validator = Draft7Validator(_load_schema()) + validator = Draft7Validator(_load_json(SCHEMA_PATH)) schema_errors: list[str] = [] curve_warnings: list[str] = [] for error in sorted(validator.iter_errors(report), key=str): @@ -488,29 +297,27 @@ def convert_report(data: dict[str, Any], *, validate: bool = True) -> Conversion summary. """ warnings: list[str] = [] - experiments = _extract_experiments(data, warnings) - - converted_attacks: dict[str, Any] = {} - for exp_key, experiment in experiments.items(): - safe_key = exp_key - if not _ATTACK_KEY_RE.match(exp_key): - safe_key = re.sub(r"[^A-Za-z0-9 _\-]", "_", exp_key) - warnings.append( - f"Experiment key '{exp_key}' contained characters not allowed " - f"by the schema; renamed to '{safe_key}'." - ) - normalised = _normalise_experiment(exp_key, experiment, warnings) - if normalised: - converted_attacks[safe_key] = normalised - - catalogs = _load_catalog_definitions() + experiments: dict[str, Any] = {} + if isinstance(data, dict): + experiments = _extract_experiments(data) + else: + warnings.append( + f"Top-level report was a {type(data).__name__}, not an object; " + "no experiments could be extracted." + ) + + catalogs = _load_json(CATALOG_DEFINITIONS_PATH) report: dict[str, Any] = { "report_schema_version": REPORT_SCHEMA_VERSION, "metric_catalog": catalogs["metric_catalog"], "parameter_catalog": catalogs["parameter_catalog"], "attack_category_catalog": catalogs["attack_category_catalog"], "attack_catalog": catalogs["attack_catalog"], - "attacks": converted_attacks, + "attacks": { + key: _normalise_experiment(key, exp, warnings) + for key, exp in experiments.items() + if isinstance(exp, dict) + }, } coverage, coverage_warnings = _compute_coverage(report, catalogs) diff --git a/tests/reporting/fixtures/legacy_edge.json b/tests/reporting/fixtures/legacy_edge.json deleted file mode 100644 index 37bc7e50..00000000 --- a/tests/reporting/fixtures/legacy_edge.json +++ /dev/null @@ -1,80 +0,0 @@ -{ - "LiRA Attack_0d8cf41e": { - "log_id": "0d8cf41e", - "log_time": "30/06/2024 18:34:53", - "metadata": { - "attack_name": "LiRA Attack", - "attack_params": { - "output_dir": "outputs_lira", - "n_shadow_models": 100, - "weird_param": true - }, - "global_metrics": { - "AUC_sig": "Significant at p=0.05", - "nested_metric": { - "a": 1, - "b": 2 - } - }, - "target_model": "RandomForestClassifier", - "target_model_params": { - "n_estimators": 100 - }, - "stray_metadata_key": "should be dropped" - }, - "attack_experiment_logger": { - "attack_instance_logger": { - "instance_0": { - "AUC": 0.75, - "weird_metric": 0.42, - "roc_thresh": [ - null, - 1.0, - 0.5, - 0.0 - ] - }, - "rep1": { - "AUC": 0.71 - } - }, - "dummy_attack_experiment_logger": { - "attack_instance_logger": {} - } - } - }, - "Structural Attack_aaaa": { - "log_id": "aaaa", - "log_time": "30/06/2024 18:40:00", - "metadata": { - "sacroml_version": "1.4.0", - "attack_name": "Structural Attack", - "attack_params": { - "output_dir": "outputs_struct" - }, - "global_metrics": { - "dof_risk": true, - "k_anonymity_risk": false - } - } - }, - "Mystery attack_bbbb": { - "log_id": "bbbb", - "log_time": "30/06/2024 18:45:00", - "metadata": { - "sacroml_version": "1.4.0", - "attack_name": "Mystery attack", - "attack_params": {}, - "global_metrics": { - "AUC": 0.6 - } - }, - "attack_experiment_logger": { - "attack_instance_logger": { - "instance_0": { - "AUC": 0.6 - } - } - } - } -} diff --git a/tests/reporting/test_convert.py b/tests/reporting/test_convert.py index 1de47fd0..f86f6739 100644 --- a/tests/reporting/test_convert.py +++ b/tests/reporting/test_convert.py @@ -9,7 +9,6 @@ from sacroml.main import main from sacroml.reporting import ConversionResult, convert_report, convert_report_file -from sacroml.reporting.convert import _is_curve_violation FIXTURES = Path(__file__).parent / "fixtures" DOCS_EXAMPLES = Path(__file__).parents[2] / "docs" / "source" / "attacks" @@ -21,16 +20,29 @@ def _load(name: str) -> dict: return json.load(fh) +def _experiment(metadata: dict | None = None, instances: dict | None = None) -> dict: + """Build a minimal, schema-valid legacy experiment for inline tests.""" + meta = { + "sacroml_version": "1.4.0", + "attack_name": "LiRA Attack", + "attack_params": {}, + "global_metrics": {"AUC_sig": "Significant at p=0.05"}, + } + if metadata is not None: + meta.update(metadata) + exp: dict = {"log_id": "id", "log_time": "now", "metadata": meta} + if instances is not None: + exp["attack_experiment_logger"] = {"attack_instance_logger": instances} + return exp + + @pytest.fixture def basic_result() -> ConversionResult: """Conversion result for the well-formed legacy fixture.""" return convert_report(_load("legacy_basic.json")) -@pytest.fixture -def edge_result() -> ConversionResult: - """Conversion result for the edge-case legacy fixture.""" - return convert_report(_load("legacy_edge.json")) +# --- R1/R2: wrapping and catalog injection ------------------------------- def test_basic_conversion_structure(basic_result: ConversionResult) -> None: @@ -51,9 +63,10 @@ def test_basic_conversion_structure(basic_result: ConversionResult) -> None: } -def test_basic_conversion_is_schema_valid( - basic_result: ConversionResult, -) -> None: +# --- R3: schema validation ----------------------------------------------- + + +def test_basic_conversion_is_schema_valid(basic_result: ConversionResult) -> None: """A clean legacy report converts to a schema-valid document.""" assert basic_result.is_valid assert basic_result.schema_errors == [] @@ -61,14 +74,6 @@ def test_basic_conversion_is_schema_valid( assert basic_result.curve_warnings == [] -def test_basic_coverage_all_catalogued( - basic_result: ConversionResult, -) -> None: - """All metrics/params/attacks in the clean fixture are catalogued.""" - for dim in basic_result.coverage.values(): - assert dim["missing"] == [] - - def test_conversion_is_idempotent(basic_result: ConversionResult) -> None: """Converting an already-converted report is a no-op for the payload.""" again = convert_report(basic_result.report) @@ -76,109 +81,139 @@ def test_conversion_is_idempotent(basic_result: ConversionResult) -> None: assert again.is_valid -def test_missing_sacroml_version_injected( - edge_result: ConversionResult, -) -> None: - """Old reports without sacroml_version get an 'unknown' placeholder.""" - lira = edge_result.report["attacks"]["LiRA Attack_0d8cf41e"] - assert lira["metadata"]["sacroml_version"] == "unknown" - assert any("sacroml_version was missing" in w for w in edge_result.warnings) - +@pytest.mark.parametrize( + "example", + ["report_example_lira.json", "report_example_worstcase.json"], +) +def test_real_docs_examples_validate(example: str) -> None: + """Real example reports convert, validate, and trip the load-bearing paths. -def test_stray_metadata_key_dropped(edge_result: ConversionResult) -> None: - """Metadata keys not permitted by the schema are dropped with a warning.""" - lira_meta = edge_result.report["attacks"]["LiRA Attack_0d8cf41e"]["metadata"] - assert "stray_metadata_key" not in lira_meta - assert any("stray_metadata_key" in w for w in edge_result.warnings) + Both real reports lack ``sacroml_version`` (injected) and contain + ``roc_thresh``/curve arrays that start with ``null`` (downgraded to curve + notices); without either behaviour the result would not be schema-valid. + """ + with open(DOCS_EXAMPLES / example, encoding="utf-8") as fh: + data = json.load(fh) + result = convert_report(data) + assert result.is_valid, result.schema_errors + assert result.curve_warnings # roc/curve null arrays present + assert all( + exp["metadata"]["sacroml_version"] == "unknown" + for exp in result.report["attacks"].values() + ) -def test_non_scalar_global_metric_serialised( - edge_result: ConversionResult, -) -> None: - """Non-scalar global_metrics values are serialised to keep schema valid.""" - gm = edge_result.report["attacks"]["LiRA Attack_0d8cf41e"]["metadata"][ - "global_metrics" - ] - assert isinstance(gm["nested_metric"], str) - assert json.loads(gm["nested_metric"]) == {"a": 1, "b": 2} +def test_non_curve_schema_error_is_reported() -> None: + """An instance metric with an object value yields a real schema error.""" + legacy = { + "LiRA Attack_yy": _experiment( + instances={"instance_0": {"AUC": {"unexpected": "object"}}} + ) + } + result = convert_report(legacy) + assert not result.is_valid + assert result.schema_errors -def test_extra_experiment_logger_sibling_dropped( - edge_result: ConversionResult, -) -> None: - """Attack_experiment_logger keeps only attack_instance_logger.""" - ael = edge_result.report["attacks"]["LiRA Attack_0d8cf41e"][ - "attack_experiment_logger" - ] - assert set(ael) == {"attack_instance_logger"} - assert any("dummy_attack_experiment_logger" in w for w in edge_result.warnings) +def test_curve_array_is_warning_not_error() -> None: + """An roc_thresh array starting with null is a curve notice, not an error.""" + legacy = { + "LiRA Attack_zz": _experiment( + instances={"instance_0": {"AUC": 0.75, "roc_thresh": [None, 1.0, 0.0]}} + ) + } + result = convert_report(legacy) + assert result.is_valid + assert result.curve_warnings + assert all("roc_thresh" in w for w in result.curve_warnings) -def test_non_conforming_instance_key_renamed( - edge_result: ConversionResult, -) -> None: - """Instance keys not matching 'instance_' are renamed.""" - logger = edge_result.report["attacks"]["LiRA Attack_0d8cf41e"][ - "attack_experiment_logger" - ]["attack_instance_logger"] - assert all(k.startswith("instance_") for k in logger) - assert any("renamed to 'instance_" in w for w in edge_result.warnings) +# --- minimal normalisation (load-bearing on real reports) ---------------- -def test_structural_attack_gets_empty_logger( - edge_result: ConversionResult, -) -> None: - """An instance-less structural attack gets an empty instance logger.""" - struct = edge_result.report["attacks"]["Structural Attack_aaaa"] - assert struct["attack_experiment_logger"]["attack_instance_logger"] == {} +def test_missing_sacroml_version_injected() -> None: + """Reports without sacroml_version get an 'unknown' placeholder.""" + legacy = {"LiRA Attack_x": _experiment(metadata={"sacroml_version": None})} + del legacy["LiRA Attack_x"]["metadata"]["sacroml_version"] + result = convert_report(legacy) + meta = result.report["attacks"]["LiRA Attack_x"]["metadata"] + assert meta["sacroml_version"] == "unknown" + assert any("sacroml_version was missing" in w for w in result.warnings) -def test_uncatalogued_entries_warn(edge_result: ConversionResult) -> None: - """Unknown metrics/params/attacks are reported but not fatal.""" - cov = edge_result.coverage - assert "weird_metric" in cov["metrics"]["missing"] - assert "weird_param" in cov["parameters"]["missing"] - assert "Mystery attack" in cov["attacks"]["missing"] - assert any("Uncatalogued metric" in w for w in edge_result.warnings) - assert any("Uncatalogued parameter" in w for w in edge_result.warnings) - assert any("Uncatalogued attack" in w for w in edge_result.warnings) +def test_structural_attack_gets_empty_logger() -> None: + """An instance-less attack gets an empty instance logger, not a crash.""" + legacy = { + "Structural Attack_aaaa": { + "log_id": "aaaa", + "metadata": { + "sacroml_version": "1.4.0", + "attack_name": "Structural Attack", + "attack_params": {}, + "global_metrics": {}, + }, + } + } + result = convert_report(legacy) + struct = result.report["attacks"]["Structural Attack_aaaa"] + assert struct["attack_experiment_logger"]["attack_instance_logger"] == {} + assert result.is_valid -def test_curve_array_is_warning_not_error( - edge_result: ConversionResult, -) -> None: - """Roc_thresh starting with null is a curve warning, not a schema error.""" - assert edge_result.curve_warnings - assert all("roc_thresh" in w for w in edge_result.curve_warnings) - assert edge_result.is_valid +# --- R4: uncatalogued warnings ------------------------------------------- -def test_non_experiment_top_level_dropped() -> None: - """Top-level keys that are not experiments are dropped with a warning.""" - result = convert_report({"junk": 123, **_load("legacy_basic.json")}) - assert "junk" not in result.report["attacks"] - assert any("does not look like an experiment" in w for w in result.warnings) +def test_basic_coverage_all_catalogued(basic_result: ConversionResult) -> None: + """All metrics/params/attacks in the clean fixture are catalogued.""" + assert all(not missing for missing in basic_result.coverage.values()) -@pytest.mark.parametrize( - "example", - ["report_example_lira.json", "report_example_worstcase.json"], -) -def test_real_docs_examples_validate(example: str) -> None: - """The real example reports shipped in docs convert and validate.""" - with open(DOCS_EXAMPLES / example, encoding="utf-8") as fh: - data = json.load(fh) - result = convert_report(data) - assert result.is_valid, result.schema_errors +def test_uncatalogued_entries_warn() -> None: + """Unknown metrics/params/attacks are reported but not fatal.""" + legacy = { + "Mystery attack_bbbb": _experiment( + metadata={ + "attack_name": "Mystery attack", + "attack_params": {"weird_param": True}, + }, + instances={"instance_0": {"AUC": 0.6, "weird_metric": 0.42}}, + ) + } + result = convert_report(legacy) + assert "weird_metric" in result.coverage["metrics"] + assert "weird_param" in result.coverage["parameters"] + assert "Mystery attack" in result.coverage["attacks"] + assert any("Uncatalogued metric" in w for w in result.warnings) + assert any("Uncatalogued parameter" in w for w in result.warnings) + assert any("Uncatalogued attack" in w for w in result.warnings) def test_no_validate_skips_validation() -> None: - """Validate=False produces no schema errors or curve warnings.""" - result = convert_report(_load("legacy_edge.json"), validate=False) + """Validate=False produces no schema errors but keeps coverage warnings.""" + legacy = { + "Mystery attack_bbbb": _experiment( + metadata={"attack_name": "Mystery attack"}, + instances={"instance_0": {"weird_metric": 0.42}}, + ) + } + result = convert_report(legacy, validate=False) assert result.schema_errors == [] assert result.curve_warnings == [] - # Coverage warnings are independent of schema validation. - assert result.coverage["metrics"]["missing"] + assert result.coverage["metrics"] # coverage is independent of validation + + +# --- robustness: non-object input must not crash (regression guard) ------ + + +@pytest.mark.parametrize("payload", [[1, 2, 3], "a string", 42, None]) +def test_non_dict_top_level_yields_empty_report(payload: object) -> None: + """A non-object top-level report converts to empty attacks with a warning.""" + result = convert_report(payload, validate=False) + assert result.report["attacks"] == {} + assert any("not an object" in w for w in result.warnings) + + +# --- file + CLI entry points --------------------------------------------- def test_convert_report_file_roundtrip(tmp_path: Path) -> None: @@ -209,170 +244,6 @@ def test_cli_convert_report( assert "schema-valid" in capsys.readouterr().out -def test_defensive_normalisation_paths() -> None: - """Malformed legacy inputs hit every defensive normalisation branch.""" - legacy = { - "report_schema_version": "0.9", - "WorstCase attack/with!chars_xyz": { - "log_id": 12345, - "log_time": 17181920, - "metadata": { - "attack_name": "WorstCase attack", - "attack_params": "not a dict", - "global_metrics": "not a dict", - "target_model": 42, - "target_model_params": "not a dict", - "target_train_params": "not a dict", - }, - "attack_experiment_logger": "not a dict", - }, - } - result = convert_report(legacy) - assert "WorstCase attack_with_chars_xyz" in result.report["attacks"] - converted = result.report["attacks"]["WorstCase attack_with_chars_xyz"] - meta = converted["metadata"] - assert meta["attack_params"] == {} - assert meta["global_metrics"] == {} - assert meta["target_model"] == "42" - assert meta["target_model_params"] == {} - assert meta["target_train_params"] == {} - assert isinstance(converted["log_time"], str) - assert converted["attack_experiment_logger"]["attack_instance_logger"] == {} - assert any("renamed to" in w for w in result.warnings) - - -def test_non_dict_instance_logger_becomes_empty() -> None: - """A non-dict ``attack_instance_logger`` is replaced with an empty one.""" - legacy = { - "WorstCase attack_qq": { - "log_id": "qq", - "log_time": "now", - "metadata": { - "sacroml_version": "1.0", - "attack_name": "WorstCase attack", - "attack_params": {}, - "global_metrics": {"AUC": 0.5}, - }, - "attack_experiment_logger": {"attack_instance_logger": "not a dict"}, - } - } - result = convert_report(legacy) - logger = result.report["attacks"]["WorstCase attack_qq"][ - "attack_experiment_logger" - ]["attack_instance_logger"] - assert logger == {} - - -def test_instance_key_collision_during_renaming() -> None: - """Renaming a non-conforming key bumps past an existing matching key.""" - legacy = { - "LiRA Attack_zz": { - "log_id": "zz", - "log_time": "now", - "metadata": { - "sacroml_version": "1.0", - "attack_name": "LiRA Attack", - "attack_params": {}, - "global_metrics": {"AUC": 0.5}, - }, - "attack_experiment_logger": { - "attack_instance_logger": { - "weird": {"AUC": 0.1}, - "instance_0": {"AUC": 0.2}, - } - }, - } - } - result = convert_report(legacy) - logger = result.report["attacks"]["LiRA Attack_zz"]["attack_experiment_logger"][ - "attack_instance_logger" - ] - assert set(logger) == {"instance_0", "instance_1"} - - -def test_already_wrapped_with_non_dict_experiment() -> None: - """An already-wrapped report whose experiment isn't a dict is skipped.""" - result = convert_report( - {"attacks": {"BogusAttack_xx": "not a dict"}}, validate=False - ) - assert result.report["attacks"] == {} - assert any("was not an object" in w for w in result.warnings) - - -def test_non_curve_schema_error_is_reported() -> None: - """An instance metric with an object value yields a real schema error.""" - legacy = { - "LiRA Attack_yy": { - "log_id": "yy", - "log_time": "now", - "metadata": { - "sacroml_version": "1.0", - "attack_name": "LiRA Attack", - "attack_params": {}, - "global_metrics": {"AUC": 0.5}, - }, - "attack_experiment_logger": { - "attack_instance_logger": { - "instance_0": {"AUC": {"unexpected": "object"}} - } - }, - } - } - result = convert_report(legacy) - assert not result.is_valid - assert result.schema_errors - - -class _FakeError: - """Minimal stand-in for ``jsonschema.exceptions.ValidationError``.""" - - def __init__(self, path: list) -> None: - self.absolute_path = path - - -@pytest.mark.parametrize( - "path", - [ - ["attacks"], # path < 6 - ["other", "exp", "attack_experiment_logger", "ail", "i0", "AUC"], - ["attacks", "exp", "wrong", "ail", "i0", "AUC"], - ["attacks", "exp", "attack_experiment_logger", "wrong", "i0", "AUC"], - [ - "attacks", - "exp", - "attack_experiment_logger", - "attack_instance_logger", - "i0", - 0, - ], - ], -) -def test_is_curve_violation_rejects_unrelated_paths(path: list) -> None: - """Errors outside the instance-metric path are never curve violations.""" - assert _is_curve_violation(_FakeError(path), {"attacks": {}}) is False - - -def test_is_curve_violation_handles_missing_lookup() -> None: - """A well-shaped path with no matching value is not a curve violation.""" - path = [ - "attacks", - "missing_exp", - "attack_experiment_logger", - "attack_instance_logger", - "instance_0", - "AUC", - ] - assert _is_curve_violation(_FakeError(path), {"attacks": {}}) is False - - -@pytest.mark.parametrize("payload", [[1, 2, 3], "a string", 42, None]) -def test_non_dict_top_level_yields_empty_report(payload: object) -> None: - """A non-object top-level report converts to empty attacks with a warning.""" - result = convert_report(payload, validate=False) - assert result.report["attacks"] == {} - assert any("not an object" in w for w in result.warnings) - - def test_cli_convert_report_malformed_json( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, @@ -410,85 +281,3 @@ def test_cli_convert_report_missing_input( with pytest.raises(SystemExit) as exc: main() assert exc.value.code == 1 - - -def test_summary_dict_counts(edge_result: ConversionResult) -> None: - """``summary_dict`` reduces the result to plain counts.""" - summary = edge_result.summary_dict() - assert summary["is_valid"] is True - assert summary["warnings"] == len(edge_result.warnings) - assert summary["curve_warnings"] == len(edge_result.curve_warnings) - assert summary["schema_errors"] == 0 - for dim, counts in summary["coverage"].items(): - assert counts["covered"] == len(edge_result.coverage[dim]["covered"]) - assert counts["missing"] == len(edge_result.coverage[dim]["missing"]) - - -def test_summary_text_includes_sections(edge_result: ConversionResult) -> None: - """``summary`` text surfaces coverage, warnings, curve notices and validity.""" - text = edge_result.summary() - assert "metrics:" in text - assert "Warnings (" in text - assert "Curve-array notices (" in text - assert "Converted report is schema-valid." in text - - -def test_summary_text_when_validation_skipped( - edge_result: ConversionResult, -) -> None: - """``validated=False`` suppresses the trailing schema-valid line.""" - text = edge_result.summary(validated=False) - assert "schema-valid" not in text - - -def test_summary_text_for_schema_errors() -> None: - """A schema-invalid result includes the NOT schema-valid footer.""" - legacy = { - "LiRA Attack_zz": { - "log_id": "zz", - "log_time": "now", - "metadata": { - "sacroml_version": "1.0", - "attack_name": "LiRA Attack", - "attack_params": {}, - "global_metrics": {"AUC": 0.5}, - }, - "attack_experiment_logger": { - "attack_instance_logger": { - "instance_0": {"AUC": {"unexpected": "object"}} - } - }, - } - } - result = convert_report(legacy) - text = result.summary() - assert "Schema errors (" in text - assert "NOT schema-valid" in text - - -def test_renamed_instance_keys_are_compact() -> None: - """All non-conforming keys get consecutive instance_ slots.""" - legacy = { - "LiRA Attack_zz": { - "log_id": "zz", - "log_time": "now", - "metadata": { - "sacroml_version": "1.0", - "attack_name": "LiRA Attack", - "attack_params": {}, - "global_metrics": {"AUC": 0.5}, - }, - "attack_experiment_logger": { - "attack_instance_logger": { - "first": {"AUC": 0.1}, - "second": {"AUC": 0.2}, - "third": {"AUC": 0.3}, - } - }, - } - } - result = convert_report(legacy) - logger = result.report["attacks"]["LiRA Attack_zz"]["attack_experiment_logger"][ - "attack_instance_logger" - ] - assert list(logger) == ["instance_0", "instance_1", "instance_2"] From d68e4769dbbb3ddbe28e28ce8a7abfe1709bcfba Mon Sep 17 00:00:00 2001 From: ssrhaso Date: Fri, 29 May 2026 13:06:03 +0100 Subject: [PATCH 7/7] test(reporting): cover _is_curve_violation guards and non-dict instance logger --- tests/reporting/test_convert.py | 41 +++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/reporting/test_convert.py b/tests/reporting/test_convert.py index f86f6739..e76e368a 100644 --- a/tests/reporting/test_convert.py +++ b/tests/reporting/test_convert.py @@ -9,6 +9,7 @@ from sacroml.main import main from sacroml.reporting import ConversionResult, convert_report, convert_report_file +from sacroml.reporting.convert import _is_curve_violation FIXTURES = Path(__file__).parent / "fixtures" DOCS_EXAMPLES = Path(__file__).parents[2] / "docs" / "source" / "attacks" @@ -128,6 +129,32 @@ def test_curve_array_is_warning_not_error() -> None: assert all("roc_thresh" in w for w in result.curve_warnings) +class _FakeError: + """Minimal stand-in for a jsonschema ValidationError (path only).""" + + def __init__(self, path: list) -> None: + self.absolute_path = path + + +_AEL = "attack_experiment_logger" +_AIL = "attack_instance_logger" + + +@pytest.mark.parametrize( + "path", + [ + ["attacks"], # path too short + ["other", "e", _AEL, _AIL, "i", "AUC"], # not the attacks subtree + ["attacks", "e", _AEL, "wrong", "i", "AUC"], # not the instance logger + ["attacks", "e", _AEL, _AIL, "i", 0], # metric name is not a string + ["attacks", "x", _AEL, _AIL, "i", "AUC"], # value not present in report + ], +) +def test_is_curve_violation_rejects_non_curve_paths(path: list) -> None: + """Only a well-formed instance-metric path with a list value is a curve.""" + assert _is_curve_violation(_FakeError(path), {"attacks": {}}) is False + + # --- minimal normalisation (load-bearing on real reports) ---------------- @@ -160,6 +187,20 @@ def test_structural_attack_gets_empty_logger() -> None: assert result.is_valid +def test_non_dict_instance_logger_becomes_empty() -> None: + """A non-dict attack_instance_logger is replaced with an empty one.""" + legacy = { + "LiRA Attack_qq": { + "log_id": "qq", + "metadata": {"sacroml_version": "1.0", "attack_name": "LiRA Attack"}, + "attack_experiment_logger": {"attack_instance_logger": "not a dict"}, + } + } + result = convert_report(legacy, validate=False) + logger = result.report["attacks"]["LiRA Attack_qq"]["attack_experiment_logger"] + assert logger["attack_instance_logger"] == {} + + # --- R4: uncatalogued warnings -------------------------------------------