Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2025 Yoemi
Copyright (c) 2025 Vidal, and the OCEAN contributors

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

[![Maintained](https://img.shields.io/badge/Maintained-YES-14b8a6?style=for-the-badge&logo=github)](https://github.com/vidalt/OCEAN/graphs/commit-activity)
[![License](https://img.shields.io/github/license/vidalt/OCEAN?style=for-the-badge&color=0ea5e9&logo=unlicense&logoColor=white)](https://github.com/vidalt/OCEAN/blob/main/LICENSE)
[![Documentation](https://img.shields.io/badge/Documentation-YES-06b6d4?style=for-the-badge&logo=readthedocs&logoColor=white)](https://ocean-py.readthedocs.io/en/latest/)
[![Contributors](https://img.shields.io/github/contributors/vidalt/OCEAN?style=for-the-badge&color=38bdf8&logo=github)](https://github.com/vidalt/OCEAN/graphs/contributors)
[![Stars](https://img.shields.io/github/stars/vidalt/OCEAN?style=for-the-badge&color=0284c7&logo=github)](https://github.com/vidalt/OCEAN/stargazers)
[![Watchers](https://img.shields.io/github/watchers/vidalt/OCEAN?style=for-the-badge&color=2563eb&logo=github)](https://github.com/vidalt/OCEAN/watchers)
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

project = "OCEAN"
author = "OCEAN contributors"
copyright = "2026, Awa Khouna and the OCEAN contributors"

try:
release = version("oceanpy")
Expand Down
4 changes: 4 additions & 0 deletions ocean/maxsat/_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ def explain(
raise
except TimeoutError as exc:
warnings.warn(str(exc), category=UserWarning, stacklevel=2)
signal.alarm(0)
if clean_up:
self.cleanup()
return None
finally:
signal.alarm(0)
# Store the query in the explanation
Expand Down
12 changes: 5 additions & 7 deletions ocean/mip/_explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,11 @@ def format_discrete_value(
return float(val)
return float(query_arr[f])

def _continuous_index(self, feature: FeatureVar) -> int:
idx = 0
for mu_idx in range(len(feature.levels) - 1):
value = feature.mget(mu_idx).X
if not np.isclose(value, 0.0, rtol=0.0, atol=self._atol):
idx = mu_idx
return idx
@staticmethod
def _continuous_index(feature: FeatureVar) -> int:
x = float(feature.xget().X)
idx = int(np.searchsorted(feature.levels, x, side="left")) - 1
return max(0, min(idx, len(feature.levels) - 2))

@property
def value(self) -> Mapping[Key, Key | Number]:
Expand Down
63 changes: 63 additions & 0 deletions tests/test_explainer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import sys

import gurobipy as gp
import numpy as np
import pandas as pd
import pytest
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier

from ocean import (
ConstraintProgrammingExplainer,
MaxSATExplainer,
MixedIntegerProgramExplainer,
)
from ocean.feature import parse_features

from .utils import ENV, generate_data

Expand Down Expand Up @@ -197,3 +202,61 @@ def test_cp_explain_xgb(

except gp.GurobiError as e:
pytest.skip(f"Skipping test due to {e}")


@pytest.mark.skipif(
sys.platform == "win32", reason="tests for non-windows platforms"
)
def test_explainers_return_same_distance_on_discrete_data() -> None:
raw = pd.DataFrame({
"age_bucket": [0, 0, 1, 1, 2, 2, 3, 3, 1, 2, 0, 3],
"owns_home": [0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1],
"job_type": [
"office",
"office",
"manual",
"manual",
"service",
"service",
"office",
"manual",
"service",
"office",
"manual",
"service",
],
})
y = np.array([0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1], dtype=np.int64)
data, mapper = parse_features(
raw,
discretes=("age_bucket",),
encoded=("job_type",),
scale=False,
)
clf = RandomForestClassifier(
random_state=7,
n_estimators=5,
max_depth=3,
)
clf.fit(data, y)

x = data.iloc[0, :].to_numpy(dtype=float).flatten()
prediction = np.asarray(clf.predict([x]), dtype=np.int64)
target = int(1 - prediction[0])

mip = MixedIntegerProgramExplainer(clf, mapper=mapper, env=ENV)
cp = ConstraintProgrammingExplainer(clf, mapper=mapper)
maxsat = MaxSATExplainer(clf, mapper=mapper)

exp_mip = mip.explain(x, y=target, norm=1, random_seed=7)
exp_cp = cp.explain(x, y=target, norm=1, random_seed=7)
exp_maxsat = maxsat.explain(x, y=target, norm=1, random_seed=7)

assert exp_mip is not None
assert exp_cp is not None
assert exp_maxsat is not None
assert clf.predict([exp_mip.to_numpy()])[0] == target
assert clf.predict([exp_cp.to_numpy()])[0] == target
assert clf.predict([exp_maxsat.to_numpy()])[0] == target
assert mip.get_distance() == pytest.approx(cp.get_distance())
assert mip.get_distance() == pytest.approx(maxsat.get_distance())
Loading