Skip to content

Commit 032df3e

Browse files
authored
Merge pull request #36 from vidalt/fix-docs
Enhance documentation and examples for OCEAN
2 parents 415901d + 2739afa commit 032df3e

18 files changed

Lines changed: 623 additions & 254 deletions

README.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
**ocean** is a full package dedicated to counterfactual explanations for **tree ensembles**.
1717
It builds on the paper *Optimal Counterfactual Explanations in Tree Ensemble* by Axel Parmentier and Thibaut Vidal in the *Proceedings of the thirty-eighth International Conference on Machine Learning*, 2021, in press. The article is [available here](http://proceedings.mlr.press/v139/parmentier21a/parmentier21a.pdf).
18-
Beyond the original MIP approach, ocean includes a new **constraint programming (CP)** method and will grow to cover additional formulations and heuristics.
18+
Beyond the original MIP approach, ocean also includes **constraint programming (CP)** and **weighted MaxSAT** backends for exact counterfactual search on the same parsed tree ensembles.
1919

2020
## Installation
2121

@@ -54,14 +54,15 @@ rf.fit(data, target)
5454
y = int(rf.predict(x).item())
5555
x = x.to_numpy().flatten()
5656

57-
# Explain the prediction using MIPEXplainer
57+
# Explain the prediction using the MIP backend
5858
mip_model = MixedIntegerProgramExplainer(rf, mapper=mapper)
5959
mip_explanation = mip_model.explain(x, y=1 - y, norm=1)
6060

61-
# Explain the prediction using CPEExplainer
61+
# Explain the prediction using the CP backend
6262
cp_model = ConstraintProgrammingExplainer(rf, mapper=mapper)
6363
cp_explanation = cp_model.explain(x, y=1 - y, norm=1)
6464

65+
# Explain the prediction using the MaxSAT backend
6566
maxsat_model = MaxSATExplainer(rf, mapper=mapper)
6667
maxsat_explanation = maxsat_model.explain(x, y=1 - y, norm=1)
6768

@@ -151,4 +152,4 @@ See the [examples folder](https://github.com/vidalt/OCEAN/tree/main/examples) fo
151152

152153
- Axel Parmentier and Thibaut Vidal. 2021. Optimal Counterfactual Explanations in Tree Ensembles. In *Proceedings of the thirty-eighth International Conference on Machine Learning*. PMLR, 8276–8286. [Available here](http://proceedings.mlr.press/v139/parmentier21a/parmentier21a.pdf).
153154
- Raevskaya, Alesya & Lehtonen, Tuomo. (2025). Optimal Counterfactual Explanations for Random Forests with MaxSAT. 10.3233/FAIA250895. [Available here](https://aaltodoc.aalto.fi/server/api/core/bitstreams/36760903-9b05-491d-b744-ea4309bdf538/content).
154-
155+

docs/_ext/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

docs/_ext/ocean_docs_stubs.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# ruff: noqa
1+
# ruff: noqa: BLE001, C901, N801, PLC0415, PLR6301, PYI034, RUF012
22

33
from __future__ import annotations
44

@@ -69,6 +69,9 @@ class Var(_GenericAliasMixin):
6969
class Constr(_GenericAliasMixin):
7070
pass
7171

72+
class MConstr(_GenericAliasMixin):
73+
pass
74+
7275
class LinExpr(_GenericAliasMixin):
7376
pass
7477

@@ -104,6 +107,7 @@ class GRB:
104107
module.Env = Env
105108
module.Var = Var
106109
module.Constr = Constr
110+
module.MConstr = MConstr
107111
module.LinExpr = LinExpr
108112
module.QuadExpr = QuadExpr
109113
module.MVar = MVar
@@ -252,6 +256,9 @@ class PBEnc:
252256
def atleast(*_args: object, **_kwargs: object) -> _PBResult:
253257
return _PBResult()
254258

259+
class EncType:
260+
adder = "adder"
261+
255262
class RC2(_GenericAliasMixin):
256263
def __init__(self, *_args: object, **_kwargs: object) -> None:
257264
self.cost = 0
@@ -271,6 +278,7 @@ def compute(self) -> list[int]:
271278

272279
formula.WCNF = WCNF
273280
formula.IDPool = IDPool
281+
pb.EncType = EncType
274282
pb.PBEnc = PBEnc
275283
rc2.RC2 = RC2
276284

docs/api/ocean.maxsat.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Weighted MaxSAT Backend
22
=======================
33

4+
The :mod:`ocean.maxsat` module exposes the weighted MaxSAT formulation backed
5+
by PySAT. The main public entry point is :class:`ocean.maxsat.Explainer`; the
6+
remaining classes document the underlying Boolean model, variable bundles, and
7+
manager helpers used to build that formulation.
8+
49
.. automodule:: ocean.maxsat
510
:members:
611
:undoc-members:

docs/api/ocean.mip.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Mixed-integer Programming Backend
22
=================================
33

4+
The :mod:`ocean.mip` module is the Gurobi-backed formulation. Most users start
5+
from :class:`ocean.mip.Explainer`, while :class:`ocean.mip.Model` and the
6+
variable classes are useful when you want to inspect or extend the lower-level
7+
formulation directly.
8+
49
.. automodule:: ocean.mip
510
:members:
611
:undoc-members:

docs/api/ocean.typing.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Shared Typing Helpers
22
=====================
33

4+
The :mod:`ocean.typing` module collects the shared protocols and type aliases
5+
used across the public explainers, tree parsing helpers, and internal backend
6+
implementations. It is the reference page for supported ensemble types,
7+
processed-array aliases, and the common explanation and explainer protocols.
8+
49
.. automodule:: ocean.typing
510
:members:
611
:undoc-members:

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
project = "OCEAN"
1818
author = "OCEAN contributors"
19-
copyright = "2026, Awa Khouna and the OCEAN contributors"
19+
copyright = "2026, Awa Khouna and the OCEAN contributors" # noqa: A001
2020

2121
try:
2222
release = version("oceanpy")

docs/custom-dataset.rst

Lines changed: 206 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,223 @@
11
Custom Dataset Example
22
======================
33

4-
This example mirrors the dataset-generation style used in the test suite: mix
5-
continuous features, ordered discrete features, binary flags, and unordered
6-
categorical features, parse them with OCEAN, train a tree ensemble, and
7-
explain one instance.
4+
This page turns the synthetic custom-dataset example into a notebook-style
5+
walkthrough. It starts from a raw pandas dataframe, parses the feature types
6+
with OCEAN, trains a random forest, and explains one class-``0`` prediction
7+
with the constraint-programming backend.
88

99
Why this example matters
1010
------------------------
1111

12-
The packaged dataset loaders are convenient, but most real integrations start
13-
from a custom pandas dataframe. This example shows the full path from raw data
14-
to a readable counterfactual without depending on any external dataset.
12+
The packaged dataset loaders are useful when you want a quick start, but most
13+
real integrations begin with a dataframe that you already own. This example
14+
shows the full workflow on mixed feature types:
1515

16-
Running the example
17-
-------------------
16+
- ordered discrete values through ``credit_lines``,
17+
- binary flags through ``owns_home`` and ``has_guarantor``,
18+
- continuous ratios through ``income_ratio``, ``debt_ratio``, and
19+
``savings_ratio``,
20+
- unordered nominal features through ``job_type`` and ``region``.
1821

19-
.. code-block:: bash
22+
Cell 1: Build a mixed-type dataframe
23+
------------------------------------
2024

21-
python examples/custom_dataset.py
25+
.. code-block:: python
2226
23-
What it does
24-
------------
27+
import numpy as np
28+
import pandas as pd
2529
26-
1. Generates a synthetic credit-style dataset with multiple feature types.
27-
2. Uses :func:`ocean.feature.parse_features` to build the processed matrix and
28-
mapper.
29-
3. Trains a random forest on that processed matrix.
30-
4. Selects a query that the model predicts as class ``0``.
31-
5. Uses ``ocean.ConstraintProgrammingExplainer`` to search for the closest
32-
class-``1`` counterfactual.
33-
6. Prints both the original raw instance and the decoded explanation.
30+
rng = np.random.default_rng(42)
31+
raw = pd.DataFrame({
32+
"credit_lines": rng.choice([0, 1, 2, 4], size=300),
33+
"owns_home": rng.integers(0, 2, size=300),
34+
"has_guarantor": rng.integers(0, 2, size=300),
35+
"income_ratio": rng.uniform(-0.4, 0.8, size=300),
36+
"debt_ratio": rng.uniform(0.0, 1.0, size=300),
37+
"savings_ratio": rng.uniform(-0.5, 0.6, size=300),
38+
"job_type": rng.choice(
39+
["office", "manual", "service", "student"],
40+
size=300,
41+
),
42+
"region": rng.choice(
43+
["north", "south", "east", "west"],
44+
size=300,
45+
),
46+
})
3447
35-
In this example, ``credit_lines`` is treated as an ordered discrete feature,
36-
while ``job_type`` and ``region`` are treated as unordered categories and are
37-
therefore one-hot encoded.
48+
score = (
49+
(raw["credit_lines"] >= 2).astype(int)
50+
+ raw["owns_home"].astype(int)
51+
+ raw["has_guarantor"].astype(int)
52+
+ (raw["income_ratio"] > 0.1).astype(int)
53+
+ (raw["savings_ratio"] > 0.0).astype(int)
54+
+ raw["job_type"].isin(["office", "service"]).astype(int)
55+
+ raw["region"].isin(["north", "east"]).astype(int)
56+
- (raw["debt_ratio"] > 0.55).astype(int)
57+
)
58+
target = (score >= 4).astype(int).rename("approved")
3859
39-
Source
40-
------
60+
Cell 2: Parse the features with OCEAN
61+
-------------------------------------
62+
63+
.. code-block:: python
64+
65+
from ocean.feature import parse_features
66+
67+
data, mapper = parse_features(raw, discretes=("credit_lines",))
68+
print(data.columns)
69+
70+
.. code-block:: text
71+
72+
MultiIndex([( 'credit_lines', ''),
73+
( 'owns_home', ''),
74+
( 'has_guarantor', ''),
75+
( 'income_ratio', ''),
76+
( 'debt_ratio', ''),
77+
( 'savings_ratio', ''),
78+
( 'job_type', 'manual'),
79+
( 'job_type', 'office'),
80+
( 'job_type', 'service'),
81+
( 'job_type', 'student'),
82+
( 'region', 'east'),
83+
( 'region', 'north'),
84+
( 'region', 'south'),
85+
( 'region', 'west')],
86+
)
87+
88+
The important part is that ``credit_lines`` stays ordered and numeric, while
89+
``job_type`` and ``region`` expand into one-hot blocks.
90+
91+
Cell 3: Fit a classifier and choose a query
92+
-------------------------------------------
93+
94+
.. code-block:: python
95+
96+
import pandas as pd
97+
from sklearn.ensemble import RandomForestClassifier
98+
99+
model = RandomForestClassifier(
100+
n_estimators=40,
101+
max_depth=4,
102+
random_state=42,
103+
)
104+
model.fit(data, target)
105+
106+
predictions = pd.Series(model.predict(data), index=data.index)
107+
query_index = predictions[predictions == 0].index[0]
108+
query = data.loc[query_index].to_numpy(dtype=float).flatten()
109+
query_frame = data.loc[[query_index]]
110+
raw_query = raw.loc[query_index]
111+
112+
print(raw_query)
113+
print()
114+
print("Model prediction:", int(model.predict(query_frame).item()))
115+
116+
.. code-block:: text
117+
118+
credit_lines 0
119+
owns_home 0
120+
has_guarantor 1
121+
income_ratio -0.349179
122+
debt_ratio 0.260349
123+
savings_ratio 0.234634
124+
job_type student
125+
region west
126+
Name: 0, dtype: object
127+
128+
Model prediction: 0
129+
130+
Cell 4: Explain the query
131+
-------------------------
132+
133+
.. code-block:: python
134+
135+
from ocean import ConstraintProgrammingExplainer
136+
137+
explainer = ConstraintProgrammingExplainer(model, mapper=mapper)
138+
explanation = explainer.explain(
139+
query,
140+
y=1,
141+
norm=1,
142+
max_time=10,
143+
num_workers=1,
144+
random_seed=42,
145+
)
146+
if explanation is None:
147+
raise RuntimeError("No counterfactual was found for the synthetic example.")
148+
149+
counterfactual_frame = pd.DataFrame(
150+
[explanation.to_numpy()],
151+
columns=data.columns,
152+
)
153+
154+
print("Target class:", 1)
155+
print("Counterfactual prediction:", int(model.predict(counterfactual_frame).item()))
156+
157+
.. code-block:: text
158+
159+
Target class: 1
160+
Counterfactual prediction: 1
161+
162+
Cell 5: Inspect the decoded explanation
163+
---------------------------------------
164+
165+
.. code-block:: python
166+
167+
print(explanation)
168+
169+
.. code-block:: text
170+
171+
Explanation:
172+
credit_lines : 0.0
173+
owns_home : 0
174+
has_guarantor : 1
175+
income_ratio : -0.2833683341741562
176+
debt_ratio : -0.1887158378958702
177+
savings_ratio : 0.29842646420001984
178+
job_type : student
179+
region : north
180+
181+
This decoded view is usually the most readable one: categorical one-hot blocks
182+
are mapped back to labels, and the keys match the original dataframe columns.
183+
184+
Cell 6: Inspect the processed vector and the final distance
185+
-----------------------------------------------------------
186+
187+
.. code-block:: python
188+
189+
print(explanation.to_series())
190+
print()
191+
print("Distance:", explainer.get_distance())
192+
193+
.. code-block:: text
194+
195+
credit_lines 0.000000
196+
owns_home 0.000000
197+
has_guarantor 1.000000
198+
income_ratio -0.283368
199+
debt_ratio -0.188716
200+
savings_ratio 0.298426
201+
job_type manual 0.000000
202+
office 0.000000
203+
service 0.000000
204+
student 1.000000
205+
region east 0.000000
206+
north 1.000000
207+
south 0.000000
208+
west 0.000000
209+
dtype: float64
210+
211+
Distance: 1.3566914800296987
212+
213+
``get_distance()`` is the user-facing metric to report here: it reconstructs
214+
the post-processed :math:`L_1` distance between the original query and the
215+
decoded counterfactual, including the half-weight treatment for one-hot blocks.
216+
217+
Full script
218+
-----------
219+
220+
If you want the exact runnable version behind this page:
41221

42222
.. literalinclude:: ../examples/custom_dataset.py
43223
:language: python

0 commit comments

Comments
 (0)