66
77from ..feature import FeatureEncoder
88from ..mip import MIP
9- from ..typing import BaseEnsemble , MNumber , MProb
9+ from ..typing import BaseEnsemble , MNumber , MProb , Number
1010from .base import BasePruner
1111
1212
@@ -45,12 +45,12 @@ def __init__(
4545
4646 def build (self ) -> None :
4747 self ._add_weight_vars ()
48- self ._add_objective ()
4948 self ._n_samples = 0
5049
5150 def add_samples (self , X : npt .ArrayLike ) -> None :
5251 X = np .asarray (X )
53- classes = self .ensemble .predict (X = X , w = self ._weights )
52+ w = self ._weights
53+ classes = self .ensemble .predict (X = X , w = w )
5454 prob = self .ensemble .predict_proba (X = X )
5555 n = X .shape [0 ]
5656 for i in range (n ):
@@ -60,7 +60,9 @@ def prune(self) -> None:
6060 if self ._n_samples == 0 :
6161 msg = "No samples have been added to the pruner."
6262 raise RuntimeError (msg )
63- self .optimize ()
63+ self ._prune_l1 ()
64+ if self ._norm == 0 :
65+ self ._prune_l0 ()
6466
6567 @property
6668 def n_samples (self ) -> int :
@@ -110,3 +112,19 @@ def _validate_norm(self, norm: int) -> None:
110112 if norm not in self .VALID_NOMRS :
111113 msg = "The norm must be either 0 or 1."
112114 raise ValueError (msg )
115+
116+ def _prune_l1 (self ) -> None :
117+ w = self ._weight_vars
118+ self .setObjective (w .sum (), gp .GRB .MINIMIZE )
119+ self .optimize ()
120+
121+ def _prune_l0 (self ) -> None :
122+ W = Number (np .sum (self ._weight_vars .X ))
123+ n = self .n_estimators
124+ w = self ._weight_vars
125+ u = self .addMVar (shape = n , vtype = gp .GRB .BINARY , name = "u" )
126+ contrs = self .addConstr (w <= W * u , name = "bigM" )
127+ self .setObjective (u .sum (), gp .GRB .MINIMIZE )
128+ self .optimize ()
129+ self .remove (contrs )
130+ self .remove (u )
0 commit comments