diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b978bf..b71a2db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,24 @@ Versioning: [SemVer](https://semver.org/spec/v2.0.0.html). ### Added +- **Per-metric treatment effects.** New optional `target_metric` field + on the treatment surface (set on `SegmentInput.treatment` in the + builder; mirrored as `Entity.treatment_target_metric` in the engine). + When set, the configured `treatment_lift_log_odds` applies only to + the named metric's evaluation for treatment-arm entities — every + other metric in the same period is drawn identically to its + control-arm counterpart. Lets a single intervention be modelled as + shifting one outcome (e.g. revenue) while leaving the rest of the + metric set as a placebo. Default `null` preserves the prior + trajectory-wide behaviour byte-for-byte. Config-time validation + rejects target names that don't match any declared metric. + Correlated metrics do not inherit the lift via the copula: the + residual transform is centred on each metric's own (un-shifted) + centre, so the targeted metric shifts and the correlated metric + stays at its control distribution. Manifest schema bumps 1.8 → 1.9 + with the additive `target_metric` field on the per-entity + `treatment` and per-cohort `treatment_cohorts` records. + - **Heteroscedastic gaussian noise.** Optional `scale_with_trajectory` flag on `NoiseConfig` (mirror on the builder's `NoiseInput`). When `true`, each cell's gaussian standard deviation becomes diff --git a/docs/site/manifest-reference.md b/docs/site/manifest-reference.md index 4ab149f..b6dff82 100644 --- a/docs/site/manifest-reference.md +++ b/docs/site/manifest-reference.md @@ -70,7 +70,7 @@ produces a byte-identical `manifest.json`. Encoding: UTF-8, | Field | Type | Description | |---|---|---| -| `schema_version` | `str` | Wire-shape version. Currently `"1.8"` (bumped over time as new additive sections — `causal_graph`, `correlations`, `outlier_injections`, multi-source mappings, `parent_child_relations`, `noise_config` — landed; 1.7 → 1.8 extended `noise_config` with `noise_family` / `degrees_of_freedom`) | +| `schema_version` | `str` | Wire-shape version. Currently `"1.9"` (bumped over time as new additive sections — `causal_graph`, `correlations`, `outlier_injections`, multi-source mappings, `parent_child_relations`, `noise_config` — landed; 1.7 → 1.8 extended `noise_config` with `noise_family` / `degrees_of_freedom`; 1.8 → 1.9 added the optional `target_metric` field on the per-entity `treatment` and per-cohort `treatment_cohorts` records) | | `seed` | `int` | The seed used for generation — `config.seed` | | `config_sha256` | `str` | Full SHA-256 hex of the JSON-serialized config. Detects config drift between generation and consumption | | `archetype_assignments` | array | One entry per entity; see below | diff --git a/docs/site/user-guide/experiments-and-cohorts.md b/docs/site/user-guide/experiments-and-cohorts.md index bccc8ec..b0198e5 100644 --- a/docs/site/user-guide/experiments-and-cohorts.md +++ b/docs/site/user-guide/experiments-and-cohorts.md @@ -171,6 +171,7 @@ segments: start_period: 6 # rollout date treatment_label: new_onboarding control_label: original_onboarding + target_metric: mrr # optional — see below ``` ### What the lift does @@ -188,6 +189,36 @@ behaviour: a `+0.5` lift moves `p=0.5` to ~0.62, but only moves `p=0.9` to ~0.94. Same intervention, less impact when the metric is already near saturation. +### Targeting a single metric (`target_metric`) + +By default the lift applies to **every** metric for the treatment arm — +useful when modelling an intervention that shifts overall trajectory +position (a global engagement boost, a churn-reduction programme). +Add `target_metric: ` to restrict the lift to a single +named metric. Every other metric in the same period is drawn +identically to the control arm, even for entities in the treatment +cohort. + +```yaml +treatment: + fraction: 0.5 + lift_log_odds: 0.6 + start_period: 6 + target_metric: mrr # only mrr shifts; engagement, churn_risk, etc. stay flat +``` + +Use the targeted form when the experimental hypothesis names one +outcome metric ("the pricing experiment lifts revenue, not +engagement"), or when you want a placebo metric in the dataset whose +mean must be statistically identical across arms. Omit it for a +trajectory-wide intervention. + +Correlated metrics: if `target_metric` names a metric that participates +in a `connections` correlation, the copula still operates on residuals +around each metric's own (un-shifted) centre — so the lift does **not** +propagate to the correlated metric's mean. The targeted metric shifts, +the correlated metric stays at its control distribution. + ### Pre-treatment baseline At `period_index < treatment_start_period`, the shift is `0.0` for @@ -218,15 +249,22 @@ changing one feature's shape doesn't shift another feature's outputs. ### Manifest -Two new manifest fields land at schema version `1.5`: +Two manifest fields surface treatment ground-truth: - `EntityArchetypeAssignment.treatment` — per-entity assignment record. Carries the entity's group label, lift (or `None` for - control), and `start_period`. `null` for entities with no treatment + control), `start_period`, and `target_metric` (`null` for the + trajectory-wide default). `null` for entities with no treatment fields set. - `ManifestSchema.treatment_cohorts` — aggregate per-cohort records. One entry per distinct `treatment_group` label. Reports the cohort - size, mean lift, and modal `start_period`. + size, mean lift, modal `start_period`, and modal `target_metric` + (`null` when every entity in the cohort uses the trajectory-wide + default). + +The `target_metric` field on both records is additive; manifests +emitted for configs that do not set `target_metric` keep the field +`null`, so older readers continue to parse cleanly. ### Validator @@ -234,6 +272,10 @@ Rejected at config load: - `treatment_start_period >= n_periods` (the lift would never apply). - `treatment_lift_log_odds = ±inf` or `nan` (would propagate NaN cells). +- `target_metric` set to a name that doesn't match any declared metric. + Without this check a typo would silently fall through the per-metric + gate (no metric matches → the lift is never applied) and the + treatment would be invisible in the generated data. **NOT** rejected (intentionally): diff --git a/plotsim-schema.json b/plotsim-schema.json index e647175..48f042f 100644 --- a/plotsim-schema.json +++ b/plotsim-schema.json @@ -560,6 +560,18 @@ "minimum": 0, "title": "Treatment Start Period", "type": "integer" + }, + "treatment_target_metric": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Treatment Target Metric" } }, "required": [ diff --git a/plotsim/builder/input.py b/plotsim/builder/input.py index 579b880..9e482ca 100644 --- a/plotsim/builder/input.py +++ b/plotsim/builder/input.py @@ -388,6 +388,15 @@ class TreatmentConfig(BaseModel): positions — the AC for "pre-treatment baseline is identical". * ``treatment_label`` / ``control_label`` — cohort labels for the manifest. Defaults match the conventional A/B labelling. + * ``target_metric`` — optional name of a single metric. When set, + the lift only affects that metric's effective-position + evaluation; every other metric is byte-identical to its + control-arm draw. ``None`` (default) = trajectory-wide + application (every metric sees the lift, the pre-M24 + behaviour). The interpreter copies this value onto every + expanded entity in the segment (treatment AND control arms) + for ground-truth symmetry; control entities have no lift to + gate, so the field is harmless there. RNG isolation: the interpreter draws treatment assignments from a distinct ``np.random.default_rng(seed ^ TREATMENT_SALT)`` stream, @@ -404,6 +413,7 @@ class TreatmentConfig(BaseModel): start_period: int = Field(default=0, ge=0) treatment_label: str = Field(default="treatment", min_length=1) control_label: str = Field(default="control", min_length=1) + target_metric: Optional[str] = None @field_validator("lift_log_odds") @classmethod diff --git a/plotsim/builder/interpreter.py b/plotsim/builder/interpreter.py index 6e23caa..d8a5b8c 100644 --- a/plotsim/builder/interpreter.py +++ b/plotsim/builder/interpreter.py @@ -487,6 +487,9 @@ def _build_archetypes_and_entities( treatment_start_period=( s.treatment.start_period if s.treatment is not None else 0 ), + treatment_target_metric=( + s.treatment.target_metric if s.treatment is not None else None + ), ) ) diff --git a/plotsim/config.py b/plotsim/config.py index 2d21be6..9885108 100644 --- a/plotsim/config.py +++ b/plotsim/config.py @@ -1332,6 +1332,14 @@ class Entity(_Frozen): # out a baseline window where treatment and control entities share # identical metric distributions (the AC for "pre-treatment baseline # is identical across groups"). + # * ``treatment_target_metric`` — optional name of a single metric. + # When set, the logit shift only applies to that metric's + # effective-position evaluation; every other metric in the same + # period sees ``treatment_shift=0.0`` and is byte-identical to its + # control-arm draw. ``None`` (default) = trajectory-wide application + # (every metric sees the lift, the pre-M24 behaviour). The named + # metric must exist in ``config.metrics``; the validator + # ``validate_treatment_assignments`` enforces this at load time. # # The label is decoupled from the lift so a "control" entity can carry # a label without applying a shift, AND so the user can opt out of @@ -1342,6 +1350,7 @@ class Entity(_Frozen): treatment_group: Optional[str] = None treatment_lift_log_odds: Optional[float] = None treatment_start_period: int = Field(default=0, ge=0) + treatment_target_metric: Optional[str] = None class FKDistribution(_Frozen): diff --git a/plotsim/manifest.py b/plotsim/manifest.py index bb62ba1..fc1ccb5 100644 --- a/plotsim/manifest.py +++ b/plotsim/manifest.py @@ -117,7 +117,15 @@ # records the realized noise family whenever it diverges from the # historical lane, not only when heteroscedastic amplitude is on. # Default-family default-amplitude runs still emit ``noise_config=None``. -MANIFEST_SCHEMA_VERSION = "1.8" +# 0.6-M24: bumped 1.8 → 1.9 for the additive ``target_metric`` field on +# both ``TreatmentAssignment`` (per-entity) and ``TreatmentCohort`` +# (per-cohort). Defaults to ``None`` (trajectory-wide lift, the pre-M24 +# behaviour), so configs without per-metric targeting emit a 1.9 +# manifest byte-equivalent to 1.8 modulo the schema version string and +# the new field's null default. Populated when the entity's +# ``treatment_target_metric`` names a metric — the lift then applies +# only to that metric's effective-position evaluation. +MANIFEST_SCHEMA_VERSION = "1.9" class _ManifestBase(BaseModel): @@ -151,7 +159,7 @@ class ActiveWindow(_ManifestBase): class TreatmentAssignment(_ManifestBase): """0.6-M8c: an entity's treatment / control assignment. - Three fields, all sourced from the matching ``Entity`` fields: + Four fields, all sourced from the matching ``Entity`` fields: * ``group`` — the cohort label (e.g. ``"treatment"`` / ``"control"``). Plotsim treats it as opaque metadata. @@ -160,16 +168,23 @@ class TreatmentAssignment(_ManifestBase): * ``start_period`` — the absolute period index at which the lift kicks in. Pre-treatment periods (``period_index < start_period``) see the same trajectory as the control arm. + * ``target_metric`` — M24 per-metric targeting. ``None`` means + the lift applies trajectory-wide (every metric sees it, the + pre-M24 default); a metric name restricts the lift to that + metric only. Carried on both treatment and control entities + for ground-truth symmetry — control arms have no lift to gate. Emitted only for entities with at least one treatment field set. - Default-only entities (no group label, no lift, no start period) get - ``treatment=None`` on their ``EntityArchetypeAssignment`` so the - M8c manifest field is invisible to non-A/B test datasets. + Default-only entities (no group label, no lift, no start period, + no target metric) get ``treatment=None`` on their + ``EntityArchetypeAssignment`` so the M8c manifest field is + invisible to non-A/B test datasets. """ group: Optional[str] lift_log_odds: Optional[float] start_period: int + target_metric: Optional[str] = None class EntityArchetypeAssignment(_ManifestBase): @@ -211,12 +226,22 @@ class TreatmentCohort(_ManifestBase): for the cohort. Most A/B tests use one start period per cohort, so this is the headline value; if the cohort has heterogeneous starts (rare, but supported), pick the most common. + * ``target_metric`` — M24 per-metric targeting. ``None`` when + every entity in the cohort applies the lift trajectory-wide + (the pre-M24 default), or when no entity in the cohort + declares a target metric. Otherwise the modal target metric + across the cohort — heterogeneous cohorts (rare; segments + normally map 1:1 to cohort labels and carry one + ``TreatmentConfig.target_metric``) report their most-common + value and downstream consumers can cross-reference per-entity + records for outliers. """ label: str n_entities: int mean_lift_log_odds: Optional[float] start_period: int + target_metric: Optional[str] = None class TrajectorySample(_ManifestBase): @@ -918,12 +943,14 @@ def _treatment_assignment_for(entity: Any) -> Optional[TreatmentAssignment]: entity.treatment_group is None and entity.treatment_lift_log_odds is None and entity.treatment_start_period == 0 + and entity.treatment_target_metric is None ): return None return TreatmentAssignment( group=entity.treatment_group, lift_log_odds=entity.treatment_lift_log_odds, start_period=entity.treatment_start_period, + target_metric=entity.treatment_target_metric, ) @@ -970,12 +997,23 @@ def _build_treatment_cohorts(entities: list) -> list[TreatmentCohort]: starts = Counter(m.treatment_start_period for m in members) modal_start = starts.most_common(1)[0][0] + # M24: modal target_metric across the cohort. ``None`` when no + # member declares one (the pre-M24 default — trajectory-wide + # lift). Counted across non-None values only; if every member + # has ``treatment_target_metric=None`` the cohort reports + # ``None`` (trajectory-wide), matching the pre-M24 manifest + # shape for that cohort. + targets = Counter( + m.treatment_target_metric for m in members if m.treatment_target_metric is not None + ) + modal_target: Optional[str] = targets.most_common(1)[0][0] if targets else None cohorts.append( TreatmentCohort( label=label, n_entities=len(members), mean_lift_log_odds=mean_lift, start_period=modal_start, + target_metric=modal_target, ) ) return cohorts diff --git a/plotsim/metrics.py b/plotsim/metrics.py index 5e2b222..58c1ef0 100644 --- a/plotsim/metrics.py +++ b/plotsim/metrics.py @@ -1268,6 +1268,7 @@ def generate_metrics_for_period( seasonal_global: float = 0.0, entity_seasonal_sensitivity: float = 1.0, treatment_shift: float = 0.0, + treatment_target_metric: Optional[str] = None, ) -> dict[str, Optional[float]]: """Generate every metric for one entity at one time step. @@ -1285,6 +1286,18 @@ def generate_metrics_for_period( 6. apply Cholesky correlation on residuals (if correlations given) 7. apply noise (if noise config given): gaussian → outlier → MCAR 8. clamp to value_range, round poisson to int + + M24: ``treatment_target_metric`` gates the per-metric shift. When + ``None`` (default) every metric in the loop sees the caller's + ``treatment_shift`` — the pre-M24 trajectory-wide behaviour, byte- + identical to before. When set to a metric name, only that metric's + ``_compute_effective_position`` call receives the shift; every other + metric in the same period sees ``0.0`` and is byte-identical to its + control-arm draw. The validator + ``validate_treatment_assignments`` guarantees the named metric + exists, so a non-matching name here is silent dead-weight only if + the caller bypassed validation (e.g. constructed ``PlotsimConfig`` + programmatically and pushed it into the engine directly). """ effective = [_apply_archetype_overrides(m, archetype) for m in metrics] centers: dict[str, float] = {} @@ -1292,12 +1305,17 @@ def generate_metrics_for_period( correlations_active = bool(correlations) for em in effective: + em_shift = ( + treatment_shift + if (treatment_target_metric is None or em.name == treatment_target_metric) + else 0.0 + ) eff_pos = _compute_effective_position( trajectory_position, em, lag_buffer, period_index, - treatment_shift=treatment_shift, + treatment_shift=em_shift, ) if lag_buffer is not None: # Append this metric's effective position BEFORE moving on to @@ -1374,6 +1392,7 @@ def generate_entity_metrics( entity_seasonal_sensitivity: float = 1.0, treatment_lift_log_odds: Optional[float] = None, treatment_start_period: int = 0, + treatment_target_metric: Optional[str] = None, ) -> dict[str, np.ndarray]: """Generate every metric's full time series for one entity. @@ -1471,6 +1490,7 @@ def generate_entity_metrics( seasonal_global=seasonal_global_t, entity_seasonal_sensitivity=entity_seasonal_sensitivity, treatment_shift=shift_t, + treatment_target_metric=treatment_target_metric, ) for m in sorted_metrics: collected[m.name].append(period_out[m.name]) diff --git a/plotsim/tables.py b/plotsim/tables.py index dc8a21f..0d58862 100644 --- a/plotsim/tables.py +++ b/plotsim/tables.py @@ -367,6 +367,7 @@ def _compute_entity_metrics( entity_seasonal_sensitivity=entity.seasonal_sensitivity, treatment_lift_log_odds=entity.treatment_lift_log_odds, treatment_start_period=entity.treatment_start_period, + treatment_target_metric=entity.treatment_target_metric, ) return entity_metrics @@ -468,6 +469,7 @@ def _compute_entity_metrics( entity_seasonal_sensitivity=entity.seasonal_sensitivity, treatment_lift_log_odds=entity.treatment_lift_log_odds, treatment_start_period=entity.treatment_start_period, + treatment_target_metric=entity.treatment_target_metric, ) return entity_metrics_v diff --git a/plotsim/validation.py b/plotsim/validation.py index 84845fc..319dad2 100644 --- a/plotsim/validation.py +++ b/plotsim/validation.py @@ -269,6 +269,12 @@ def validate_treatment_assignments(config: PlotsimConfig) -> list[str]: 2. ``treatment_lift_log_odds`` must be finite when set. ``inf`` or ``nan`` would propagate through the logit shift to NaN cell values for every post-treatment row. + 3. ``treatment_target_metric`` must match a metric name declared + in ``config.metrics`` when set. A typo'd or stale metric name + would silently fall through the per-metric gate (no metric + matches, so the shift never applies) and produce a dataset + where the lift is invisible — the same silent-dead-weight + failure mode that gate 1 closes for ``treatment_start_period``. Note on ``treatment_start_period < entity.start_period``: this is NOT a gate. An entity that arrives at period 6 with @@ -287,11 +293,13 @@ def validate_treatment_assignments(config: PlotsimConfig) -> list[str]: """ errors: list[str] = [] n_periods = config.time_window.period_count() + metric_names = {m.name for m in config.metrics} for entity in config.entities: if ( entity.treatment_lift_log_odds is None and entity.treatment_group is None and entity.treatment_start_period == 0 + and entity.treatment_target_metric is None ): # No treatment fields set — the no-op default. Skip every # gate so a config that doesn't use the M8c surface is @@ -311,6 +319,16 @@ def validate_treatment_assignments(config: PlotsimConfig) -> list[str]: f"{lift} is non-finite; the logit shift would " f"propagate NaN into every post-treatment cell" ) + if ( + entity.treatment_target_metric is not None + and entity.treatment_target_metric not in metric_names + ): + errors.append( + f"entity {entity.name!r}: treatment_target_metric=" + f"{entity.treatment_target_metric!r} does not match any " + f"declared metric in config.metrics; the per-metric gate " + f"would never fire and the lift would be invisible" + ) return errors diff --git a/tests/test_heteroscedastic_noise.py b/tests/test_heteroscedastic_noise.py index 0632391..e683f9d 100644 --- a/tests/test_heteroscedastic_noise.py +++ b/tests/test_heteroscedastic_noise.py @@ -264,9 +264,11 @@ def test_manifest_omits_noise_config_when_off(): assert manifest.noise_config is None -def test_manifest_schema_version_pins_1_8(): +def test_manifest_schema_version_pins_1_9(): """0.6-M22 bumped the manifest schema version 1.6 → 1.7; 0.6-M23 bumped - 1.7 → 1.8. The test_schema_version_bumped_to_1_8 test in + 1.7 → 1.8; 0.6-M24 bumped 1.8 → 1.9 for the additive + ``target_metric`` field on ``TreatmentAssignment`` / + ``TreatmentCohort``. The test_schema_version_bumped_to_1_9 test in tests/test_manifest.py is the authoritative pin; this assertion is a load-bearing reminder that the heteroscedastic-emitting path participates in the schema-version contract too (so a future mission that adds a @@ -275,7 +277,7 @@ def test_manifest_schema_version_pins_1_8(): warnings.simplefilter("ignore") cfg = _build_small_config(scale_with_trajectory=True) _tables, manifest = _generate_and_manifest(cfg) - assert manifest.schema_version == "1.8" + assert manifest.schema_version == "1.9" # --- End-to-end byte-identity for default-off engine path ------------------- diff --git a/tests/test_manifest.py b/tests/test_manifest.py index 0adc40a..1caaa71 100644 --- a/tests/test_manifest.py +++ b/tests/test_manifest.py @@ -365,7 +365,7 @@ def test_all_bundled_templates_produce_valid_manifest(template, tmp_path): # --- 0.6-M5: causal_graph --------------------------------------------------- -def test_schema_version_bumped_to_1_8(): +def test_schema_version_bumped_to_1_9(): """0.6-M5 added causal_graph / correlations / outlier_injections (1.0 → 1.1). 0.6-M8a added per-entity ``active_window`` on EntityArchetypeAssignment (1.1 → 1.2). 0.6-M8c added per-entity ``treatment`` and the top-level @@ -380,13 +380,16 @@ def test_schema_version_bumped_to_1_8(): (1.6 → 1.7). 0.6-M23 extended ``NoiseConfigInfo`` with ``noise_family`` and ``degrees_of_freedom`` for heavy-tailed noise families and broadened the emission criterion to cover non-default families - (1.7 → 1.8). + (1.7 → 1.8). 0.6-M24 added the optional ``target_metric`` field on + both per-entity ``TreatmentAssignment`` and per-cohort + ``TreatmentCohort`` records (1.8 → 1.9); ``None`` default preserves + byte-equivalence for pre-M24 configs modulo the version string. The version pin lives in this test rather than just the manifest module - so a downstream consumer pinning ``schema_version >= "1.8"`` has a + so a downstream consumer pinning ``schema_version >= "1.9"`` has a direct on-disk contract test it can reference. """ - assert MANIFEST_SCHEMA_VERSION == "1.8" + assert MANIFEST_SCHEMA_VERSION == "1.9" def test_causal_graph_emits_one_edge_per_metric_with_lag(saas_run): diff --git a/tests/test_multi_source.py b/tests/test_multi_source.py index 8d34c28..a4a8547 100644 --- a/tests/test_multi_source.py +++ b/tests/test_multi_source.py @@ -306,15 +306,17 @@ def test_manifest_source_entity_mappings_complete(): assert field in canonical_columns -def test_manifest_schema_bumped_to_1_8(): +def test_manifest_schema_bumped_to_1_9(): # 1.5 introduced the source_entity_mappings list (0.6-M13); 1.6 added # the parent_child_relations list (0.6-M18); 1.7 adds the optional # ``noise_config`` field (0.6-M22); 1.8 extends ``NoiseConfigInfo`` # with ``noise_family`` / ``degrees_of_freedom`` and broadens its - # emission criterion to cover non-gaussian families (0.6-M23). This - # module's contract tracks the pin at the schema level, not the field - # semantics. - assert MANIFEST_SCHEMA_VERSION == "1.8" + # emission criterion to cover non-gaussian families (0.6-M23); 1.9 + # adds the optional ``target_metric`` field on ``TreatmentAssignment`` + # / ``TreatmentCohort`` for per-metric treatment effects (0.6-M24). + # This module's contract tracks the pin at the schema level, not the + # field semantics. + assert MANIFEST_SCHEMA_VERSION == "1.9" # ── AC6: single-source configs unchanged (no multi_source block) ────────── @@ -404,7 +406,7 @@ def test_bundled_template_loads_and_validates(tmp_path: Path): assert (out_dir / "dim_company_crm.csv").is_file() assert (out_dir / "dim_company_billing.csv").is_file() manifest_payload = json.loads((out_dir / "manifest.json").read_text(encoding="utf-8")) - assert manifest_payload["schema_version"] == "1.8" + assert manifest_payload["schema_version"] == "1.9" # 20 entities × 2 sources = 40 mapping records. assert len(manifest_payload["source_entity_mappings"]) == 40 diff --git a/tests/test_per_metric_treatment.py b/tests/test_per_metric_treatment.py new file mode 100644 index 0000000..39dcdf9 --- /dev/null +++ b/tests/test_per_metric_treatment.py @@ -0,0 +1,590 @@ +"""0.6-M24: per-metric treatment effects. + +Extends the M8c treatment surface so a configured +``treatment_lift_log_odds`` can target a single named metric instead of +applying trajectory-wide. The per-metric gate lives in +``generate_metrics_for_period``: when ``treatment_target_metric`` is set, +only the metric whose name matches receives the shift; every other +metric sees ``treatment_shift=0.0`` and is byte-identical to its +control-arm draw. + +Locks in: + +* Default behaviour (``treatment_target_metric=None``) preserves + pre-M24 output byte-for-byte. Every metric sees the lift, the same + trajectory-wide behaviour M8c shipped. +* Per-metric targeting: with ``target_metric="m1"`` and a positive + lift, only ``m1``'s realized values shift in the treatment cohort; + ``m2`` stays at the control distribution. Under + ``correlations=[]`` + zero-noise, the non-targeted metric is + byte-identical between treatment and control arms (the strongest + pin — no residual leakage at the sample-draw level). +* Correlation pipeline does not leak the lift to a correlated + non-targeted metric. With ``target_metric="m1"`` and a strong + baseline correlation ``m1 ↔ m2``, ``m2``'s post-treatment mean + remains within statistical noise of its control mean — the copula + operates on residuals around each metric's own (un-shifted) center, + so the lift on ``m1`` does not propagate. +* Validator rejects ``target_metric`` values that do not match any + declared metric name — same silent-dead-weight failure mode that + M8c's ``treatment_start_period >= n_periods`` gate closes. +* Manifest emits ``target_metric`` per entity AND per cohort. Schema + bumped 1.8 → 1.9 (additive; defaults to ``None``, so pre-M24 readers + parse 1.9 manifests cleanly except for the new field). +* Builder ``TreatmentConfig.target_metric`` propagates to every + expanded entity in the segment (treatment AND control arms — the + field is harmless on control entities because they have no lift to + gate, and carrying it on both arms preserves ground-truth symmetry). +""" + +from __future__ import annotations + +import warnings + +import numpy as np +import pytest +from pydantic import ValidationError + +from plotsim import create, generate_tables_with_state +from plotsim.builder.input import TreatmentConfig +from plotsim.config import ( + Archetype, + Column, + CorrelationPair, + CurveSegment, + Domain, + Entity, + Metric, + OutputConfig, + PlotsimConfig, + SurrogateKeyWarning, + Table, + TimeWindow, +) +from plotsim.manifest import MANIFEST_SCHEMA_VERSION, build_manifest + + +# --- Helpers --------------------------------------------------------------- + + +def _two_metric_engine_config( + entities: list[Entity], + *, + correlations: list[CorrelationPair] | None = None, +) -> PlotsimConfig: + """Engine-direct config with two metrics on a flat archetype. + + Mirrors ``_engine_config`` from ``test_treatment_control.py`` but + declares two metrics so the per-metric gate has something to gate + against. Both metrics use a beta distribution so the realized + values lie in [0,1] and the logit shift's effect is cleanly + recoverable via a difference of means. + """ + arch = Archetype( + name="flat", + label="flat", + description="constant 0.5 plateau", + curve_segments=[ + CurveSegment( + curve="plateau", + params={"level": 0.5}, + start_pct=0.0, + end_pct=1.0, + ), + ], + ) + m1 = Metric( + name="m1", + label="m1", + distribution="beta", + params={"alpha": 2.0, "beta": 5.0}, + polarity="positive", + ) + m2 = Metric( + name="m2", + label="m2", + distribution="beta", + params={"alpha": 2.0, "beta": 5.0}, + polarity="positive", + ) + fct = Table( + name="fct_m", + type="fact", + grain="per_entity_per_period", + primary_key=["date_key", "entity_id"], + foreign_keys=["dim_date.date_key", "dim_entity.entity_id"], + columns=[ + Column(name="date_key", dtype="id", source="fk:dim_date.date_key"), + Column(name="entity_id", dtype="id", source="fk:dim_entity.entity_id"), + Column(name="m1", dtype="float", source="metric:m1"), + Column(name="m2", dtype="float", source="metric:m2"), + ], + ) + dim_date = Table( + name="dim_date", + type="dim", + grain="per_period", + primary_key="date_key", + columns=[ + Column(name="date_key", dtype="id", source="pk"), + Column(name="date", dtype="date", source="generated:date_key"), + ], + ) + dim_entity = Table( + name="dim_entity", + type="dim", + grain="per_entity", + primary_key="entity_id", + columns=[ + Column(name="entity_id", dtype="id", source="pk"), + ], + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", SurrogateKeyWarning) + return PlotsimConfig( + domain=Domain( + name="t", + description="t", + entity_type="entity", + entity_label="Entities", + ), + time_window=TimeWindow( + start="2024-01", + end="2024-12", + granularity="monthly", + ), + seed=0, + metrics=[m1, m2], + archetypes=[arch], + entities=entities, + tables=[dim_date, dim_entity, fct], + correlations=correlations or [], + output=OutputConfig(format="csv", directory="out/m24"), + ) + + +def _split_cohort_means( + tables: dict, ctrl_pks: set, trt_pks: set +) -> dict[str, tuple[float, float]]: + """Return ``{metric_name: (ctrl_mean, trt_mean)}`` over the fact table.""" + fct = tables["fct_m"] + ctrl_rows = fct[fct["entity_id"].isin(ctrl_pks)] + trt_rows = fct[fct["entity_id"].isin(trt_pks)] + return { + col: (float(ctrl_rows[col].mean()), float(trt_rows[col].mean())) for col in ("m1", "m2") + } + + +# --- Default behaviour ------------------------------------------------------ + + +def test_entity_default_target_metric_is_none(): + """Sanity: the new field is opt-in. An ``Entity`` constructed + without ``treatment_target_metric`` carries ``None`` — the no-op + default that drops the entity into the pre-M24 trajectory-wide + lane. + """ + e = Entity(name="x", archetype="flat", size=1) + assert e.treatment_target_metric is None + + +def test_target_metric_none_lifts_every_metric(): + """The pre-M24 contract: with ``treatment_target_metric=None`` and + a positive lift, every metric in the treatment cohort shifts + upward. This pins the no-op-default behaviour at the table level — + if the gate ever started defaulting to ``"first_metric"`` or + similar, this test would catch it. + """ + ctrl = [ + Entity(name=f"c_{i}", archetype="flat", size=1, treatment_group="control") + for i in range(30) + ] + trt = [ + Entity( + name=f"t_{i}", + archetype="flat", + size=1, + treatment_group="treatment", + treatment_lift_log_odds=1.5, + treatment_start_period=0, + treatment_target_metric=None, + ) + for i in range(30) + ] + cfg = _two_metric_engine_config(ctrl + trt) + rng = np.random.default_rng(cfg.seed) + tables, _state = generate_tables_with_state(cfg, rng) + de = tables["dim_entity"] + ctrl_pks = set(de.iloc[:30]["entity_id"]) + trt_pks = set(de.iloc[30:]["entity_id"]) + means = _split_cohort_means(tables, ctrl_pks, trt_pks) + for metric in ("m1", "m2"): + ctrl_mean, trt_mean = means[metric] + assert trt_mean - ctrl_mean > 0.05, ( + f"target_metric=None failed to lift {metric}: " + f"ctrl_mean={ctrl_mean}, trt_mean={trt_mean}" + ) + + +# --- Per-metric targeting -------------------------------------------------- + + +def test_target_metric_shifts_only_named_metric(): + """AC #1: with ``treatment_target_metric="m1"`` and a positive lift, + the treatment cohort's ``m1`` mean shifts UP versus control; its + ``m2`` mean stays within statistical noise of control (no + propagation). 30+30 sample with no correlations and no noise so + the per-metric gate is the only thing the test pins. + """ + ctrl = [ + Entity(name=f"c_{i}", archetype="flat", size=1, treatment_group="control") + for i in range(30) + ] + trt = [ + Entity( + name=f"t_{i}", + archetype="flat", + size=1, + treatment_group="treatment", + treatment_lift_log_odds=1.5, + treatment_start_period=0, + treatment_target_metric="m1", + ) + for i in range(30) + ] + cfg = _two_metric_engine_config(ctrl + trt) + rng = np.random.default_rng(cfg.seed) + tables, _state = generate_tables_with_state(cfg, rng) + de = tables["dim_entity"] + ctrl_pks = set(de.iloc[:30]["entity_id"]) + trt_pks = set(de.iloc[30:]["entity_id"]) + means = _split_cohort_means(tables, ctrl_pks, trt_pks) + m1_ctrl, m1_trt = means["m1"] + m2_ctrl, m2_trt = means["m2"] + assert ( + m1_trt - m1_ctrl > 0.05 + ), f"target_metric='m1' failed to lift m1: ctrl={m1_ctrl}, trt={m1_trt}" + # AC #1 envelope: non-targeted metric must stay within 5% of + # control mean. Picked over a t-test to keep the assertion + # threshold concrete and seed-independent. + rel_delta = abs(m2_trt - m2_ctrl) / max(m2_ctrl, 1e-9) + assert rel_delta < 0.05, ( + f"target_metric='m1' leaked into m2: ctrl={m2_ctrl}, trt={m2_trt}, " + f"relative delta={rel_delta}" + ) + + +def test_target_metric_named_byte_identical_non_targeted_under_no_correlations(): + """Strongest pin on the gate: under ``correlations=[]`` and + zero-noise, the non-targeted metric must be BYTE-identical between + a config that targets ``m1`` and an otherwise-identical config + with no treatment at all. Both runs walk the per-entity RNG + forward through the same number of bytes per cell (distribution + draws consume a fixed number of bytes per call regardless of + loc/scale), so if the gate correctly zeros the shift on ``m2``, + its series is bit-equal across runs. + + If the gate ever leaked the shift to ``m2`` — even by a single + logit unit — this test fails immediately. + """ + e_trt = Entity( + name="solo", + archetype="flat", + size=1, + treatment_lift_log_odds=2.0, + treatment_start_period=0, + treatment_target_metric="m1", + ) + e_none = Entity(name="solo", archetype="flat", size=1) + cfg_trt = _two_metric_engine_config([e_trt]) + cfg_none = _two_metric_engine_config([e_none]) + rng_trt = np.random.default_rng(cfg_trt.seed) + rng_none = np.random.default_rng(cfg_none.seed) + tables_trt, _ = generate_tables_with_state(cfg_trt, rng_trt) + tables_none, _ = generate_tables_with_state(cfg_none, rng_none) + m2_trt = tables_trt["fct_m"]["m2"].to_numpy() + m2_none = tables_none["fct_m"]["m2"].to_numpy() + assert np.array_equal(m2_trt, m2_none), ( + "non-targeted metric m2 diverged between target_metric='m1' and " + "no-treatment runs — the per-metric gate leaked the shift" + ) + + +def test_target_metric_named_lifts_targeted_under_no_correlations(): + """Symmetric pin: ``m1`` IS shifted when targeted. Run the same + pair of configs as the byte-identity test and assert the ``m1`` + arrays DIFFER — guards against a wholesale-gate bug where the + shift is zeroed for every metric, not just the non-targeted ones. + """ + e_trt = Entity( + name="solo", + archetype="flat", + size=1, + treatment_lift_log_odds=2.0, + treatment_start_period=0, + treatment_target_metric="m1", + ) + e_none = Entity(name="solo", archetype="flat", size=1) + cfg_trt = _two_metric_engine_config([e_trt]) + cfg_none = _two_metric_engine_config([e_none]) + rng_trt = np.random.default_rng(cfg_trt.seed) + rng_none = np.random.default_rng(cfg_none.seed) + tables_trt, _ = generate_tables_with_state(cfg_trt, rng_trt) + tables_none, _ = generate_tables_with_state(cfg_none, rng_none) + m1_trt = tables_trt["fct_m"]["m1"].to_numpy() + m1_none = tables_none["fct_m"]["m1"].to_numpy() + assert not np.array_equal(m1_trt, m1_none), ( + "targeted metric m1 was NOT shifted under target_metric='m1' — " + "the per-metric gate also zeroed the shift on the named metric" + ) + assert m1_trt.mean() > m1_none.mean(), ( + f"targeted metric m1 shifted in the wrong direction: " + f"trt_mean={m1_trt.mean()} vs none_mean={m1_none.mean()}" + ) + + +# --- Correlation leakage probe --------------------------------------------- + + +def test_target_metric_no_leakage_through_correlated_pair(): + """Dispatch Decision 3: with a strong baseline correlation + ``m1 ↔ m2`` and ``target_metric="m1"``, does the copula propagate + the m1 lift to m2 at the population mean? The copula at + ``apply_correlations`` operates on residuals around each metric's + own (un-shifted) center, so m2's center stays unchanged and the + correlated-residual transform preserves the mean. This test pins + that analytic claim empirically. + """ + ctrl = [ + Entity(name=f"c_{i}", archetype="flat", size=1, treatment_group="control") + for i in range(60) + ] + trt = [ + Entity( + name=f"t_{i}", + archetype="flat", + size=1, + treatment_group="treatment", + treatment_lift_log_odds=1.5, + treatment_start_period=0, + treatment_target_metric="m1", + ) + for i in range(60) + ] + cfg = _two_metric_engine_config( + ctrl + trt, + correlations=[CorrelationPair(metric_a="m1", metric_b="m2", coefficient=0.8)], + ) + rng = np.random.default_rng(cfg.seed) + tables, _state = generate_tables_with_state(cfg, rng) + de = tables["dim_entity"] + ctrl_pks = set(de.iloc[:60]["entity_id"]) + trt_pks = set(de.iloc[60:]["entity_id"]) + means = _split_cohort_means(tables, ctrl_pks, trt_pks) + m1_ctrl, m1_trt = means["m1"] + m2_ctrl, m2_trt = means["m2"] + # m1 should still shift even with the copula active. + assert m1_trt - m1_ctrl > 0.03, ( + f"correlated m1 didn't shift under target_metric: " f"ctrl={m1_ctrl}, trt={m1_trt}" + ) + # m2 should NOT shift via correlation leakage. Tolerance is wider + # than the no-correlations pin because the copula's residual + # transform introduces a small sample-mean drift even at lift=0, + # but it should stay well below the m1 shift. + rel_leak = abs(m2_trt - m2_ctrl) / max(m2_ctrl, 1e-9) + m1_shift = (m1_trt - m1_ctrl) / max(m1_ctrl, 1e-9) + assert rel_leak < 0.10, ( + f"m2 leaked through correlation: ctrl={m2_ctrl}, trt={m2_trt}, " + f"relative delta={rel_leak} (>10%)" + ) + assert rel_leak < m1_shift, ( + f"m2 leakage ({rel_leak}) exceeded m1 shift ({m1_shift}) — " + f"correlation propagation broke the per-metric gate's intent" + ) + + +# --- Validator ------------------------------------------------------------- + + +def test_validator_rejects_unknown_target_metric(): + """AC #3: a config naming a metric that doesn't exist must raise at + load time. Catches typos and stale references that would otherwise + silently fall through the per-metric gate (no metric matches → the + lift is never applied) and produce a dataset where the treatment + is invisible. + """ + entity = Entity( + name="t", + archetype="flat", + size=1, + treatment_group="treatment", + treatment_lift_log_odds=1.0, + treatment_target_metric="not_a_metric", + ) + with pytest.raises(ValidationError) as excinfo: + _two_metric_engine_config([entity]) + msg = str(excinfo.value) + assert "not_a_metric" in msg, f"validator did not name the offending metric: {msg}" + assert "treatment_target_metric" in msg, f"validator did not name the offending field: {msg}" + + +def test_validator_accepts_known_target_metric(): + """Sanity: a config naming a declared metric loads cleanly.""" + entity = Entity( + name="t", + archetype="flat", + size=1, + treatment_group="treatment", + treatment_lift_log_odds=1.0, + treatment_target_metric="m1", + ) + cfg = _two_metric_engine_config([entity]) + # Round-trips through the validator without raising. + assert cfg.entities[0].treatment_target_metric == "m1" + + +def test_validator_skips_when_no_treatment_fields(): + """The no-op-default skip predicate must include + ``treatment_target_metric`` — an entity with EVERY treatment field + unset (including the M24 addition) must still bypass the gate so + pre-M24 configs remain validator-invisible. + """ + entity = Entity(name="t", archetype="flat", size=1) + cfg = _two_metric_engine_config([entity]) + assert cfg.entities[0].treatment_target_metric is None + assert cfg.entities[0].treatment_lift_log_odds is None + + +# --- Manifest -------------------------------------------------------------- + + +def test_manifest_records_target_metric_per_entity(): + """AC #4: the per-entity manifest record carries ``target_metric``. + Configs that don't use the M24 surface continue to emit + ``target_metric=None`` so 1.8 readers parse cleanly. + """ + entities = [ + Entity(name="c", archetype="flat", size=1, treatment_group="control"), + Entity( + name="t", + archetype="flat", + size=1, + treatment_group="treatment", + treatment_lift_log_odds=1.0, + treatment_target_metric="m1", + ), + ] + cfg = _two_metric_engine_config(entities) + rng = np.random.default_rng(cfg.seed) + tables, state = generate_tables_with_state(cfg, rng) + manifest = build_manifest(cfg, state.trajectories, tables) + by_entity = {a.entity: a for a in manifest.archetype_assignments} + assert by_entity["c"].treatment is not None + assert by_entity["c"].treatment.target_metric is None + assert by_entity["t"].treatment is not None + assert by_entity["t"].treatment.target_metric == "m1" + + +def test_manifest_records_target_metric_per_cohort(): + """AC #4: the per-cohort manifest record carries ``target_metric``. + Homogeneous cohort (every entity in the cohort shares the same + target metric, which is the canonical segment-driven shape) reports + that value directly. + """ + entities = [ + Entity(name=f"c_{i}", archetype="flat", size=1, treatment_group="control") for i in range(5) + ] + [ + Entity( + name=f"t_{i}", + archetype="flat", + size=1, + treatment_group="treatment", + treatment_lift_log_odds=1.0, + treatment_target_metric="m2", + ) + for i in range(5) + ] + cfg = _two_metric_engine_config(entities) + rng = np.random.default_rng(cfg.seed) + tables, state = generate_tables_with_state(cfg, rng) + manifest = build_manifest(cfg, state.trajectories, tables) + by_label = {c.label: c for c in manifest.treatment_cohorts} + assert by_label["control"].target_metric is None + assert by_label["treatment"].target_metric == "m2" + + +def test_manifest_schema_version_bumped_for_m24(): + """Schema version pin: M24's additive field bumps the manifest + schema. Pre-M24 readers see a 1.9 manifest's new ``target_metric`` + field default to ``None`` so they parse cleanly — but the schema + string itself must advance to signal that the new field exists. + """ + assert MANIFEST_SCHEMA_VERSION == "1.9" + + +# --- Builder propagation --------------------------------------------------- + + +def test_builder_treatment_config_target_metric_propagates_to_entities(): + """``TreatmentConfig.target_metric`` set on a segment lands on + every entity expanded from that segment — both treatment AND + control arms. The field is harmless on control entities (they + have no lift to gate), but carrying it on both arms preserves + ground-truth symmetry so a downstream analyst can recover the + full experiment design from a single ``Entity`` record. + """ + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + cfg = create( + about="m24 test", + unit="company", + window=("2024-01", "2024-12"), + metrics=[ + {"name": "engagement", "type": "score", "polarity": "positive"}, + {"name": "satisfaction", "type": "score", "polarity": "positive"}, + ], + segments=[ + { + "name": "s", + "count": 10, + "archetype": "flat", + "treatment": TreatmentConfig( + fraction=0.5, + lift_log_odds=1.0, + target_metric="engagement", + ), + }, + ], + seed=42, + ) + for e in cfg.entities: + assert ( + e.treatment_target_metric == "engagement" + ), f"entity {e.name!r} did not inherit target_metric from segment" + + +def test_builder_treatment_config_target_metric_defaults_to_none(): + """A ``TreatmentConfig`` constructed without ``target_metric`` (the + pre-M24 shape) leaves every expanded entity with + ``treatment_target_metric=None`` — preserves the trajectory-wide + behaviour every existing builder-using template depends on. + """ + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + cfg = create( + about="m24 default test", + unit="company", + window=("2024-01", "2024-12"), + metrics=[{"name": "engagement", "type": "score", "polarity": "positive"}], + segments=[ + { + "name": "s", + "count": 10, + "archetype": "flat", + "treatment": TreatmentConfig(fraction=0.5, lift_log_odds=1.0), + }, + ], + seed=42, + ) + for e in cfg.entities: + assert e.treatment_target_metric is None diff --git a/tests/test_time_varying_correlations.py b/tests/test_time_varying_correlations.py index 42b9a37..423c2c5 100644 --- a/tests/test_time_varying_correlations.py +++ b/tests/test_time_varying_correlations.py @@ -679,13 +679,15 @@ def test_project_phase_correlation_or_issue_invalid_index(self): class TestManifestIntegration: """Manifest carries per-phase entries and the new top-level summary.""" - def test_schema_version_is_1_8(self): + def test_schema_version_is_1_9(self): # 0.6-M13 bumped 1.4 → 1.5 for ``source_entity_mappings``; 0.6-M18 # bumped 1.5 → 1.6 for ``parent_child_relations``; 0.6-M22 bumped # 1.6 → 1.7 for the optional ``noise_config`` field; 0.6-M23 # bumped 1.7 → 1.8 for ``noise_family`` / ``degrees_of_freedom`` - # on ``NoiseConfigInfo`` and broadened its emission criterion. - assert MANIFEST_SCHEMA_VERSION == "1.8" + # on ``NoiseConfigInfo`` and broadened its emission criterion; + # 0.6-M24 bumped 1.8 → 1.9 for the additive ``target_metric`` + # field on ``TreatmentAssignment`` / ``TreatmentCohort``. + assert MANIFEST_SCHEMA_VERSION == "1.9" def test_no_phases_yields_empty_correlation_phases_list(self): cfg = _two_metric_config(