Skip to content

Commit 0310657

Browse files
committed
Refactor Pruner class to improve pruning logic and enhance clarity by separating L1 and L0 pruning methods
1 parent fa3a22a commit 0310657

1 file changed

Lines changed: 22 additions & 4 deletions

File tree

fipe/prune/pruner.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from ..feature import FeatureEncoder
88
from ..mip import MIP
9-
from ..typing import BaseEnsemble, MNumber, MProb
9+
from ..typing import BaseEnsemble, MNumber, MProb, Number
1010
from .base import BasePruner
1111

1212

@@ -45,12 +45,12 @@ def __init__(
4545

4646
def build(self) -> None:
4747
self._add_weight_vars()
48-
self._add_objective()
4948
self._n_samples = 0
5049

5150
def add_samples(self, X: npt.ArrayLike) -> None:
5251
X = np.asarray(X)
53-
classes = self.ensemble.predict(X=X, w=self._weights)
52+
w = self._weights
53+
classes = self.ensemble.predict(X=X, w=w)
5454
prob = self.ensemble.predict_proba(X=X)
5555
n = X.shape[0]
5656
for i in range(n):
@@ -60,7 +60,9 @@ def prune(self) -> None:
6060
if self._n_samples == 0:
6161
msg = "No samples have been added to the pruner."
6262
raise RuntimeError(msg)
63-
self.optimize()
63+
self._prune_l1()
64+
if self._norm == 0:
65+
self._prune_l0()
6466

6567
@property
6668
def n_samples(self) -> int:
@@ -110,3 +112,19 @@ def _validate_norm(self, norm: int) -> None:
110112
if norm not in self.VALID_NOMRS:
111113
msg = "The norm must be either 0 or 1."
112114
raise ValueError(msg)
115+
116+
def _prune_l1(self) -> None:
117+
w = self._weight_vars
118+
self.setObjective(w.sum(), gp.GRB.MINIMIZE)
119+
self.optimize()
120+
121+
def _prune_l0(self) -> None:
122+
W = Number(np.sum(self._weight_vars.X))
123+
n = self.n_estimators
124+
w = self._weight_vars
125+
u = self.addMVar(shape=n, vtype=gp.GRB.BINARY, name="u")
126+
contrs = self.addConstr(w <= W * u, name="bigM")
127+
self.setObjective(u.sum(), gp.GRB.MINIMIZE)
128+
self.optimize()
129+
self.remove(contrs)
130+
self.remove(u)

0 commit comments

Comments
 (0)