Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

**ocean** is a full package dedicated to counterfactual explanations for **tree ensembles**.
It builds on the paper *Optimal Counterfactual Explanations in Tree Ensemble* by Axel Parmentier and Thibaut Vidal in the *Proceedings of the thirty-eighth International Conference on Machine Learning*, 2021, in press. The article is [available here](http://proceedings.mlr.press/v139/parmentier21a/parmentier21a.pdf).
Beyond the original MIP approach, ocean includes a new **constraint programming (CP)** method and will grow to cover additional formulations and heuristics.
Beyond the original MIP approach, ocean also includes **constraint programming (CP)** and **weighted MaxSAT** backends for exact counterfactual search on the same parsed tree ensembles.

## Installation

Expand Down Expand Up @@ -54,14 +54,15 @@ rf.fit(data, target)
y = int(rf.predict(x).item())
x = x.to_numpy().flatten()

# Explain the prediction using MIPEXplainer
# Explain the prediction using the MIP backend
mip_model = MixedIntegerProgramExplainer(rf, mapper=mapper)
mip_explanation = mip_model.explain(x, y=1 - y, norm=1)

# Explain the prediction using CPEExplainer
# Explain the prediction using the CP backend
cp_model = ConstraintProgrammingExplainer(rf, mapper=mapper)
cp_explanation = cp_model.explain(x, y=1 - y, norm=1)

# Explain the prediction using the MaxSAT backend
maxsat_model = MaxSATExplainer(rf, mapper=mapper)
maxsat_explanation = maxsat_model.explain(x, y=1 - y, norm=1)

Expand Down Expand Up @@ -151,4 +152,4 @@ See the [examples folder](https://github.com/vidalt/OCEAN/tree/main/examples) fo

- Axel Parmentier and Thibaut Vidal. 2021. Optimal Counterfactual Explanations in Tree Ensembles. In *Proceedings of the thirty-eighth International Conference on Machine Learning*. PMLR, 8276–8286. [Available here](http://proceedings.mlr.press/v139/parmentier21a/parmentier21a.pdf).
- Raevskaya, Alesya & Lehtonen, Tuomo. (2025). Optimal Counterfactual Explanations for Random Forests with MaxSAT. 10.3233/FAIA250895. [Available here](https://aaltodoc.aalto.fi/server/api/core/bitstreams/36760903-9b05-491d-b744-ea4309bdf538/content).


1 change: 1 addition & 0 deletions docs/_ext/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

10 changes: 9 additions & 1 deletion docs/_ext/ocean_docs_stubs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# ruff: noqa
# ruff: noqa: BLE001, C901, N801, PLC0415, PLR6301, PYI034, RUF012

from __future__ import annotations

Expand Down Expand Up @@ -69,6 +69,9 @@ class Var(_GenericAliasMixin):
class Constr(_GenericAliasMixin):
pass

class MConstr(_GenericAliasMixin):
pass

class LinExpr(_GenericAliasMixin):
pass

Expand Down Expand Up @@ -104,6 +107,7 @@ class GRB:
module.Env = Env
module.Var = Var
module.Constr = Constr
module.MConstr = MConstr
module.LinExpr = LinExpr
module.QuadExpr = QuadExpr
module.MVar = MVar
Expand Down Expand Up @@ -252,6 +256,9 @@ class PBEnc:
def atleast(*_args: object, **_kwargs: object) -> _PBResult:
return _PBResult()

class EncType:
adder = "adder"

class RC2(_GenericAliasMixin):
def __init__(self, *_args: object, **_kwargs: object) -> None:
self.cost = 0
Expand All @@ -271,6 +278,7 @@ def compute(self) -> list[int]:

formula.WCNF = WCNF
formula.IDPool = IDPool
pb.EncType = EncType
pb.PBEnc = PBEnc
rc2.RC2 = RC2

Expand Down
5 changes: 5 additions & 0 deletions docs/api/ocean.maxsat.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Weighted MaxSAT Backend
=======================

The :mod:`ocean.maxsat` module exposes the weighted MaxSAT formulation backed
by PySAT. The main public entry point is :class:`ocean.maxsat.Explainer`; the
remaining classes document the underlying Boolean model, variable bundles, and
manager helpers used to build that formulation.

.. automodule:: ocean.maxsat
:members:
:undoc-members:
Expand Down
5 changes: 5 additions & 0 deletions docs/api/ocean.mip.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Mixed-integer Programming Backend
=================================

The :mod:`ocean.mip` module is the Gurobi-backed formulation. Most users start
from :class:`ocean.mip.Explainer`, while :class:`ocean.mip.Model` and the
variable classes are useful when you want to inspect or extend the lower-level
formulation directly.

.. automodule:: ocean.mip
:members:
:undoc-members:
Expand Down
5 changes: 5 additions & 0 deletions docs/api/ocean.typing.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Shared Typing Helpers
=====================

The :mod:`ocean.typing` module collects the shared protocols and type aliases
used across the public explainers, tree parsing helpers, and internal backend
implementations. It is the reference page for supported ensemble types,
processed-array aliases, and the common explanation and explainer protocols.

.. automodule:: ocean.typing
:members:
:undoc-members:
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

project = "OCEAN"
author = "OCEAN contributors"
copyright = "2026, Awa Khouna and the OCEAN contributors"
copyright = "2026, Awa Khouna and the OCEAN contributors" # noqa: A001

try:
release = version("oceanpy")
Expand Down
232 changes: 206 additions & 26 deletions docs/custom-dataset.rst
Original file line number Diff line number Diff line change
@@ -1,43 +1,223 @@
Custom Dataset Example
======================

This example mirrors the dataset-generation style used in the test suite: mix
continuous features, ordered discrete features, binary flags, and unordered
categorical features, parse them with OCEAN, train a tree ensemble, and
explain one instance.
This page turns the synthetic custom-dataset example into a notebook-style
walkthrough. It starts from a raw pandas dataframe, parses the feature types
with OCEAN, trains a random forest, and explains one class-``0`` prediction
with the constraint-programming backend.

Why this example matters
------------------------

The packaged dataset loaders are convenient, but most real integrations start
from a custom pandas dataframe. This example shows the full path from raw data
to a readable counterfactual without depending on any external dataset.
The packaged dataset loaders are useful when you want a quick start, but most
real integrations begin with a dataframe that you already own. This example
shows the full workflow on mixed feature types:

Running the example
-------------------
- ordered discrete values through ``credit_lines``,
- binary flags through ``owns_home`` and ``has_guarantor``,
- continuous ratios through ``income_ratio``, ``debt_ratio``, and
``savings_ratio``,
- unordered nominal features through ``job_type`` and ``region``.

.. code-block:: bash
Cell 1: Build a mixed-type dataframe
------------------------------------

python examples/custom_dataset.py
.. code-block:: python

What it does
------------
import numpy as np
import pandas as pd

1. Generates a synthetic credit-style dataset with multiple feature types.
2. Uses :func:`ocean.feature.parse_features` to build the processed matrix and
mapper.
3. Trains a random forest on that processed matrix.
4. Selects a query that the model predicts as class ``0``.
5. Uses ``ocean.ConstraintProgrammingExplainer`` to search for the closest
class-``1`` counterfactual.
6. Prints both the original raw instance and the decoded explanation.
rng = np.random.default_rng(42)
raw = pd.DataFrame({
"credit_lines": rng.choice([0, 1, 2, 4], size=300),
"owns_home": rng.integers(0, 2, size=300),
"has_guarantor": rng.integers(0, 2, size=300),
"income_ratio": rng.uniform(-0.4, 0.8, size=300),
"debt_ratio": rng.uniform(0.0, 1.0, size=300),
"savings_ratio": rng.uniform(-0.5, 0.6, size=300),
"job_type": rng.choice(
["office", "manual", "service", "student"],
size=300,
),
"region": rng.choice(
["north", "south", "east", "west"],
size=300,
),
})

In this example, ``credit_lines`` is treated as an ordered discrete feature,
while ``job_type`` and ``region`` are treated as unordered categories and are
therefore one-hot encoded.
score = (
(raw["credit_lines"] >= 2).astype(int)
+ raw["owns_home"].astype(int)
+ raw["has_guarantor"].astype(int)
+ (raw["income_ratio"] > 0.1).astype(int)
+ (raw["savings_ratio"] > 0.0).astype(int)
+ raw["job_type"].isin(["office", "service"]).astype(int)
+ raw["region"].isin(["north", "east"]).astype(int)
- (raw["debt_ratio"] > 0.55).astype(int)
)
target = (score >= 4).astype(int).rename("approved")

Source
------
Cell 2: Parse the features with OCEAN
-------------------------------------

.. code-block:: python

from ocean.feature import parse_features

data, mapper = parse_features(raw, discretes=("credit_lines",))
print(data.columns)

.. code-block:: text

MultiIndex([( 'credit_lines', ''),
( 'owns_home', ''),
( 'has_guarantor', ''),
( 'income_ratio', ''),
( 'debt_ratio', ''),
( 'savings_ratio', ''),
( 'job_type', 'manual'),
( 'job_type', 'office'),
( 'job_type', 'service'),
( 'job_type', 'student'),
( 'region', 'east'),
( 'region', 'north'),
( 'region', 'south'),
( 'region', 'west')],
)

The important part is that ``credit_lines`` stays ordered and numeric, while
``job_type`` and ``region`` expand into one-hot blocks.

Cell 3: Fit a classifier and choose a query
-------------------------------------------

.. code-block:: python

import pandas as pd
from sklearn.ensemble import RandomForestClassifier

model = RandomForestClassifier(
n_estimators=40,
max_depth=4,
random_state=42,
)
model.fit(data, target)

predictions = pd.Series(model.predict(data), index=data.index)
query_index = predictions[predictions == 0].index[0]
query = data.loc[query_index].to_numpy(dtype=float).flatten()
query_frame = data.loc[[query_index]]
raw_query = raw.loc[query_index]

print(raw_query)
print()
print("Model prediction:", int(model.predict(query_frame).item()))

.. code-block:: text

credit_lines 0
owns_home 0
has_guarantor 1
income_ratio -0.349179
debt_ratio 0.260349
savings_ratio 0.234634
job_type student
region west
Name: 0, dtype: object

Model prediction: 0

Cell 4: Explain the query
-------------------------

.. code-block:: python

from ocean import ConstraintProgrammingExplainer

explainer = ConstraintProgrammingExplainer(model, mapper=mapper)
explanation = explainer.explain(
query,
y=1,
norm=1,
max_time=10,
num_workers=1,
random_seed=42,
)
if explanation is None:
raise RuntimeError("No counterfactual was found for the synthetic example.")

counterfactual_frame = pd.DataFrame(
[explanation.to_numpy()],
columns=data.columns,
)

print("Target class:", 1)
print("Counterfactual prediction:", int(model.predict(counterfactual_frame).item()))

.. code-block:: text

Target class: 1
Counterfactual prediction: 1

Cell 5: Inspect the decoded explanation
---------------------------------------

.. code-block:: python

print(explanation)

.. code-block:: text

Explanation:
credit_lines : 0.0
owns_home : 0
has_guarantor : 1
income_ratio : -0.2833683341741562
debt_ratio : -0.1887158378958702
savings_ratio : 0.29842646420001984
job_type : student
region : north

This decoded view is usually the most readable one: categorical one-hot blocks
are mapped back to labels, and the keys match the original dataframe columns.

Cell 6: Inspect the processed vector and the final distance
-----------------------------------------------------------

.. code-block:: python

print(explanation.to_series())
print()
print("Distance:", explainer.get_distance())

.. code-block:: text

credit_lines 0.000000
owns_home 0.000000
has_guarantor 1.000000
income_ratio -0.283368
debt_ratio -0.188716
savings_ratio 0.298426
job_type manual 0.000000
office 0.000000
service 0.000000
student 1.000000
region east 0.000000
north 1.000000
south 0.000000
west 0.000000
dtype: float64

Distance: 1.3566914800296987

``get_distance()`` is the user-facing metric to report here: it reconstructs
the post-processed :math:`L_1` distance between the original query and the
decoded counterfactual, including the half-weight treatment for one-hot blocks.

Full script
-----------

If you want the exact runnable version behind this page:

.. literalinclude:: ../examples/custom_dataset.py
:language: python
Expand Down
Loading
Loading