Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions ocean/cp/_explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
class Explanation(Mapper[FeatureVar], BaseExplanation):
"""Concrete explanation container returned by the CP backend."""

_epsilon: float = float(np.finfo(np.float32).eps)
_x: Array1D = np.zeros((0,), dtype=int)

def vget(self, i: int) -> cp.IntVar:
Expand Down Expand Up @@ -95,9 +94,9 @@ def format_value(
if j == idx:
value = float(query_arr[f])
elif j < idx:
value = float(levels[idx]) + self._epsilon
value = self._next_float32_up(levels[idx])
else:
value = float(levels[idx + 1]) - self._epsilon
value = self._next_float32_down(levels[idx + 1])
return value

def format_discrete_value(
Expand Down
5 changes: 2 additions & 3 deletions ocean/maxsat/_explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
class Explanation(Mapper[FeatureVar], BaseExplanation):
"""Concrete explanation container returned by the MaxSAT backend."""

_epsilon: float = float(np.finfo(np.float32).eps)
_x: Array1D = np.zeros((0,), dtype=int)

def vget(self, i: int) -> int:
Expand Down Expand Up @@ -142,9 +141,9 @@ def format_continuous_value(
if j == idx:
value = float(query_arr[f])
elif j < idx:
value = float(levels[idx]) + self._epsilon
value = self._next_float32_up(levels[idx])
else:
value = float(levels[idx + 1]) - self._epsilon
value = self._next_float32_down(levels[idx + 1])
return value

def format_discrete_value(
Expand Down
5 changes: 2 additions & 3 deletions ocean/mip/_explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
class Explanation(Mapper[FeatureVar], BaseExplanation):
"""Concrete explanation container returned by the MIP backend."""

_epsilon: float = float(np.finfo(np.float32).eps)
_atol: float = 1e-10
_x: Array1D = np.zeros((0,), dtype=int)

Expand Down Expand Up @@ -67,9 +66,9 @@ def format_value(
if j == idx:
value = float(query_arr[f])
elif j < idx:
value = float(levels[idx]) + self._epsilon
value = self._next_float32_up(levels[idx])
else:
value = float(levels[idx + 1]) - self._epsilon
value = self._next_float32_down(levels[idx + 1])
return value

def format_discrete_value(
Expand Down
35 changes: 35 additions & 0 deletions ocean/mip/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Model(BaseModel, FeatureManager, TreeManager, GarbageManager):

DEFAULT_EPSILON: Unit = 1.0 / (2.0**16)
DEFAULT_NUM_EPSILON: Unit = 1.0 / (2.0**6)
MIN_NUMERIC_TOL: float = 1e-9

class Type(Enum):
MIP = "MIP"
Expand Down Expand Up @@ -126,6 +127,15 @@ def build(self) -> None:
self.build_trees(self)
self._builder.build(self, trees=self.trees, mapper=self.mapper)
self._set_isolation()
self._stabilize_tolerances()

@property
def epsilon(self) -> Unit:
return self._epsilon

@property
def num_epsilon(self) -> Unit:
return self._num_epsilon

def add_objective(
self,
Expand Down Expand Up @@ -251,6 +261,31 @@ def _set_isolation(self) -> None:

self.addConstr(self.length >= self.min_length)

def _stabilize_tolerances(self) -> None:
"""
Tighten solver tolerances when the class margin is very small.

The MIP uses a tiny score margin ``self._epsilon`` to enforce the
target class. If integer-valued branching variables are accepted with
a larger tolerance than that margin, Gurobi can report solutions whose
near-integral leaf selections satisfy the model scores but decode to an
invalid counterfactual under exact tree traversal.
"""
total_weight = float(np.sum(self.weights, dtype=np.float64))
if total_weight <= 0.0:
return

safe_tol = self._epsilon / (2.0 * total_weight)
safe_tol = max(self.MIN_NUMERIC_TOL, safe_tol)

feasibility_tol = float(self.getParamInfo("FeasibilityTol")[2])
if safe_tol < feasibility_tol:
self.setParam("FeasibilityTol", safe_tol)

int_feas_tol = float(self.getParamInfo("IntFeasTol")[2])
if safe_tol < int_feas_tol:
self.setParam("IntFeasTol", safe_tol)

def _add_objective(self, x: Array1D, norm: int) -> Objective:
r"""
Build the symbolic objective expression for :math:`d(x, \hat{x})`.
Expand Down
28 changes: 27 additions & 1 deletion ocean/tree/_parse_xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,30 @@ def _logit(p: Array1D) -> Array1D:
return np.log(p / (1 - p))


def _previous_float32(value: float) -> float:
"""
Return the greatest float32 value strictly smaller than ``value``.

XGBoost routes numeric splits with a strict ``< split`` comparison, while
OCEAN's internal tree abstraction uses ``<= threshold``. Using the
previous representable float32 value preserves the exact branch semantics
for float32-encoded inputs.

Returns
-------
float
Largest representable float32 value strictly below ``value``.

"""
return float(
np.nextafter(
np.float32(value),
np.float32(-np.inf),
dtype=np.float32,
)
)


def _get_column_value(
xgb_tree: XGBTree, node_id: NonNegativeInt, column: str
) -> str | float | int:
Expand Down Expand Up @@ -105,7 +129,9 @@ def _build_xgb_node(

threshold = None
if mapper[name].is_numeric:
threshold = float(_get_column_value(xgb_tree, node_id, "Split")) - 1e-8
threshold = _previous_float32(
float(_get_column_value(xgb_tree, node_id, "Split"))
)
mapper[name].add(threshold)

left_id = _get_child_id(xgb_tree, node_id, "Yes")
Expand Down
20 changes: 20 additions & 0 deletions ocean/typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,26 @@ def value(self) -> Mapping[Key, Key | Number]: ...
@property
def query(self) -> Array1D: ...

@staticmethod
def _next_float32_up(value: float) -> float:
return float(
np.nextafter(
np.float32(value),
np.float32(np.inf),
dtype=np.float32,
)
)

@staticmethod
def _next_float32_down(value: float) -> float:
return float(
np.nextafter(
np.float32(value),
np.float32(-np.inf),
dtype=np.float32,
)
)


class BaseExplainer(Protocol):
"""Protocol implemented by all public OCEAN explainers."""
Expand Down
Loading
Loading