diff --git a/README.md b/README.md index b3fe857..3882843 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ **ocean** is a full package dedicated to counterfactual explanations for **tree ensembles**. It builds on the paper *Optimal Counterfactual Explanations in Tree Ensemble* by Axel Parmentier and Thibaut Vidal in the *Proceedings of the thirty-eighth International Conference on Machine Learning*, 2021, in press. The article is [available here](http://proceedings.mlr.press/v139/parmentier21a/parmentier21a.pdf). -Beyond the original MIP approach, ocean includes a new **constraint programming (CP)** method and will grow to cover additional formulations and heuristics. +Beyond the original MIP approach, ocean also includes **constraint programming (CP)** and **weighted MaxSAT** backends for exact counterfactual search on the same parsed tree ensembles. ## Installation @@ -54,14 +54,15 @@ rf.fit(data, target) y = int(rf.predict(x).item()) x = x.to_numpy().flatten() -# Explain the prediction using MIPEXplainer +# Explain the prediction using the MIP backend mip_model = MixedIntegerProgramExplainer(rf, mapper=mapper) mip_explanation = mip_model.explain(x, y=1 - y, norm=1) -# Explain the prediction using CPEExplainer +# Explain the prediction using the CP backend cp_model = ConstraintProgrammingExplainer(rf, mapper=mapper) cp_explanation = cp_model.explain(x, y=1 - y, norm=1) +# Explain the prediction using the MaxSAT backend maxsat_model = MaxSATExplainer(rf, mapper=mapper) maxsat_explanation = maxsat_model.explain(x, y=1 - y, norm=1) @@ -151,4 +152,4 @@ See the [examples folder](https://github.com/vidalt/OCEAN/tree/main/examples) fo - Axel Parmentier and Thibaut Vidal. 2021. Optimal Counterfactual Explanations in Tree Ensembles. In *Proceedings of the thirty-eighth International Conference on Machine Learning*. PMLR, 8276–8286. [Available here](http://proceedings.mlr.press/v139/parmentier21a/parmentier21a.pdf). - Raevskaya, Alesya & Lehtonen, Tuomo. (2025). Optimal Counterfactual Explanations for Random Forests with MaxSAT. 10.3233/FAIA250895. [Available here](https://aaltodoc.aalto.fi/server/api/core/bitstreams/36760903-9b05-491d-b744-ea4309bdf538/content). - \ No newline at end of file + diff --git a/docs/_ext/__init__.py b/docs/_ext/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/docs/_ext/__init__.py @@ -0,0 +1 @@ + diff --git a/docs/_ext/ocean_docs_stubs.py b/docs/_ext/ocean_docs_stubs.py index 47d0def..82b28e2 100644 --- a/docs/_ext/ocean_docs_stubs.py +++ b/docs/_ext/ocean_docs_stubs.py @@ -1,4 +1,4 @@ -# ruff: noqa +# ruff: noqa: BLE001, C901, N801, PLC0415, PLR6301, PYI034, RUF012 from __future__ import annotations @@ -69,6 +69,9 @@ class Var(_GenericAliasMixin): class Constr(_GenericAliasMixin): pass + class MConstr(_GenericAliasMixin): + pass + class LinExpr(_GenericAliasMixin): pass @@ -104,6 +107,7 @@ class GRB: module.Env = Env module.Var = Var module.Constr = Constr + module.MConstr = MConstr module.LinExpr = LinExpr module.QuadExpr = QuadExpr module.MVar = MVar @@ -252,6 +256,9 @@ class PBEnc: def atleast(*_args: object, **_kwargs: object) -> _PBResult: return _PBResult() + class EncType: + adder = "adder" + class RC2(_GenericAliasMixin): def __init__(self, *_args: object, **_kwargs: object) -> None: self.cost = 0 @@ -271,6 +278,7 @@ def compute(self) -> list[int]: formula.WCNF = WCNF formula.IDPool = IDPool + pb.EncType = EncType pb.PBEnc = PBEnc rc2.RC2 = RC2 diff --git a/docs/api/ocean.maxsat.rst b/docs/api/ocean.maxsat.rst index ca4424b..fddb028 100644 --- a/docs/api/ocean.maxsat.rst +++ b/docs/api/ocean.maxsat.rst @@ -1,6 +1,11 @@ Weighted MaxSAT Backend ======================= +The :mod:`ocean.maxsat` module exposes the weighted MaxSAT formulation backed +by PySAT. The main public entry point is :class:`ocean.maxsat.Explainer`; the +remaining classes document the underlying Boolean model, variable bundles, and +manager helpers used to build that formulation. + .. automodule:: ocean.maxsat :members: :undoc-members: diff --git a/docs/api/ocean.mip.rst b/docs/api/ocean.mip.rst index ff59a2e..37c6e76 100644 --- a/docs/api/ocean.mip.rst +++ b/docs/api/ocean.mip.rst @@ -1,6 +1,11 @@ Mixed-integer Programming Backend ================================= +The :mod:`ocean.mip` module is the Gurobi-backed formulation. Most users start +from :class:`ocean.mip.Explainer`, while :class:`ocean.mip.Model` and the +variable classes are useful when you want to inspect or extend the lower-level +formulation directly. + .. automodule:: ocean.mip :members: :undoc-members: diff --git a/docs/api/ocean.typing.rst b/docs/api/ocean.typing.rst index 38634a3..7e13c6c 100644 --- a/docs/api/ocean.typing.rst +++ b/docs/api/ocean.typing.rst @@ -1,6 +1,11 @@ Shared Typing Helpers ===================== +The :mod:`ocean.typing` module collects the shared protocols and type aliases +used across the public explainers, tree parsing helpers, and internal backend +implementations. It is the reference page for supported ensemble types, +processed-array aliases, and the common explanation and explainer protocols. + .. automodule:: ocean.typing :members: :undoc-members: diff --git a/docs/conf.py b/docs/conf.py index e1b7d39..d549005 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -16,7 +16,7 @@ project = "OCEAN" author = "OCEAN contributors" -copyright = "2026, Awa Khouna and the OCEAN contributors" +copyright = "2026, Awa Khouna and the OCEAN contributors" # noqa: A001 try: release = version("oceanpy") diff --git a/docs/custom-dataset.rst b/docs/custom-dataset.rst index 7631042..adea80e 100644 --- a/docs/custom-dataset.rst +++ b/docs/custom-dataset.rst @@ -1,43 +1,223 @@ Custom Dataset Example ====================== -This example mirrors the dataset-generation style used in the test suite: mix -continuous features, ordered discrete features, binary flags, and unordered -categorical features, parse them with OCEAN, train a tree ensemble, and -explain one instance. +This page turns the synthetic custom-dataset example into a notebook-style +walkthrough. It starts from a raw pandas dataframe, parses the feature types +with OCEAN, trains a random forest, and explains one class-``0`` prediction +with the constraint-programming backend. Why this example matters ------------------------ -The packaged dataset loaders are convenient, but most real integrations start -from a custom pandas dataframe. This example shows the full path from raw data -to a readable counterfactual without depending on any external dataset. +The packaged dataset loaders are useful when you want a quick start, but most +real integrations begin with a dataframe that you already own. This example +shows the full workflow on mixed feature types: -Running the example -------------------- +- ordered discrete values through ``credit_lines``, +- binary flags through ``owns_home`` and ``has_guarantor``, +- continuous ratios through ``income_ratio``, ``debt_ratio``, and + ``savings_ratio``, +- unordered nominal features through ``job_type`` and ``region``. -.. code-block:: bash +Cell 1: Build a mixed-type dataframe +------------------------------------ - python examples/custom_dataset.py +.. code-block:: python -What it does ------------- + import numpy as np + import pandas as pd -1. Generates a synthetic credit-style dataset with multiple feature types. -2. Uses :func:`ocean.feature.parse_features` to build the processed matrix and - mapper. -3. Trains a random forest on that processed matrix. -4. Selects a query that the model predicts as class ``0``. -5. Uses ``ocean.ConstraintProgrammingExplainer`` to search for the closest - class-``1`` counterfactual. -6. Prints both the original raw instance and the decoded explanation. + rng = np.random.default_rng(42) + raw = pd.DataFrame({ + "credit_lines": rng.choice([0, 1, 2, 4], size=300), + "owns_home": rng.integers(0, 2, size=300), + "has_guarantor": rng.integers(0, 2, size=300), + "income_ratio": rng.uniform(-0.4, 0.8, size=300), + "debt_ratio": rng.uniform(0.0, 1.0, size=300), + "savings_ratio": rng.uniform(-0.5, 0.6, size=300), + "job_type": rng.choice( + ["office", "manual", "service", "student"], + size=300, + ), + "region": rng.choice( + ["north", "south", "east", "west"], + size=300, + ), + }) -In this example, ``credit_lines`` is treated as an ordered discrete feature, -while ``job_type`` and ``region`` are treated as unordered categories and are -therefore one-hot encoded. + score = ( + (raw["credit_lines"] >= 2).astype(int) + + raw["owns_home"].astype(int) + + raw["has_guarantor"].astype(int) + + (raw["income_ratio"] > 0.1).astype(int) + + (raw["savings_ratio"] > 0.0).astype(int) + + raw["job_type"].isin(["office", "service"]).astype(int) + + raw["region"].isin(["north", "east"]).astype(int) + - (raw["debt_ratio"] > 0.55).astype(int) + ) + target = (score >= 4).astype(int).rename("approved") -Source ------- +Cell 2: Parse the features with OCEAN +------------------------------------- + +.. code-block:: python + + from ocean.feature import parse_features + + data, mapper = parse_features(raw, discretes=("credit_lines",)) + print(data.columns) + +.. code-block:: text + + MultiIndex([( 'credit_lines', ''), + ( 'owns_home', ''), + ( 'has_guarantor', ''), + ( 'income_ratio', ''), + ( 'debt_ratio', ''), + ( 'savings_ratio', ''), + ( 'job_type', 'manual'), + ( 'job_type', 'office'), + ( 'job_type', 'service'), + ( 'job_type', 'student'), + ( 'region', 'east'), + ( 'region', 'north'), + ( 'region', 'south'), + ( 'region', 'west')], + ) + +The important part is that ``credit_lines`` stays ordered and numeric, while +``job_type`` and ``region`` expand into one-hot blocks. + +Cell 3: Fit a classifier and choose a query +------------------------------------------- + +.. code-block:: python + + import pandas as pd + from sklearn.ensemble import RandomForestClassifier + + model = RandomForestClassifier( + n_estimators=40, + max_depth=4, + random_state=42, + ) + model.fit(data, target) + + predictions = pd.Series(model.predict(data), index=data.index) + query_index = predictions[predictions == 0].index[0] + query = data.loc[query_index].to_numpy(dtype=float).flatten() + query_frame = data.loc[[query_index]] + raw_query = raw.loc[query_index] + + print(raw_query) + print() + print("Model prediction:", int(model.predict(query_frame).item())) + +.. code-block:: text + + credit_lines 0 + owns_home 0 + has_guarantor 1 + income_ratio -0.349179 + debt_ratio 0.260349 + savings_ratio 0.234634 + job_type student + region west + Name: 0, dtype: object + + Model prediction: 0 + +Cell 4: Explain the query +------------------------- + +.. code-block:: python + + from ocean import ConstraintProgrammingExplainer + + explainer = ConstraintProgrammingExplainer(model, mapper=mapper) + explanation = explainer.explain( + query, + y=1, + norm=1, + max_time=10, + num_workers=1, + random_seed=42, + ) + if explanation is None: + raise RuntimeError("No counterfactual was found for the synthetic example.") + + counterfactual_frame = pd.DataFrame( + [explanation.to_numpy()], + columns=data.columns, + ) + + print("Target class:", 1) + print("Counterfactual prediction:", int(model.predict(counterfactual_frame).item())) + +.. code-block:: text + + Target class: 1 + Counterfactual prediction: 1 + +Cell 5: Inspect the decoded explanation +--------------------------------------- + +.. code-block:: python + + print(explanation) + +.. code-block:: text + + Explanation: + credit_lines : 0.0 + owns_home : 0 + has_guarantor : 1 + income_ratio : -0.2833683341741562 + debt_ratio : -0.1887158378958702 + savings_ratio : 0.29842646420001984 + job_type : student + region : north + +This decoded view is usually the most readable one: categorical one-hot blocks +are mapped back to labels, and the keys match the original dataframe columns. + +Cell 6: Inspect the processed vector and the final distance +----------------------------------------------------------- + +.. code-block:: python + + print(explanation.to_series()) + print() + print("Distance:", explainer.get_distance()) + +.. code-block:: text + + credit_lines 0.000000 + owns_home 0.000000 + has_guarantor 1.000000 + income_ratio -0.283368 + debt_ratio -0.188716 + savings_ratio 0.298426 + job_type manual 0.000000 + office 0.000000 + service 0.000000 + student 1.000000 + region east 0.000000 + north 1.000000 + south 0.000000 + west 0.000000 + dtype: float64 + + Distance: 1.3566914800296987 + +``get_distance()`` is the user-facing metric to report here: it reconstructs +the post-processed :math:`L_1` distance between the original query and the +decoded counterfactual, including the half-weight treatment for one-hot blocks. + +Full script +----------- + +If you want the exact runnable version behind this page: .. literalinclude:: ../examples/custom_dataset.py :language: python diff --git a/docs/explainer-guide.rst b/docs/explainer-guide.rst index b82964f..8ea5b2a 100644 --- a/docs/explainer-guide.rst +++ b/docs/explainer-guide.rst @@ -42,7 +42,8 @@ Every explainer exposes an ``explain`` method with the same core arguments. Parallel worker count when the backend exposes it. ``random_seed`` - Solver seed for more repeatable runs. + Solver seed for more repeatable runs. The MaxSAT backend currently accepts + this argument for API compatibility but does not use it. ``verbose`` Whether to print solver logs. @@ -66,7 +67,7 @@ Backend-specific behavior - Yes - No public callback list * - Automatic cleanup after solve - - No + - Yes by default - Yes - Yes * - Isolation forest support @@ -77,12 +78,13 @@ Backend-specific behavior Repeated solves --------------- -If you solve multiple queries with the same MIP explainer instance, call -``cleanup()`` after each solve to remove the temporary objective and class -constraints created for the previous query. +All three explainers default to ``clean_up=True`` inside ``explain``. That +means query-specific objectives and target-class constraints are removed +automatically after each solve unless you opt out. -The CP and MaxSAT explainers already clear those query-specific constraints -inside ``explain``. +Call ``cleanup()`` manually only when you deliberately run ``explain(..., +clean_up=False)`` and want to clear the previous query state yourself before +reusing the same explainer instance. Inspecting the result --------------------- @@ -93,6 +95,8 @@ useful. - ``explanation.x`` gives the processed numerical vector. - ``explanation.to_series()`` keeps the processed column names. - ``explanation.value`` decodes one-hot groups into original category labels. +- ``explainer.get_objective_value()`` returns the backend objective value for + the last solve. - ``explainer.get_distance()`` returns the post-processed distance between the query and the decoded counterfactual using the norm from the last ``explain(...)`` call. diff --git a/docs/installation.rst b/docs/installation.rst index 6a965b4..111350f 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -49,6 +49,10 @@ for the reference site. pip install .[docs] sphinx-build -W -b html docs docs/_build/html +Use Python 3.12 or newer when building from the source tree, since the +package uses Python 3.12 type-alias syntax in modules such as +``ocean.typing``. + Or through tox: .. code-block:: bash diff --git a/docs/overview.rst b/docs/overview.rst index 9a6e26c..f4a6aa8 100644 --- a/docs/overview.rst +++ b/docs/overview.rst @@ -89,5 +89,7 @@ surface is intentionally similar. the most readable form for reports and notebooks. If you want to solve multiple queries with the same MIP explainer instance, -call ``cleanup()`` between solves. The CP and MaxSAT explainers already clear -query-specific constraints after each solve. +you usually do not need any extra step because all three explainers default to +``clean_up=True`` inside ``explain``. Call ``cleanup()`` manually only when +you disabled that behavior with ``clean_up=False`` and want to reuse the same +instance safely. diff --git a/examples/custom_dataset.py b/examples/custom_dataset.py index af94d70..ac54f23 100644 --- a/examples/custom_dataset.py +++ b/examples/custom_dataset.py @@ -81,7 +81,10 @@ def main() -> None: if explanation is None: msg = "No counterfactual was found for the synthetic example." raise RuntimeError(msg) - counterfactual_frame = explanation.to_numpy().reshape(1, -1) + counterfactual_frame = pd.DataFrame( + [explanation.to_numpy()], + columns=data.columns, + ) print("Original raw instance:") print(raw_query) @@ -98,7 +101,7 @@ def main() -> None: print("Counterfactual vector:") print(explanation.to_series()) print() - print("Objective value:", explainer.get_objective_value()) + print("Distance:", explainer.get_distance()) if __name__ == "__main__": diff --git a/examples/simple_example_both.py b/examples/simple_example_both.py index cd04b99..473440d 100644 --- a/examples/simple_example_both.py +++ b/examples/simple_example_both.py @@ -1,68 +1,125 @@ import time -from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier + from ocean import ConstraintProgrammingExplainer, MixedIntegerProgramExplainer from ocean.datasets import load_adult -from xgboost import XGBClassifier + print_paths = True plot_anytime_distances = True -num_workers = 8 # Both CP and MILP solving support multithreading +num_workers = 8 random_state = 0 -timeout = 3600 # Maximum running time given to the (CP or MILP) solver - -# Load the adult dataset -(data, target), mapper = load_adult( - scale=True -) # scale=True to perform normalization +timeout = 3600 + + +def to_frame( + x: np.ndarray[tuple[int], np.dtype[np.float64]], + columns: pd.Index | pd.MultiIndex, +) -> pd.DataFrame: + return pd.DataFrame([x], columns=columns) + + +def print_close_threshold_paths( + model: RandomForestClassifier | AdaBoostClassifier, + cf: np.ndarray[tuple[int], np.dtype[np.float64]] | None, + *, + columns: pd.Index | pd.MultiIndex, + query_pred: int, + label: str, +) -> None: + if cf is None: + print(f"{label}: No CF found.") + return + + cf_frame = to_frame(cf, columns) + if int(model.predict(cf_frame)[0]) != query_pred: + print(f"{label} Valid CF.") + return + + print(f"INVALID {label} CF: decision path of the CF found by {label}") + for i, estimator in enumerate(model.estimators_): + if int(estimator.predict(cf_frame)[0]) != query_pred: + continue + + feature = estimator.tree_.feature + threshold = estimator.tree_.threshold + node_indicator = estimator.decision_path(cf_frame) + leaf_id = estimator.apply(cf_frame) + sample_id = 0 + start = node_indicator.indptr[sample_id] + stop = node_indicator.indptr[sample_id + 1] + node_index = node_indicator.indices[start:stop] + + print(node_index) + print( + f"[Tree {i}] Rules used to predict sample {sample_id} " + "with features close to a threshold:\n" + ) + for node_id in node_index: + if leaf_id[sample_id] == node_id: + continue + + threshold_sign = ( + "<=" if cf[feature[node_id]] <= threshold[node_id] else ">" + ) + if np.abs(cf[feature[node_id]] - threshold[node_id]) < 1e-3: + print( + f"decision node {node_id}: " + f"cf[{feature[node_id]}] = {cf[feature[node_id]]} " + f"{threshold_sign} {threshold[node_id]}" + ) + + +def unpack_anytime( + anytime: list[dict[str, float]] | None, +) -> tuple[list[float], list[float]]: + if anytime is None: + return [], [] + objectives = [entry.get("objective_value", 0.0) for entry in anytime] + times = [entry.get("time", 0.0) for entry in anytime] + return objectives, times + + +(data, target), mapper = load_adult(scale=True) X_train, X_test, y_train, y_test = train_test_split( - data, target, test_size=0.2, random_state=random_state + data, + target, + test_size=0.2, + random_state=random_state, ) -# Train a RF -#rf = RandomForestClassifier(n_estimators=5, max_depth=2, random_state=random_state) -rf = AdaBoostClassifier(estimator=DecisionTreeClassifier(max_depth=1), n_estimators=100, random_state=random_state) -#rf = XGBClassifier(n_estimators=5, max_depth=2, random_state=random_state) +rf = AdaBoostClassifier( + estimator=DecisionTreeClassifier(max_depth=1), + n_estimators=100, + random_state=random_state, +) rf.fit(X_train, y_train) -''' -from sklearn.tree import plot_tree -import matplotlib.pyplot as plt -# Plot the first tree of the forest -plt.figure(figsize=(20, 10)) -plot_tree(rf.estimators_[0], filled=True) -plt.title("First tree of the Random Forest") -plt.savefig("./first_tree_rf.png") -plt.close() -''' - print("RF train acc= ", rf.score(X_train, y_train)) print("RF test acc= ", rf.score(X_test, y_test)) -if isinstance(rf, RandomForestClassifier) or isinstance(rf, AdaBoostClassifier): +if isinstance(rf, (RandomForestClassifier, AdaBoostClassifier)): print( "RF size= ", - sum(a_tree.tree_.node_count for a_tree in rf.estimators_), + sum(tree.tree_.node_count for tree in rf.estimators_), " nodes.", ) -# Define a CF query using the qid-th element of the test set -# qid = 1 -# query = X_test.iloc[qid] -import numpy as np - qid = 10 -query = X_test.iloc[qid] -query_pred = rf.predict([np.asarray(query)])[0] -print("Query: ", query, "(class ", query_pred, ")") - +query_frame = X_test.iloc[[qid]] +query_series = query_frame.iloc[0] +query = query_series.to_numpy(dtype=float).flatten() +query_pred = int(rf.predict(query_frame)[0]) +print("Query: ", query_series, "(class ", query_pred, ")") -# Use the CP formulation to generate a CF cp_model = ConstraintProgrammingExplainer(rf, mapper=mapper) - start_ = time.time() -explanation_oceancp = cp_model.explain( +cp_explanation = cp_model.explain( query, y=1 - query_pred, norm=1, @@ -70,87 +127,29 @@ num_workers=num_workers, random_seed=random_state, max_time=timeout, - verbose=False + verbose=False, ) cp_time = time.time() - start_ +cp_cf = cp_explanation.to_numpy() if cp_explanation is not None else None -if explanation_oceancp is not None: - print( - "CP : ", - explanation_oceancp, - "(class ", - rf.predict([explanation_oceancp.to_numpy()])[0], - ")", - ) - # print("CP Sollist = ", cp_model.get_anytime_solutions()) +if cp_explanation is not None and cp_cf is not None: + cp_pred = int(rf.predict(to_frame(cp_cf, data.columns))[0]) + print("CP : ", cp_explanation, "(class ", cp_pred, ")") else: print("CP: No CF found.") -# debug CP ------------------------------------------------------- if print_paths: - cf = explanation_oceancp.to_numpy() - if cf is not None: - if rf.predict([cf])[0] == query_pred: - print("INVALID CP CF : decision path of the CF found by CP") - for i, clf in enumerate(rf.estimators_): - if clf.predict([cf])[0] == query_pred: - n_nodes = clf.tree_.node_count - children_left = clf.tree_.children_left - children_right = clf.tree_.children_right - feature = clf.tree_.feature - threshold = clf.tree_.threshold - values = clf.tree_.value - - node_indicator = clf.decision_path([cf]) - leaf_id = clf.apply([cf]) - sample_id = 0 - # obtain ids of the nodes `sample_id` goes through, i.e., row `sample_id` - node_index = node_indicator.indices[ - node_indicator.indptr[ - sample_id - ] : node_indicator.indptr[sample_id + 1] - ] - print(node_index) - print( - "[Tree {i}] Rules used to predict sample {id} with features values close to threshold:\n".format( - i=i, id=sample_id - ) - ) - for node_id in node_index: - # continue to the next node if it is a leaf node - if leaf_id[sample_id] == node_id: - continue - - # check if value of the split feature for sample 0 is below threshold - if cf[feature[node_id]] <= threshold[node_id]: - threshold_sign = "<=" - else: - threshold_sign = ">" - if ( - np.abs(cf[feature[node_id]] - threshold[node_id]) - < 1e-3 - ): - print( - "decision node {node} : (cf[{feature}] = {value}) " - "{inequality} {threshold})".format( - node=node_id, - sample=sample_id, - feature=feature[node_id], - value=cf[feature[node_id]], - inequality=threshold_sign, - threshold=threshold[node_id], - ) - ) - else: - print("CP Valid CF.") -# debug CP ------------------------------------------------------- - + print_close_threshold_paths( + rf, + cp_cf, + columns=data.columns, + query_pred=query_pred, + label="CP", + ) -# Use the MILP formulation to generate a CF milp_model = MixedIntegerProgramExplainer(rf, mapper=mapper) -# print("milp_model._num_epsilon", milp_model._num_epsilon) start_ = time.time() -explanation_ocean = milp_model.explain( +milp_explanation = milp_model.explain( query, y=1 - query_pred, norm=1, @@ -160,124 +159,61 @@ max_time=timeout, ) milp_time = time.time() - start_ -cf = explanation_ocean -# cf[4] += 0.0001 -if explanation_ocean is not None: - print( - "MILP : ", - explanation_ocean, - "(class ", - rf.predict([explanation_ocean.to_numpy()])[0], - ")", - ) - # print("MILP Sollist = ", milp_model.get_anytime_solutions()) +milp_cf = milp_explanation.to_numpy() if milp_explanation is not None else None + +if milp_explanation is not None and milp_cf is not None: + milp_pred = int(rf.predict(to_frame(milp_cf, data.columns))[0]) + print("MILP : ", milp_explanation, "(class ", milp_pred, ")") else: print("MILP: No CF found.") -# debug MILP ------------------------------------------------------- if print_paths: - cf = explanation_ocean.to_numpy() - if cf is not None: - if rf.predict([cf])[0] == query_pred: - print("INVALID MILP CF : decision path of the CF found by MILP") - for i, clf in enumerate(rf.estimators_): - if clf.predict([cf])[0] == query_pred: - n_nodes = clf.tree_.node_count - children_left = clf.tree_.children_left - children_right = clf.tree_.children_right - feature = clf.tree_.feature - threshold = clf.tree_.threshold - values = clf.tree_.value - - node_indicator = clf.decision_path([cf]) - leaf_id = clf.apply([cf]) - - sample_id = 0 - # obtain ids of the nodes `sample_id` goes through, i.e., row `sample_id` - node_index = node_indicator.indices[ - node_indicator.indptr[ - sample_id - ] : node_indicator.indptr[sample_id + 1] - ] - - print( - "[Tree {i}] Rules used to predict sample {id} with features values close to threshold:\n".format( - i=i, id=sample_id - ) - ) - for node_id in node_index: - # continue to the next node if it is a leaf node - if leaf_id[sample_id] == node_id: - continue - - # check if value of the split feature for sample 0 is below threshold - if cf[feature[node_id]] <= threshold[node_id]: - threshold_sign = "<=" - else: - threshold_sign = ">" - if ( - np.abs(cf[feature[node_id]] - threshold[node_id]) - < 1e-3 - ): - print( - "decision node {node} : (cf[{feature}] = {value}) " - "{inequality} {threshold})".format( - node=node_id, - sample=sample_id, - feature=feature[node_id], - value=cf[feature[node_id]], - inequality=threshold_sign, - threshold=threshold[node_id], - ) - ) - else: - print("MILP Valid CF.") -# debug MILP ------------------------------------------------------- - + print_close_threshold_paths( + rf, + milp_cf, + columns=data.columns, + query_pred=query_pred, + label="MILP", + ) -# Display summary statistics print(f"Runtime: CP {cp_time:.3f} s, MILP {milp_time:.3f} s") print( - f"Distance: CP {cp_model.get_objective_value():.10f},", - f" MILP {milp_model.get_objective_value():.10f}", + f"Distance: CP {cp_model.get_objective_value():.10f}, " + f"MILP {milp_model.get_objective_value():.10f}" ) print( - f"Status: CP {cp_model.get_solving_status()},", - f" MILP {milp_model.get_solving_status()}", + f"Status: CP {cp_model.get_solving_status()}, " + f"MILP {milp_model.get_solving_status()}" ) if plot_anytime_distances: - import matplotlib.pyplot as plt - - anytime_solution = {} - anytime_solution["cp"] = cp_model.get_anytime_solutions() - anytime_solution["mip"] = milp_model.get_anytime_solutions() - cpobjectives = [] - cptimes = [] - for dic in anytime_solution.get("cp", []): - cpobjectives.append(dic.get("objective_value", 0)) - cptimes.append(dic.get("time", 0)) - milpobjectives = [] - milptimes = [] - for dic in anytime_solution.get("mip", []): - milpobjectives.append(dic.get("objective_value", 0)) - milptimes.append(dic.get("time", 0)) + cp_objectives, cp_times = unpack_anytime(cp_model.get_anytime_solutions()) + milp_objectives, milp_times = unpack_anytime( + milp_model.get_anytime_solutions() + ) - plt.plot(milptimes, milpobjectives, marker="x", label="MILP", c="b") - if milp_model.get_solving_status() == "OPTIMAL": + plt.plot(milp_times, milp_objectives, marker="x", label="MILP", c="b") + if milp_times and milp_model.get_solving_status() == "OPTIMAL": plt.plot( - milptimes[-1], milpobjectives[-1], marker="*", c="b", markersize=15 + milp_times[-1], + milp_objectives[-1], + marker="*", + c="b", + markersize=15, ) - plt.plot(cptimes, cpobjectives, marker="x", label="CP", c="r") - if cp_model.get_solving_status() == "OPTIMAL": + plt.plot(cp_times, cp_objectives, marker="x", label="CP", c="r") + if cp_times and cp_model.get_solving_status() == "OPTIMAL": plt.plot( - cptimes[-1], cpobjectives[-1], marker="*", c="r", markersize=15 + cp_times[-1], + cp_objectives[-1], + marker="*", + c="r", + markersize=15, ) plt.legend() plt.ylabel("CF distance from query") plt.xlabel("Running time (second)") - plt.title("Anytime CF distance comparison.") plt.savefig("./anytime_distances_cp_vs_milp.pdf") diff --git a/ocean/cp/_explainer.py b/ocean/cp/_explainer.py index 083a910..10bec97 100644 --- a/ocean/cp/_explainer.py +++ b/ocean/cp/_explainer.py @@ -49,6 +49,15 @@ def __init__( self.solver = ENV.solver def get_objective_value(self) -> float: + """ + Return the scaled objective value of the last CP-SAT solve. + + Returns + ------- + float + Objective value rescaled back to the user-facing distance units. + + """ return self.solver.ObjectiveValue() / self._obj_scale def get_distance(self) -> float: @@ -95,9 +104,29 @@ def get_distance(self) -> float: return float(distance) def get_solving_status(self) -> str: + """ + Return the status string from the latest CP-SAT solve. + + Returns + ------- + str + Solver status such as ``"OPTIMAL"``, ``"FEASIBLE"``, or + ``"INFEASIBLE"``. + + """ return self.Status def get_anytime_solutions(self) -> list[dict[str, float]] | None: + """ + Return intermediate solutions collected during the last solve. + + Returns + ------- + list[dict[str, float]] | None + Time-stamped incumbent objective values when ``return_callback`` + was enabled in :meth:`explain`, otherwise ``None``. + + """ if self.callback is not None: return self.callback.sollist return None @@ -115,6 +144,42 @@ def explain( random_seed: int = 42, clean_up: bool = True, ) -> Explanation | None: + """ + Solve one counterfactual query with the CP-SAT backend. + + Parameters + ---------- + x + Query instance in the processed feature space. + y + Target class enforced by the counterfactual. + norm + Integer distance norm used by the CP objective. + return_callback + Whether to record incumbent solutions during the search. + verbose + Whether to enable CP-SAT search logging. + max_time + Time limit in seconds. + num_workers + Optional number of CP-SAT workers. + random_seed + Random seed passed to CP-SAT. + clean_up + Whether to remove query-specific constraints after the solve. + + Returns + ------- + Explanation | None + The decoded counterfactual, or ``None`` when no feasible + counterfactual is found within the given limits. + + Raises + ------ + RuntimeError + If CP-SAT reports an invalid model or an unexpected status. + + """ self.solver.parameters.log_search_progress = verbose self.solver.parameters.max_time_in_seconds = max_time self.solver.parameters.random_seed = random_seed diff --git a/ocean/maxsat/_explainer.py b/ocean/maxsat/_explainer.py index 5339c32..7c37cc6 100644 --- a/ocean/maxsat/_explainer.py +++ b/ocean/maxsat/_explainer.py @@ -58,6 +58,15 @@ def __init__( self.solver = ENV.solver def get_objective_value(self) -> float: + """ + Return the weighted MaxSAT objective value of the last solve. + + Returns + ------- + float + Objective value rescaled back to the user-facing distance units. + + """ return self.solver.cost / self._obj_scale def get_distance(self) -> float: @@ -104,11 +113,29 @@ def get_distance(self) -> float: return float(distance) def get_solving_status(self) -> str: + """ + Return the status of the latest MaxSAT solve. + + Returns + ------- + str + Status string such as ``"OPTIMAL"`` or ``"INFEASIBLE"``. + + """ return self.Status def get_anytime_solutions(self) -> list[dict[str, float]] | None: - """MaxSAT currently exposes only the final optimal solution.""" - raise NotImplementedError + """ + Return the intermediate solution trace for the last MaxSAT solve. + + Returns + ------- + None + The MaxSAT backend currently exposes only the final solution. + + """ + _ = self.Status + return None def explain( self, @@ -123,6 +150,42 @@ def explain( random_seed: int = 42, clean_up: bool = True, ) -> Explanation | None: + """ + Solve one counterfactual query with the weighted MaxSAT backend. + + Parameters + ---------- + x + Query instance in the processed feature space. + y + Target class enforced by the counterfactual. + norm + Distance norm. The MaxSAT backend currently supports only ``1``. + return_callback + Accepted for API compatibility but ignored by this backend. + verbose + Whether to enable RC2 logging. + max_time + Time limit in seconds. + num_workers + Optional thread count forwarded to the MaxSAT solver. + random_seed + Accepted for API compatibility but currently ignored. + clean_up + Whether to remove query-specific clauses after the solve. + + Returns + ------- + Explanation | None + The decoded counterfactual, or ``None`` when no feasible + counterfactual is found within the given limits. + + Raises + ------ + RuntimeError + If the MaxSAT solver raises an error that is not UNSAT or timeout. + + """ if return_callback: default_seed = 42 msg = "There are no callbacks for maxsat." diff --git a/ocean/maxsat/_model.py b/ocean/maxsat/_model.py index 3146ba0..6ea69d6 100644 --- a/ocean/maxsat/_model.py +++ b/ocean/maxsat/_model.py @@ -8,7 +8,7 @@ try: from pysat.pb import EncType, PBEnc -except AssertionError: +except (AssertionError, ImportError): EncType = None PBEnc = None diff --git a/ocean/mip/_explainer.py b/ocean/mip/_explainer.py index 431682e..9b1d1df 100644 --- a/ocean/mip/_explainer.py +++ b/ocean/mip/_explainer.py @@ -1,5 +1,6 @@ import time import warnings +from typing import cast import gurobipy as gp from sklearn.ensemble import AdaBoostClassifier, IsolationForest @@ -58,6 +59,15 @@ def __init__( self.build() def get_objective_value(self) -> float: + """ + Return the solver objective value of the last optimization run. + + Returns + ------- + float + Objective value reported by Gurobi for the latest solve. + + """ return self.ObjVal def get_distance(self) -> float: @@ -104,6 +114,15 @@ def get_distance(self) -> float: return float(distance) def get_solving_status(self) -> str: + """ + Return the latest Gurobi solve status as a readable string. + + Returns + ------- + str + Current model status such as ``"OPTIMAL"`` or ``"TIME_LIMIT"``. + + """ gurobi_statuses = { 1: "LOADED", 2: "OPTIMAL", @@ -125,7 +144,23 @@ def get_solving_status(self) -> str: return gurobi_statuses[self.Status] def get_anytime_solutions(self) -> list[dict[str, float]] | None: - return self.callback.sollist + """ + Return incumbent solutions collected during the last solve. + + Returns + ------- + list[dict[str, float]] | None + Time-stamped incumbent objective values when ``return_callback`` + was enabled in :meth:`explain`, otherwise ``None``. + + """ + callback = cast( + "SolutionCallback | None", + getattr(self, "callback", None), + ) + if callback is None: + return None + return callback.sollist def explain( self, @@ -140,6 +175,43 @@ def explain( random_seed: int = 42, clean_up: bool = True, ) -> Explanation | None: + """ + Solve one counterfactual query with the MIP backend. + + Parameters + ---------- + x + Query instance in the processed feature space. + y + Target class enforced by the counterfactual. + norm + Distance norm. The MIP backend supports ``1`` and ``2``. + return_callback + Whether to collect incumbent solutions through a Gurobi callback. + verbose + Whether to print Gurobi logs. + max_time + Time limit in seconds. + num_workers + Optional Gurobi thread count. + random_seed + Random seed passed to Gurobi. + clean_up + Whether to remove query-specific constraints after the solve. + + Returns + ------- + Explanation | None + The decoded counterfactual, or ``None`` when no feasible + counterfactual is found within the given limits. + + Raises + ------ + RuntimeError + If the solver stops for an unexpected status that is not handled + by the explainer. + + """ self.setParam("LogToConsole", int(verbose)) self.setParam("TimeLimit", max_time) self.setParam("Seed", random_seed) diff --git a/ocean/typing/__init__.py b/ocean/typing/__init__.py index a075f09..beb4f50 100644 --- a/ocean/typing/__init__.py +++ b/ocean/typing/__init__.py @@ -96,6 +96,8 @@ class SKLearnTree(Protocol): class BaseExplanation(Protocol): """Protocol implemented by explanation containers returned by explainers.""" + def to_numpy(self) -> Array1D: ... + def to_series(self) -> pd.Series: ... @property def x(self) -> Array1D: ... @property @@ -107,6 +109,11 @@ def query(self) -> Array1D: ... class BaseExplainer(Protocol): """Protocol implemented by all public OCEAN explainers.""" + def get_objective_value(self) -> float: ... + def get_distance(self) -> float: ... + def get_solving_status(self) -> str: ... + def get_anytime_solutions(self) -> list[dict[str, float]] | None: ... + def explain( self, x: Array1D, @@ -128,6 +135,9 @@ def cleanup(self) -> None: ... "Array", "Array1D", "Array2D", + "BaseExplainableEnsemble", + "BaseExplainer", + "BaseExplanation", "Dtype", "Index", "Index1L", @@ -138,6 +148,8 @@ def cleanup(self) -> None: ... "Key", "NodeId", "NodeIdArray1D", + "NodeIdDtype", + "NonNegative", "NonNegativeArray", "NonNegativeArray1D", "NonNegativeArray2D", @@ -147,9 +159,12 @@ def cleanup(self) -> None: ... "NonNegativeIntArray1D", "NonNegativeIntArray2D", "NonNegativeIntDtype", + "NonNegativeNumber", "Number", "ParsableEnsemble", "PositiveInt", + "SKLearnTree", "Unit", "UnitO", + "XGBTree", ]