From b1ae22a7cbbdcbe76ae2a8d0f7e3ac6150eb3228 Mon Sep 17 00:00:00 2001 From: Andres Contreras Date: Thu, 25 Jun 2026 19:45:13 +0200 Subject: [PATCH 1/2] feat: explainability (ExplainerPort + permutation/SHAP) + integrity fixes Explainability (new domain module, hexagonal): - ExplainerPort + typed GlobalExplanation/LocalExplanation - PermutationImportanceExplainer (dependency-free default) + ShapExplainer (optional 'explain' extra: global + local attributions) - ExplainabilityAutoConfiguration (SHAP when installed, else permutation) - AutoMLResult.explain() + DI wiring via AutoML.from_context - 4 real-data TDD tests (signal>noise, AutoML integration, DI, SHAP) - docs/explainability.md (real permutation-importance output) + nav/home cards Integrity / no-fake-data fixes (from a full-codebase gap audit): - plain AutoML() now includes installed XGBoost/LightGBM/CatBoost by default, matching the documented '+boosting when installed' (TDD) - security.md: sandbox tiers (docker/e2b), timeout_seconds and HITL approval marked as declared-config/roadmap (only static-analysis + restricted exec are enforced today); cost-benefit gate corrected to a post-hoc measured-lift filter - README: list real adapters vs reference/planned (AutoGluon/Feast/BentoML) Real-LLM path verified end-to-end (Claude haiku-4-5): gate accepts measured-lift features, rejects the rest; agentic loop verifies 9 attempts. --- README.md | 14 ++- docs/explainability.md | 103 ++++++++++++++++ docs/index.md | 5 + docs/security.md | 21 +++- mkdocs.yml | 1 + pyproject.toml | 5 +- .../automl/__init__.py | 15 +++ .../automl/facade.py | 27 +++-- .../explainability/__init__.py | 65 ++++++++++ .../explainability/adapters.py | 113 ++++++++++++++++++ .../explainability/auto_configuration.py | 41 +++++++ tests/explainability/test_explainability.py | 103 ++++++++++++++++ tests/test_automl.py | 4 +- tests/test_default_trainers.py | 20 ++++ uv.lock | 61 ++++++++-- 15 files changed, 568 insertions(+), 30 deletions(-) create mode 100644 docs/explainability.md create mode 100644 src/fireflyframework_datascience/explainability/__init__.py create mode 100644 src/fireflyframework_datascience/explainability/adapters.py create mode 100644 src/fireflyframework_datascience/explainability/auto_configuration.py create mode 100644 tests/explainability/test_explainability.py create mode 100644 tests/test_default_trainers.py diff --git a/README.md b/README.md index 5a221db..b17bf8e 100644 --- a/README.md +++ b/README.md @@ -44,9 +44,11 @@ swappability, and security by default. - **One reproducible pattern.** The LLM proposes code/features/pipelines/seeds; a deterministic classical engine trains, scores, and selects; every GenAI step is gated behind a measured improvement over a seeded classical baseline. -- **Hexagonal & swappable.** Every ML/MLOps library (scikit-learn, XGBoost, LightGBM, CatBoost, - AutoGluon, TabPFN, PyTorch Lightning, HuggingFace, MLflow, Feast, BentoML, …) is a swappable adapter - behind a `Protocol` port. The core stays library-agnostic. +- **Hexagonal & swappable.** Each ML/MLOps library sits behind a `Protocol` port, so the core stays + library-agnostic. Adapters that ship today: scikit-learn, XGBoost, LightGBM, CatBoost, TabPFN, + PyTorch Lightning, HuggingFace, and MLflow. Ports with reference or planned adapters (AutoGluon, + Feast, BentoML packaging, a model registry) are marked as such in the docs — the seams exist; the + adapters are landing. - **Firefly-native.** Auto-configuration, dependency injection, a startup banner + wiring summary, CalVer, and the same CI gates as the rest of the Firefly Framework. @@ -121,8 +123,10 @@ Five acyclic layers, mirroring `fireflyframework-agentic` with a **DataScience** ### Hexagonal ports & adapters -Every ML/MLOps library (scikit-learn, XGBoost, AutoGluon, TabPFN, PyTorch Lightning, HuggingFace, -MLflow, BentoML, …) is a swappable adapter behind a `Protocol` port. The core stays library-agnostic. +Each ML/MLOps library sits behind a `Protocol` port, so the core stays library-agnostic. Shipping +adapters today: scikit-learn, XGBoost, LightGBM, CatBoost, TabPFN, PyTorch Lightning, HuggingFace, +MLflow. AutoGluon, Feast, BentoML packaging and a model registry are ports with reference/planned +adapters.

Hexagonal ports and adapters diff --git a/docs/explainability.md b/docs/explainability.md new file mode 100644 index 0000000..9f48ab3 --- /dev/null +++ b/docs/explainability.md @@ -0,0 +1,103 @@ +# Explainability + +**Every fitted model can explain which features drive its predictions — with deterministic, +well-understood methods, never an LLM.** + +Explainability is a first-class port in Firefly DataScience. After AutoML selects and refits a winner, +you get a model that can describe *why* it predicts what it predicts: globally (which features matter +across the dataset) and — with the optional SHAP adapter — locally (which features moved a single +prediction). This is table-stakes for regulated domains (lending, healthcare, insurance) where a model +you cannot explain is a model you cannot ship. + +!!! firefly "Explanations are classical, not generated" + + Importances come from **permutation importance** (the dependency-free default) or **SHAP** — both + deterministic, peer-reviewed methods. The LLM is never in the explanation path: just as GenAI + *proposes* and a classical engine *decides*, here the classical engine also *explains*. The numbers + are reproducible from a seed. + +## Global feature importance + +Every `AutoMLResult` can explain its winning model. Call `explain()` with a dataset (typically your +held-out split): + +```python +from fireflyframework_datascience.automl import AutoML +from fireflyframework_datascience.datasets.adapters import SklearnDatasetLoader + +ds = SklearnDatasetLoader().load("breast_cancer") +train, test = ds.train_test_split(test_size=0.25, random_state=0) + +result = AutoML(cv=3, random_state=0).fit(train) +explanation = result.explain(test) # (1)! + +print(explanation.method) # "permutation_importance" +for name, importance in explanation.top(8): + print(f"{name:<26} {importance:+.4f}") +``` + +1. `explain()` uses the DI-wired `ExplainerPort` when the result came from `AutoML.from_context`, + otherwise the dependency-free permutation-importance explainer. It returns a `GlobalExplanation`. + +!!! success "Expected (a real run on `breast_cancer`)" + + ```text + winner: linear | holdout roc_auc: 0.9952 + permutation_importance + radius error +0.0182 + fractal dimension error +0.0140 + mean concave points +0.0091 + mean concavity +0.0077 + compactness error +0.0056 + worst area +0.0056 + worst symmetry +0.0056 + perimeter error +0.0049 + ``` + + Each value is the mean drop in the model's score when that feature is randomly permuted — higher + means more important. A pure-noise column lands at ≈ 0. Exact numbers depend on the data, the + winning model, and the seed. + +`GlobalExplanation` exposes `feature_importances` (a `dict[str, float]`), `std`, `baseline_score`, +`.top(k)`, and `.to_frame()` for a tidy pandas table. + +## Local (per-prediction) explanations with SHAP + +For per-prediction attributions — "why did *this* applicant score the way they did?" — install the +optional `explain` extra and the SHAP explainer is used automatically: + +```bash +uv add "fireflyframework-datascience[explain]" # adds shap +``` + +```python +from fireflyframework_datascience.explainability.adapters import ShapExplainer + +explainer = ShapExplainer() +local = explainer.explain_local(result.best_model, test.X.iloc[:1]) # one LocalExplanation per row +for feature, contribution in local[0].top(5): + print(f"{feature:<26} {contribution:+.4f}") +``` + +Each `LocalExplanation` carries the `prediction`, per-feature `contributions`, and a `base_value`; +`.top(k)` ranks features by absolute contribution. + +## How it fits the architecture + +| Piece | What it is | +| --- | --- | +| `ExplainerPort` | The `Protocol`: `supports(model)` + `explain_global(model, dataset)`. | +| `PermutationImportanceExplainer` | Default adapter — model-agnostic, scikit-learn only (no extra dependency). | +| `ShapExplainer` | Optional adapter (`explain` extra) — global **and** local attributions. | +| `ExplainabilityAutoConfiguration` | Registers the explainer (SHAP when installed, else permutation). | +| `AutoMLResult.explain(dataset)` | The handle most users hold — delegates to the wired explainer. | + +Because it is a port, you can inject your own explainer (register an `ExplainerPort` bean and it wins), +and the DI-wired `AutoML.from_context(app)` automatically threads it into every result. + +## See also + +- [Classical AutoML](automl.md) — the engine that produces the model you explain. +- [Architecture](architecture.md) — how ports, adapters, and auto-configuration fit together. +- [GenAI features](genai-features.md) — the gated accelerator (proposals are explained the same way). +- [Security](security.md) — why generated code is treated as untrusted, and what is enforced today. diff --git a/docs/index.md b/docs/index.md index 81d1e68..2f33cba 100644 --- a/docs/index.md +++ b/docs/index.md @@ -177,6 +177,11 @@ app = FireflyDataScienceApplication.run(config=config) --- The classical-first engine: train, score, select. +- :material-lightbulb-on-outline:{ .middle } __[Explainability](explainability.md)__ + + --- + Deterministic global + local feature importances (permutation, SHAP). + - :material-creation-outline:{ .middle } __[GenAI features](genai-features.md)__ --- diff --git a/docs/security.md b/docs/security.md index ebd3244..21257da 100644 --- a/docs/security.md +++ b/docs/security.md @@ -99,7 +99,16 @@ The post-conditions are enforced in order: a non-`DataFrame` result raises `Feat ## Layer 3 — the tiered sandbox -Layers 1 and 2 run **in-process**. They block the obvious capabilities, but a determined escape against a CPython process is never something to bet sensitive data on. For untrusted data, escalate isolation with `execution.sandbox` in `ExecutionConfig`: +!!! warning "Implementation status — what is enforced today" + + Layers 1–2 (static analysis + the restricted in-process namespace) are **enforced now** and are + what protects you today. The sandbox *tiers* below (`docker`, `e2b`), `execution.timeout_seconds`, + and the `require_approval` HITL gate are currently **declared, validated configuration** — their + routing and enforcement are on the roadmap (a `CodeExecutorPort` with per-tier adapters and an + approval gate). Until that ships, the real isolation is the in-process `monty` / `local` path: + **do not run genuinely untrusted data through GenAI expecting container/microVM isolation yet.** + +Layers 1 and 2 run **in-process**. They block the obvious capabilities, but a determined escape against a CPython process is never something to bet sensitive data on. The configuration surface below lets you *declare* stronger isolation for untrusted data via `execution.sandbox` in `ExecutionConfig` (enforcement is roadmap, per the note above): ```python from fireflyframework_datascience.core.config import FireflyDataScienceConfig @@ -142,7 +151,7 @@ The literal type for `sandbox` is exactly `Literal["monty", "docker", "e2b", "lo Profile overlays outrank the base `firefly-datascience.yaml`, so a `prod` profile can tighten isolation without touching the base file. See [Configuration](configuration.md) for the full precedence order. -Beyond the strongest sandbox sits **HITL** (human-in-the-loop): when `execution.require_approval` is `True` (the default), generated code is surfaced for human approval before it runs. This is the final tier — a person, not a policy, signs off. +Beyond the strongest sandbox sits **HITL** (human-in-the-loop): `execution.require_approval` defaults to `True`, and the design's final tier is a person — not a policy — signing off on generated code before it runs. (Per the status note above, the approval-gate wiring is on the roadmap; today the field is declared and validated.) !!! note "Defaults are the safe end of every axis" Out of the box, `sandbox = "monty"` (in-process restricted interpreter), `timeout_seconds = 60`, and `require_approval = True`. You loosen these deliberately — and only `local` removes isolation entirely. @@ -154,19 +163,19 @@ The subtle attack is not the model going rogue on its own; it is a **column valu 1. **Static analysis is content-blind.** It rejects `os`, `subprocess`, `socket`, dunder access, and `eval`/`exec`/`open` regardless of *why* the model wrote them — so a successful injection still produces code that gets rejected. 2. **The restricted namespace** means even "clever" injected code has no I/O, no imports, no host reach. 3. **The numeric-new-column contract** means injected code that tries to do anything other than add a numeric feature fails the post-conditions. -4. **Sandboxing + HITL** mean that for genuinely untrusted data you route to `docker`/`e2b` and require approval — so injection cannot silently reach a capability. +4. **Sandboxing + HITL** are the *intended* outer tiers for genuinely untrusted data (route to `docker`/`e2b`, require approval). Their enforcement is on the roadmap (see the Layer 3 status note) — today, rely on points 1–3, which are enforced in-process. !!! warning "The framework does not read your data's meaning" Firefly cannot inspect or sanitize the *semantics* of your data. Prompt-injection defense rests on capability restriction and sandboxing, not on detecting malicious text. Treat data of unknown provenance as untrusted input: raise `execution.sandbox` and keep `require_approval` on. ## Governance — the CostBenefitGate -GenAI is **off by default** (`genai.enabled = False`) — Firefly is classical-first. When you do enable it, the `CostBenefitGate` is the governance control: it decides whether an LLM call is *worth it* before spending tokens, bounded by a budget. +GenAI is **off by default** (`genai.enabled = False`) — Firefly is classical-first. When you do enable it, the `CostBenefitGate` is the governance control: it is a **post-hoc, measured-lift filter** — a proposal (feature or pipeline) is adopted only if it *measurably beats the seeded baseline* on cross-validation; anything that doesn't is discarded. (It governs *what is kept*, not token spend: `genai.budget_usd` is a declared ceiling whose pre-call enforcement is on the roadmap.) ```python config.genai.enabled # False by default config.genai.cost_benefit_gate # True — gate LLM spend on expected benefit -config.genai.budget_usd # optional hard ceiling (float | None), e.g. 5.00 +config.genai.budget_usd # declared ceiling (float | None); pre-call enforcement is roadmap ``` ```yaml @@ -179,7 +188,7 @@ genai: ``` !!! firefly "Two orthogonal gates: how much, and what" - The `CostBenefitGate` is a *governance* control, not a security control: it limits spend and runaway agentic loops, not capability. Keep both axes in mind — `cost_benefit_gate` governs **how much** the model runs; the executor and sandbox govern **what its output may do**. Neither substitutes for the other. + The `CostBenefitGate` is a *governance* control, not a security control: it governs **what GenAI output is kept** (only proposals that measurably beat the baseline), not capability. Keep both axes in mind — the gate governs **whether a proposal earns its place**; the executor and sandbox govern **what its output may do**. Neither substitutes for the other. ## Limits of the trust model diff --git a/mkdocs.yml b/mkdocs.yml index 8b4d6d4..8bcd35f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -104,6 +104,7 @@ nav: - Architecture: architecture.md - Datasets: datasets.md - Classical AutoML: automl.md + - Explainability: explainability.md - GenAI Feature Engineering: genai-features.md - Agentic ML-Engineering Loop: agentic-loop.md - Deep Learning & TabFM: deep-learning.md diff --git a/pyproject.toml b/pyproject.toml index f17ca60..586ade9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ nlp = ["transformers>=4.45.0", "datasets>=3.0.0", "peft>=0.13.0", "trl>=0.11.0", tracking = ["mlflow>=2.17.0"] tracking-wandb = ["wandb>=0.18.0"] validation = ["pandera>=0.20.0"] +explain = ["shap>=0.46.0"] featurestore = ["feast>=0.40.0"] serving = ["bentoml>=1.3.0"] serving-llm = ["vllm>=0.6.0"] @@ -64,7 +65,7 @@ genai = [ ] # convenience bundles automl-stack = ["fireflyframework-datascience[tabular,tabfm,automl,tracking,validation,data]"] -full = ["fireflyframework-datascience[tabular,tabfm,automl,dl,nlp,tracking,validation,featurestore,serving,lineage,orchestration,data,genai]"] +full = ["fireflyframework-datascience[tabular,tabfm,automl,dl,nlp,tracking,validation,explain,featurestore,serving,lineage,orchestration,data,genai]"] [project.scripts] firefly-ds = "fireflyframework_datascience.cli.main:cli" @@ -77,6 +78,7 @@ datasets = "fireflyframework_datascience.datasets.auto_configuration:DatasetsAut engineering = "fireflyframework_datascience.engineering.auto_configuration:EngineeringAutoConfiguration" models = "fireflyframework_datascience.models.auto_configuration:ModelsAutoConfiguration" evaluation = "fireflyframework_datascience.evaluation.auto_configuration:EvaluationAutoConfiguration" +explainability = "fireflyframework_datascience.explainability.auto_configuration:ExplainabilityAutoConfiguration" features = "fireflyframework_datascience.features.auto_configuration:FeaturesAutoConfiguration" search = "fireflyframework_datascience.search.auto_configuration:SearchAutoConfiguration" validation = "fireflyframework_datascience.validation.auto_configuration:ValidationAutoConfiguration" @@ -125,6 +127,7 @@ ignore = ["E501", "TC001", "TC002", "TC003", "UP040", "UP046", "UP047", "B008", "src/fireflyframework_datascience/models/**" = ["PLC0415"] "src/fireflyframework_datascience/engineering/**" = ["PLC0415"] "src/fireflyframework_datascience/evaluation/**" = ["PLC0415"] +"src/fireflyframework_datascience/explainability/**" = ["PLC0415"] "src/fireflyframework_datascience/features/**" = ["PLC0415"] "src/fireflyframework_datascience/search/**" = ["PLC0415"] "src/fireflyframework_datascience/validation/**" = ["PLC0415"] diff --git a/src/fireflyframework_datascience/automl/__init__.py b/src/fireflyframework_datascience/automl/__init__.py index 063bb4d..74a34bc 100644 --- a/src/fireflyframework_datascience/automl/__init__.py +++ b/src/fireflyframework_datascience/automl/__init__.py @@ -18,6 +18,7 @@ if TYPE_CHECKING: from fireflyframework_datascience.automl.facade import AutoML + from fireflyframework_datascience.explainability import ExplainerPort, GlobalExplanation @dataclass @@ -44,6 +45,7 @@ class AutoMLResult: evaluator: MetricsEvaluatorPort cv_scoring: str = "" extras: dict[str, Any] = field(default_factory=dict) + explainer: ExplainerPort | None = None @property def best_score(self) -> float: @@ -69,6 +71,19 @@ def evaluate(self, dataset: Dataset) -> EvaluationResult: def leaderboard_table(self) -> str: return "\n".join(str(entry) for entry in self.leaderboard) + def explain(self, dataset: Dataset) -> GlobalExplanation: + """Global feature importances for the winning model. + + Uses the injected :class:`ExplainerPort` (the DI-wired explainer when built via + ``AutoML.from_context``), falling back to the dependency-free permutation-importance explainer. + """ + explainer = self.explainer + if explainer is None: + from fireflyframework_datascience.explainability.adapters import PermutationImportanceExplainer + + explainer = PermutationImportanceExplainer() + return explainer.explain_global(self.best_model, dataset) + @runtime_checkable class AutoMLBackendPort(Protocol): diff --git a/src/fireflyframework_datascience/automl/facade.py b/src/fireflyframework_datascience/automl/facade.py index 9a02cfa..f613f33 100644 --- a/src/fireflyframework_datascience/automl/facade.py +++ b/src/fireflyframework_datascience/automl/facade.py @@ -16,6 +16,7 @@ from fireflyframework_datascience.core.types import TaskType from fireflyframework_datascience.datasets import Dataset from fireflyframework_datascience.evaluation import MetricsEvaluatorPort +from fireflyframework_datascience.explainability import ExplainerPort from fireflyframework_datascience.models import Model, TrainerPort from fireflyframework_datascience.search import SearchPolicyPort from fireflyframework_datascience.tracking import TrackerPort @@ -35,6 +36,7 @@ def __init__( search_policy: SearchPolicyPort | None = None, validator: ValidatorPort | None = None, tracker: TrackerPort | None = None, + explainer: ExplainerPort | None = None, cv: int = 5, n_trials: int = 20, random_state: int = 42, @@ -44,6 +46,7 @@ def __init__( self._search = search_policy or _default_search() self._validator = validator self._tracker = tracker + self._explainer = explainer self._cv = cv self._n_trials = n_trials self._random_state = random_state @@ -59,6 +62,7 @@ def from_context(cls, context: Any, **overrides: Any) -> AutoML: search_policy=container.resolve_optional(SearchPolicyPort) or _default_search(), validator=container.resolve_optional(ValidatorPort), tracker=container.resolve_optional(TrackerPort), + explainer=container.resolve_optional(ExplainerPort), **overrides, ) @@ -113,6 +117,7 @@ def fit(self, dataset: Dataset, *, task: TaskType | None = None, metric: str | N task=task, evaluator=self._evaluator, cv_scoring=scoring, + explainer=self._explainer, ) # -- internals -------------------------------------------------------- @@ -164,13 +169,21 @@ def _track_results(self, run: Any, model: Model, leaderboard: list[LeaderboardEn def _default_trainers() -> list[TrainerPort]: - from fireflyframework_datascience.models.adapters import ( - HistGradientBoostingTrainer, - LinearTrainer, - RandomForestTrainer, - ) - - return [RandomForestTrainer(), LinearTrainer(), HistGradientBoostingTrainer()] + import importlib + import importlib.util + + adapters = importlib.import_module("fireflyframework_datascience.models.adapters") + trainers: list[TrainerPort] = [ + adapters.RandomForestTrainer(), + adapters.LinearTrainer(), + adapters.HistGradientBoostingTrainer(), + ] + # Match the documented "+ XGBoost / LightGBM / CatBoost when installed" behaviour (the DI and + # agentic paths already do this) by including each boosting trainer whose library is importable. + for lib, cls_name in (("xgboost", "XGBoostTrainer"), ("lightgbm", "LightGBMTrainer"), ("catboost", "CatBoostTrainer")): + if importlib.util.find_spec(lib) is not None: + trainers.append(getattr(adapters, cls_name)()) + return trainers def _default_evaluator() -> MetricsEvaluatorPort: diff --git a/src/fireflyframework_datascience/explainability/__init__.py b/src/fireflyframework_datascience/explainability/__init__.py new file mode 100644 index 0000000..873be37 --- /dev/null +++ b/src/fireflyframework_datascience/explainability/__init__.py @@ -0,0 +1,65 @@ +# Copyright 2026 Firefly Software Foundation. +"""Explainability module — the ``ExplainerPort`` and typed explanation results (import-light). + +Explanations follow the framework's classical-first thesis: the importances are produced by +deterministic, well-understood methods (permutation importance, model-native importances, SHAP) — not +by an LLM. Heavy adapters live in :mod:`fireflyframework_datascience.explainability.adapters`. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +if TYPE_CHECKING: + from fireflyframework_datascience.datasets import Dataset + from fireflyframework_datascience.models import Model + + +@dataclass +class GlobalExplanation: + """Dataset-level feature importances for a fitted model.""" + + method: str + feature_importances: dict[str, float] + std: dict[str, float] = field(default_factory=dict) + baseline_score: float = float("nan") + + def top(self, k: int = 20) -> list[tuple[str, float]]: + """The ``k`` most important features, highest first.""" + return sorted(self.feature_importances.items(), key=lambda kv: kv[1], reverse=True)[:k] + + def to_frame(self) -> Any: + """A tidy ``feature/importance/std`` DataFrame, sorted by importance (descending).""" + import pandas as pd + + rows = [{"feature": n, "importance": v, "std": self.std.get(n, float("nan"))} for n, v in self.top(len(self.feature_importances))] + return pd.DataFrame(rows) + + +@dataclass +class LocalExplanation: + """Per-prediction feature attributions for a single row.""" + + method: str + prediction: Any + contributions: dict[str, float] + base_value: float = float("nan") + + def top(self, k: int = 20) -> list[tuple[str, float]]: + """The ``k`` features with the largest absolute contribution to this prediction.""" + return sorted(self.contributions.items(), key=lambda kv: abs(kv[1]), reverse=True)[:k] + + +@runtime_checkable +class ExplainerPort(Protocol): + """Produces explanations for a fitted :class:`Model`.""" + + name: str + + def supports(self, model: Model) -> bool: ... + + def explain_global(self, model: Model, dataset: Dataset) -> GlobalExplanation: ... + + +__all__ = ["ExplainerPort", "GlobalExplanation", "LocalExplanation"] diff --git a/src/fireflyframework_datascience/explainability/adapters.py b/src/fireflyframework_datascience/explainability/adapters.py new file mode 100644 index 0000000..4bec318 --- /dev/null +++ b/src/fireflyframework_datascience/explainability/adapters.py @@ -0,0 +1,113 @@ +# Copyright 2026 Firefly Software Foundation. +"""Explainability adapters. + +- :class:`PermutationImportanceExplainer` — the dependency-free default (scikit-learn, already in the + ``tabular`` extra). Model-agnostic: permutes each input feature and measures the score drop. +- :class:`ShapExplainer` — optional, behind the ``explain`` extra; adds local (per-prediction) + attributions. Raises :class:`AdapterUnavailableError` if ``shap`` is not installed. +""" + +from __future__ import annotations + +from typing import Any + +from fireflyframework_datascience.core.exceptions import AdapterUnavailableError +from fireflyframework_datascience.explainability import GlobalExplanation, LocalExplanation +from fireflyframework_datascience.models import Model + + +class PermutationImportanceExplainer: + """Global feature importance via scikit-learn permutation importance (model-agnostic).""" + + name = "permutation_importance" + + def __init__(self, *, n_repeats: int = 10, random_state: int = 42, scoring: str | None = None) -> None: + self._n_repeats = n_repeats + self._random_state = random_state + self._scoring = scoring + + def supports(self, model: Model) -> bool: + return hasattr(model.estimator, "predict") + + def explain_global(self, model: Model, dataset: Any) -> GlobalExplanation: + from sklearn.inspection import permutation_importance + + result = permutation_importance( + model.estimator, + dataset.X, + dataset.y, + n_repeats=self._n_repeats, + random_state=self._random_state, + scoring=self._scoring, + ) + names = list(dataset.feature_names) or list(dataset.X.columns) + importances = {n: float(m) for n, m in zip(names, result.importances_mean, strict=False)} + std = {n: float(s) for n, s in zip(names, result.importances_std, strict=False)} + try: + baseline = float(model.estimator.score(dataset.X, dataset.y)) + except Exception: # noqa: BLE001 - baseline is informational only + baseline = float("nan") + return GlobalExplanation( + method="permutation_importance", feature_importances=importances, std=std, baseline_score=baseline + ) + + +class ShapExplainer: + """SHAP-based global + local attributions (optional ``explain`` extra).""" + + name = "shap" + + def __init__(self, *, max_samples: int = 200) -> None: + try: + import shap # noqa: F401 + except ImportError as exc: # pragma: no cover - exercised only without the extra + raise AdapterUnavailableError( + "ShapExplainer requires the 'explain' extra: pip install 'fireflyframework-datascience[explain]'" + ) from exc + self._max_samples = max_samples + + def supports(self, model: Model) -> bool: + return hasattr(model.estimator, "predict") + + def _underlying(self, model: Model) -> Any: + """Reach the final estimator inside a sklearn Pipeline, if present.""" + est = model.estimator + steps = getattr(est, "named_steps", None) + return steps["model"] if steps and "model" in steps else est + + def explain_global(self, model: Model, dataset: Any) -> GlobalExplanation: + import numpy as np + + names = list(dataset.feature_names) or list(dataset.X.columns) + values = self._shap_values(model, dataset.X) + mean_abs = np.abs(values).mean(axis=0) + importances = {n: float(v) for n, v in zip(names, mean_abs, strict=False)} + return GlobalExplanation(method="shap", feature_importances=importances) + + def explain_local(self, model: Model, X: Any) -> list[LocalExplanation]: + names = list(getattr(X, "columns", [])) + values = self._shap_values(model, X) + preds = model.predict(X) + out: list[LocalExplanation] = [] + for i in range(len(values)): + contributions = {n: float(v) for n, v in zip(names, values[i], strict=False)} + out.append(LocalExplanation(method="shap", prediction=preds[i], contributions=contributions)) + return out + + def _shap_values(self, model: Model, X: Any) -> Any: + import numpy as np + import shap + + sample = X.iloc[: self._max_samples] if hasattr(X, "iloc") else X[: self._max_samples] + # Transform through the pipeline's preprocessing so SHAP sees the estimator's real inputs is + # complex with one-hot columns; for the model-agnostic path we explain the whole pipeline. + explainer = shap.Explainer(model.estimator.predict, sample) + values = explainer(sample).values + # binary/regression -> (n, f); some explainers return (n, f, classes): collapse to class-1/abs. + values = np.asarray(values) + if values.ndim == 3: + values = values[:, :, -1] + return values + + +__all__ = ["PermutationImportanceExplainer", "ShapExplainer"] diff --git a/src/fireflyframework_datascience/explainability/auto_configuration.py b/src/fireflyframework_datascience/explainability/auto_configuration.py new file mode 100644 index 0000000..fa54e7c --- /dev/null +++ b/src/fireflyframework_datascience/explainability/auto_configuration.py @@ -0,0 +1,41 @@ +# Copyright 2026 Firefly Software Foundation. +"""Auto-configuration for the explainability module. + +Registers a default :class:`ExplainerPort`. When the optional ``explain`` extra (``shap``) is +installed, the SHAP explainer is registered as primary; otherwise the dependency-free +permutation-importance explainer is used. +""" + +from __future__ import annotations + +import importlib.util + +from fireflyframework_datascience.container.conditions import ( + auto_configuration, + conditional_on_class, + conditional_on_missing_bean, +) +from fireflyframework_datascience.container.stereotypes import bean, configuration +from fireflyframework_datascience.explainability import ExplainerPort + + +@auto_configuration +@conditional_on_class("sklearn") +@configuration +class ExplainabilityAutoConfiguration: + """Registers a single default explainer: SHAP when the ``explain`` extra is installed, else the + dependency-free permutation-importance explainer. A user-supplied ``ExplainerPort`` wins.""" + + @bean(name="default_explainer", primary=True) + @conditional_on_missing_bean(ExplainerPort) + def explainer(self) -> ExplainerPort: + if importlib.util.find_spec("shap") is not None: + from fireflyframework_datascience.explainability.adapters import ShapExplainer + + try: + return ShapExplainer() + except Exception: # noqa: BLE001 - fall back to the always-available default + pass + from fireflyframework_datascience.explainability.adapters import PermutationImportanceExplainer + + return PermutationImportanceExplainer() diff --git a/tests/explainability/test_explainability.py b/tests/explainability/test_explainability.py new file mode 100644 index 0000000..e24e4e6 --- /dev/null +++ b/tests/explainability/test_explainability.py @@ -0,0 +1,103 @@ +# Copyright 2026 Firefly Software Foundation. +"""Explainability tests — real data, no fakes, no mocks. + +The contract we assert is the one users actually rely on: a global explanation must rank a genuinely +informative feature above an injected pure-noise column. We use scikit-learn's real ``breast_cancer`` +dataset plus a deterministic noise column whose importance must be ~0. +""" +from __future__ import annotations + +import numpy as np +import pytest + +from fireflyframework_datascience.core.types import TaskType +from fireflyframework_datascience.datasets import Dataset +from fireflyframework_datascience.models import Model + + +def _breast_cancer_with_noise() -> tuple[Dataset, Model]: + """A real fitted RandomForest on breast_cancer + one pure-noise column.""" + from sklearn.datasets import load_breast_cancer + from sklearn.ensemble import RandomForestClassifier + from sklearn.pipeline import Pipeline + + raw = load_breast_cancer(as_frame=True) + X = raw.data.copy() + X["__noise__"] = np.random.default_rng(0).normal(size=len(X)) # pure noise, no signal + y = raw.target + estimator = Pipeline([("model", RandomForestClassifier(n_estimators=80, random_state=0))]).fit(X, y) + cols = list(X.columns) + model = Model(name="random_forest", estimator=estimator, task=TaskType.BINARY, feature_names=cols) + dataset = Dataset(name="breast_cancer", X=X, y=y, task=TaskType.BINARY, target_name="target", feature_names=cols) + return dataset, model + + +def test_permutation_importance_ranks_signal_above_noise() -> None: + from fireflyframework_datascience.explainability import GlobalExplanation + from fireflyframework_datascience.explainability.adapters import PermutationImportanceExplainer + + dataset, model = _breast_cancer_with_noise() + explainer = PermutationImportanceExplainer(n_repeats=5, random_state=0) + + explanation = explainer.explain_global(model, dataset) + + assert isinstance(explanation, GlobalExplanation) + # one importance per input feature, keyed by the real column names + assert set(explanation.feature_importances) == set(dataset.feature_names) + # the pure-noise column carries essentially no importance... + assert explanation.feature_importances["__noise__"] <= 0.005 + # ...and ranks strictly below the most informative real feature + assert explanation.feature_importances["__noise__"] < max(explanation.feature_importances.values()) + # a genuinely informative breast-cancer feature surfaces in the top of the ranking + top_names = [name for name, _ in explanation.top(8)] + assert any(f in top_names for f in ("worst perimeter", "worst concave points", "worst radius", "worst area")) + assert "__noise__" not in top_names + + +def test_automl_result_explains_the_winner_on_real_data() -> None: + from fireflyframework_datascience.automl import AutoML + from fireflyframework_datascience.explainability import GlobalExplanation + + dataset, _ = _breast_cancer_with_noise() + train, test = dataset.train_test_split(test_size=0.3, random_state=0) + + result = AutoML(cv=3, n_trials=1, random_state=0).fit(train) + explanation = result.explain(test) + + assert isinstance(explanation, GlobalExplanation) + assert set(explanation.feature_importances) == set(dataset.feature_names) + # the injected noise column is the least informative through the full AutoML pipeline + assert explanation.feature_importances["__noise__"] < max(explanation.feature_importances.values()) + + +def test_explainer_is_auto_configured_in_the_container() -> None: + from fireflyframework_datascience import FireflyDataScienceApplication + from fireflyframework_datascience.explainability import ExplainerPort + + app = FireflyDataScienceApplication.run(print_output=False) + explainer = app.container.resolve_optional(ExplainerPort) + + assert explainer is not None + assert isinstance(explainer, ExplainerPort) + # the dependency-free default, or SHAP when the optional `explain` extra is installed + assert explainer.name in {"permutation_importance", "shap"} + + +def test_shap_explainer_global_and_local_on_real_data() -> None: + pytest.importorskip("shap") # only runs when the optional `explain` extra is installed + from fireflyframework_datascience.explainability import GlobalExplanation, LocalExplanation + from fireflyframework_datascience.explainability.adapters import ShapExplainer + + dataset, model = _breast_cancer_with_noise() + explainer = ShapExplainer(max_samples=40) + + global_exp = explainer.explain_global(model, dataset) + assert isinstance(global_exp, GlobalExplanation) + assert set(global_exp.feature_importances) == set(dataset.feature_names) + assert global_exp.feature_importances["__noise__"] < max(global_exp.feature_importances.values()) + + local = explainer.explain_local(model, dataset.X.iloc[:3]) + assert len(local) == 3 + assert all(isinstance(item, LocalExplanation) for item in local) + assert set(local[0].contributions) == set(dataset.feature_names) + diff --git a/tests/test_automl.py b/tests/test_automl.py index ce618d7..fa4cb5f 100644 --- a/tests/test_automl.py +++ b/tests/test_automl.py @@ -17,7 +17,9 @@ def test_classification_end_to_end() -> None: result = AutoML().fit(train) assert result.task is TaskType.BINARY - assert len(result.leaderboard) == 3 # RF, Linear, HistGB + # core trainers are always present; installed boosting libraries are added by default + names = {entry.model_name for entry in result.leaderboard} + assert {"random_forest", "linear", "hist_gradient_boosting"} <= names assert result.leaderboard[0].cv_score >= result.leaderboard[-1].cv_score # sorted desc evaluation = result.evaluate(test) diff --git a/tests/test_default_trainers.py b/tests/test_default_trainers.py new file mode 100644 index 0000000..932000b --- /dev/null +++ b/tests/test_default_trainers.py @@ -0,0 +1,20 @@ +# Copyright 2026 Firefly Software Foundation. +"""The plain ``AutoML()`` constructor must include installed boosting libraries by default. + +The docs/benchmarks advertise "+ XGBoost / LightGBM / CatBoost when installed"; this asserts the +imperative path matches that claim (previously only the DI / agentic path did). +""" +from __future__ import annotations + +import importlib.util + + +def test_default_trainers_include_installed_boosting_libraries() -> None: + from fireflyframework_datascience.automl.facade import _default_trainers + + names = {t.name for t in _default_trainers()} + assert {"random_forest", "linear", "hist_gradient_boosting"} <= names + + for lib, trainer_name in [("xgboost", "xgboost"), ("lightgbm", "lightgbm"), ("catboost", "catboost")]: + if importlib.util.find_spec(lib) is not None: + assert trainer_name in names, f"{trainer_name!r} must be a default trainer when {lib!r} is installed" diff --git a/uv.lock b/uv.lock index d3e7b0f..ae42ad6 100644 --- a/uv.lock +++ b/uv.lock @@ -2,7 +2,8 @@ version = 1 revision = 3 requires-python = ">=3.13" resolution-markers = [ - "python_full_version >= '3.15' and sys_platform == 'darwin'", + "python_full_version >= '3.15' and platform_machine == 'x86_64' and sys_platform == 'darwin'", + "python_full_version >= '3.15' and platform_machine != 'x86_64' and sys_platform == 'darwin'", "python_full_version == '3.14.*' and sys_platform == 'darwin'", "python_full_version >= '3.15' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.14.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", @@ -2081,6 +2082,9 @@ dl = [ { name = "torch" }, { name = "torchvision" }, ] +explain = [ + { name = "shap" }, +] featurestore = [ { name = "feast" }, ] @@ -2106,6 +2110,7 @@ full = [ { name = "polars" }, { name = "scikit-learn" }, { name = "sentencepiece" }, + { name = "shap" }, { name = "tabpfn" }, { name = "torch" }, { name = "torchvision" }, @@ -2182,7 +2187,7 @@ requires-dist = [ { name = "feast", marker = "extra == 'featurestore'", specifier = ">=0.40.0" }, { name = "fireflyframework-agentic", git = "https://github.com/fireflyframework/fireflyframework-agentic?rev=main" }, { name = "fireflyframework-agentic", extras = ["script-execution", "embeddings", "openai-embeddings", "vectorstores-chroma"], marker = "extra == 'genai'" }, - { name = "fireflyframework-datascience", extras = ["tabular", "tabfm", "automl", "dl", "nlp", "tracking", "validation", "featurestore", "serving", "lineage", "orchestration", "data", "genai"], marker = "extra == 'full'" }, + { name = "fireflyframework-datascience", extras = ["tabular", "tabfm", "automl", "dl", "nlp", "tracking", "validation", "explain", "featurestore", "serving", "lineage", "orchestration", "data", "genai"], marker = "extra == 'full'" }, { name = "fireflyframework-datascience", extras = ["tabular", "tabfm", "automl", "tracking", "validation", "data"], marker = "extra == 'automl-stack'" }, { name = "lightgbm", marker = "extra == 'tabular'", specifier = ">=4.5.0" }, { name = "lightning", marker = "extra == 'dl'", specifier = ">=2.4.0" }, @@ -2201,6 +2206,7 @@ requires-dist = [ { name = "rich", specifier = ">=13.7.0" }, { name = "scikit-learn", marker = "extra == 'tabular'", specifier = ">=1.5.0" }, { name = "sentencepiece", marker = "extra == 'nlp'", specifier = ">=0.2.0" }, + { name = "shap", marker = "extra == 'explain'", specifier = ">=0.46.0" }, { name = "tabpfn", marker = "extra == 'tabfm'", specifier = ">=2.0.0" }, { name = "torch", marker = "extra == 'dl'", specifier = ">=2.4.0" }, { name = "torchvision", marker = "extra == 'dl'", specifier = ">=0.19.0" }, @@ -2211,7 +2217,7 @@ requires-dist = [ { name = "wandb", marker = "extra == 'tracking-wandb'", specifier = ">=0.18.0" }, { name = "xgboost", marker = "extra == 'tabular'", specifier = ">=2.1.0" }, ] -provides-extras = ["tabular", "tabfm", "automl", "dl", "nlp", "tracking", "tracking-wandb", "validation", "featurestore", "serving", "serving-llm", "lineage", "orchestration", "data", "genai", "automl-stack", "full"] +provides-extras = ["tabular", "tabfm", "automl", "dl", "nlp", "tracking", "tracking-wandb", "validation", "explain", "featurestore", "serving", "serving-llm", "lineage", "orchestration", "data", "genai", "automl-stack", "full"] [package.metadata.requires-dev] dev = [ @@ -3965,13 +3971,13 @@ name = "mlx-lm" version = "0.31.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "jinja2", marker = "sys_platform == 'darwin'" }, - { name = "mlx", marker = "sys_platform == 'darwin'" }, - { name = "numpy", marker = "sys_platform == 'darwin'" }, - { name = "protobuf", marker = "sys_platform == 'darwin'" }, - { name = "pyyaml", marker = "sys_platform == 'darwin'" }, - { name = "sentencepiece", marker = "sys_platform == 'darwin'" }, - { name = "transformers", marker = "sys_platform == 'darwin'" }, + { name = "jinja2", marker = "(python_full_version < '3.15' and sys_platform == 'darwin') or (platform_machine != 'x86_64' and sys_platform == 'darwin')" }, + { name = "mlx", marker = "(python_full_version < '3.15' and sys_platform == 'darwin') or (platform_machine != 'x86_64' and sys_platform == 'darwin')" }, + { name = "numpy", marker = "(python_full_version < '3.15' and sys_platform == 'darwin') or (platform_machine != 'x86_64' and sys_platform == 'darwin')" }, + { name = "protobuf", marker = "(python_full_version < '3.15' and sys_platform == 'darwin') or (platform_machine != 'x86_64' and sys_platform == 'darwin')" }, + { name = "pyyaml", marker = "(python_full_version < '3.15' and sys_platform == 'darwin') or (platform_machine != 'x86_64' and sys_platform == 'darwin')" }, + { name = "sentencepiece", marker = "(python_full_version < '3.15' and sys_platform == 'darwin') or (platform_machine != 'x86_64' and sys_platform == 'darwin')" }, + { name = "transformers", marker = "(python_full_version < '3.15' and sys_platform == 'darwin') or (platform_machine != 'x86_64' and sys_platform == 'darwin')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/84/94/9a38d6b0c6fcca995b9136c94eb7da1e9c5165652edf228b96b29960fa7a/mlx_lm-0.31.3.tar.gz", hash = "sha256:61eb0e3ba09444f77f874aff295401d7ccd20b39495cbbce0c782a15474ce733", size = 304318, upload-time = "2026-04-22T07:37:27.922Z" } wheels = [ @@ -6930,6 +6936,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0d/6d/b4752b044bf94cb802d88a888dc7d288baaf77d7910b7dedda74b5ceea0c/setuptools-79.0.1-py3-none-any.whl", hash = "sha256:e147c0549f27767ba362f9da434eab9c5dc0045d5304feb602a0af001089fc51", size = 1256281, upload-time = "2025-04-23T22:20:56.768Z" }, ] +[[package]] +name = "shap" +version = "0.49.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cloudpickle" }, + { name = "numba" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "pandas" }, + { name = "scikit-learn" }, + { name = "scipy" }, + { name = "slicer" }, + { name = "tqdm" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/dc/c6/9823a7f483aa9f3179fc359c10d22da9e418b1a7a3fc99a42b705d05e82a/shap-0.49.1.tar.gz", hash = "sha256:1114ecd804fff29f50d522ce6031082fcf42fe4a32fb1b5da233b2415d784c8c", size = 4084725, upload-time = "2025-10-14T10:04:49.75Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/5c/030bbfa19605ca4ad66a753d55e76aee5093be6748a6d33eda89e5613995/shap-0.49.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:333cd8e8c427badda92d5ada9e7aad1e3e1e8e7e0398da51a18b7ffb03514e45", size = 558604, upload-time = "2025-10-14T10:04:34.298Z" }, + { url = "https://files.pythonhosted.org/packages/2c/7f/7e7b78e9fac6f891096fb6a59a6d4db23243b0af2369ae54e161f513c485/shap-0.49.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f4faf61560f73a66f4f26bc027c91f8939201979c4db24949dca305ba0a2ad36", size = 555311, upload-time = "2025-10-14T10:04:35.582Z" }, + { url = "https://files.pythonhosted.org/packages/f2/be/25283a0f8c30deaf897b89a0dbfd490d330f6fc68caa6f19db6e130832e9/shap-0.49.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b440da658d9aee7711bf642c9b4826d81f588fb478cd9e90c068646e90f56669", size = 1016897, upload-time = "2025-10-14T10:04:36.856Z" }, + { url = "https://files.pythonhosted.org/packages/5c/91/a63e563f3dc8e134db12dd155a1a6ed5e0649f79fc8ac651aac1088e8652/shap-0.49.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8dfa5654eccf4d13dcb262a10314a4e0eb1060db842b2ef31e9fb0038168bc1", size = 1022476, upload-time = "2025-10-14T10:04:38.171Z" }, + { url = "https://files.pythonhosted.org/packages/15/a2/89303c1f7eb206658bf9ec974dc6e69b0a6bd309cf5de0cfa8f92f5a8eb3/shap-0.49.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ed3080030a6000d3737841c5770ed555b8a922b794fa0ba5aae1e45655eda1fa", size = 2087940, upload-time = "2025-10-14T10:04:39.497Z" }, + { url = "https://files.pythonhosted.org/packages/84/bd/0b9b3e19b9b8cda51463f8a749dc354eb9c87f42eddcbfdf742dceb3746b/shap-0.49.1-cp313-cp313-win_amd64.whl", hash = "sha256:6af779344c23b12a47063aab7fc135fefbdb5849233c1813f11dd8cf2fc73bea", size = 547806, upload-time = "2025-10-14T10:04:40.712Z" }, +] + [[package]] name = "shellingham" version = "1.5.4" @@ -6973,6 +7005,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e7/0e/3ae19fa941522cd98e119762e7181d371c8dba0b2d72bfaf9522692e329c/skops-0.14.0-py3-none-any.whl", hash = "sha256:60a5db78a9db46ccee2139a0ba13ab5afb1c96f4749b382e75a371291bbe3e36", size = 132198, upload-time = "2026-04-20T18:23:54.018Z" }, ] +[[package]] +name = "slicer" +version = "0.0.8" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d3/f9/b4bce2825b39b57760b361e6131a3dacee3d8951c58cb97ad120abb90317/slicer-0.0.8.tar.gz", hash = "sha256:2e7553af73f0c0c2d355f4afcc3ecf97c6f2156fcf4593955c3f56cf6c4d6eb7", size = 14894, upload-time = "2024-03-09T23:35:26.826Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/63/81/9ef641ff4e12cbcca30e54e72fb0951a2ba195d0cda0ba4100e532d929db/slicer-0.0.8-py3-none-any.whl", hash = "sha256:6c206258543aecd010d497dc2eca9d2805860a0b3758673903456b7df7934dc3", size = 15251, upload-time = "2024-03-09T07:03:07.708Z" }, +] + [[package]] name = "smmap" version = "5.0.3" From 6f979b1b8192faca53aae3092b994687f676bf76 Mon Sep 17 00:00:00 2001 From: Andres Contreras Date: Thu, 25 Jun 2026 19:56:06 +0200 Subject: [PATCH 2/2] fix: ruff format + pyright for explainability (CI lint/typecheck) - AdapterUnavailableError(adapter, extra) two-arg form; shap import marked type-ignore[import-not-found] (CI doesn't install the explain extra) - cast untyped sklearn permutation_importance / shap returns to Any - apply ruff format --- .../automl/facade.py | 6 +++++- .../explainability/__init__.py | 5 ++++- .../explainability/adapters.py | 18 ++++++++---------- tests/explainability/test_explainability.py | 2 +- tests/test_default_trainers.py | 1 + 5 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/fireflyframework_datascience/automl/facade.py b/src/fireflyframework_datascience/automl/facade.py index f613f33..54e5624 100644 --- a/src/fireflyframework_datascience/automl/facade.py +++ b/src/fireflyframework_datascience/automl/facade.py @@ -180,7 +180,11 @@ def _default_trainers() -> list[TrainerPort]: ] # Match the documented "+ XGBoost / LightGBM / CatBoost when installed" behaviour (the DI and # agentic paths already do this) by including each boosting trainer whose library is importable. - for lib, cls_name in (("xgboost", "XGBoostTrainer"), ("lightgbm", "LightGBMTrainer"), ("catboost", "CatBoostTrainer")): + for lib, cls_name in ( + ("xgboost", "XGBoostTrainer"), + ("lightgbm", "LightGBMTrainer"), + ("catboost", "CatBoostTrainer"), + ): if importlib.util.find_spec(lib) is not None: trainers.append(getattr(adapters, cls_name)()) return trainers diff --git a/src/fireflyframework_datascience/explainability/__init__.py b/src/fireflyframework_datascience/explainability/__init__.py index 873be37..f5df27e 100644 --- a/src/fireflyframework_datascience/explainability/__init__.py +++ b/src/fireflyframework_datascience/explainability/__init__.py @@ -33,7 +33,10 @@ def to_frame(self) -> Any: """A tidy ``feature/importance/std`` DataFrame, sorted by importance (descending).""" import pandas as pd - rows = [{"feature": n, "importance": v, "std": self.std.get(n, float("nan"))} for n, v in self.top(len(self.feature_importances))] + rows = [ + {"feature": n, "importance": v, "std": self.std.get(n, float("nan"))} + for n, v in self.top(len(self.feature_importances)) + ] return pd.DataFrame(rows) diff --git a/src/fireflyframework_datascience/explainability/adapters.py b/src/fireflyframework_datascience/explainability/adapters.py index 4bec318..c06047c 100644 --- a/src/fireflyframework_datascience/explainability/adapters.py +++ b/src/fireflyframework_datascience/explainability/adapters.py @@ -32,7 +32,7 @@ def supports(self, model: Model) -> bool: def explain_global(self, model: Model, dataset: Any) -> GlobalExplanation: from sklearn.inspection import permutation_importance - result = permutation_importance( + result: Any = permutation_importance( model.estimator, dataset.X, dataset.y, @@ -59,11 +59,9 @@ class ShapExplainer: def __init__(self, *, max_samples: int = 200) -> None: try: - import shap # noqa: F401 + import shap # type: ignore[import-not-found, import-untyped] # noqa: F401 except ImportError as exc: # pragma: no cover - exercised only without the extra - raise AdapterUnavailableError( - "ShapExplainer requires the 'explain' extra: pip install 'fireflyframework-datascience[explain]'" - ) from exc + raise AdapterUnavailableError("ShapExplainer", "explain") from exc self._max_samples = max_samples def supports(self, model: Model) -> bool: @@ -96,15 +94,15 @@ def explain_local(self, model: Model, X: Any) -> list[LocalExplanation]: def _shap_values(self, model: Model, X: Any) -> Any: import numpy as np - import shap + import shap # type: ignore[import-not-found, import-untyped] sample = X.iloc[: self._max_samples] if hasattr(X, "iloc") else X[: self._max_samples] - # Transform through the pipeline's preprocessing so SHAP sees the estimator's real inputs is - # complex with one-hot columns; for the model-agnostic path we explain the whole pipeline. + # Explaining the whole pipeline (model-agnostic) keeps this correct whether or not the + # estimator has preprocessing steps with one-hot-expanded columns. explainer = shap.Explainer(model.estimator.predict, sample) - values = explainer(sample).values + explained: Any = explainer(sample) # binary/regression -> (n, f); some explainers return (n, f, classes): collapse to class-1/abs. - values = np.asarray(values) + values = np.asarray(explained.values) if values.ndim == 3: values = values[:, :, -1] return values diff --git a/tests/explainability/test_explainability.py b/tests/explainability/test_explainability.py index e24e4e6..fbb1ada 100644 --- a/tests/explainability/test_explainability.py +++ b/tests/explainability/test_explainability.py @@ -5,6 +5,7 @@ informative feature above an injected pure-noise column. We use scikit-learn's real ``breast_cancer`` dataset plus a deterministic noise column whose importance must be ~0. """ + from __future__ import annotations import numpy as np @@ -100,4 +101,3 @@ def test_shap_explainer_global_and_local_on_real_data() -> None: assert len(local) == 3 assert all(isinstance(item, LocalExplanation) for item in local) assert set(local[0].contributions) == set(dataset.feature_names) - diff --git a/tests/test_default_trainers.py b/tests/test_default_trainers.py index 932000b..965e5de 100644 --- a/tests/test_default_trainers.py +++ b/tests/test_default_trainers.py @@ -4,6 +4,7 @@ The docs/benchmarks advertise "+ XGBoost / LightGBM / CatBoost when installed"; this asserts the imperative path matches that claim (previously only the DI / agentic path did). """ + from __future__ import annotations import importlib.util