From 49932520aca8a97525df42412a5b667d11780fe6 Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Thu, 18 Sep 2025 23:56:13 +0300 Subject: [PATCH] added rank_na_mode --- smarttree/_check.py | 14 ++++++++++- smarttree/_classes.py | 24 +++++++++++++++++++ .../base/test__get_set_params.py | 6 ++++- .../classifier/test__repr_tree.py | 14 +++++++++++ 4 files changed, 56 insertions(+), 2 deletions(-) diff --git a/smarttree/_check.py b/smarttree/_check.py index cd63a10..d9c1ad1 100644 --- a/smarttree/_check.py +++ b/smarttree/_check.py @@ -18,6 +18,7 @@ def check__params( num_na_mode=None, cat_na_mode=None, cat_na_filler=None, + rank_na_mode=None, feature_na_mode=None, ): if criterion is not None: @@ -86,6 +87,9 @@ def check__params( if cat_na_filler is not None: _check__cat_na_filler(cat_na_filler) + if rank_na_mode is not None: + _check__rank_na_mode(rank_na_mode) + if feature_na_mode is not None: _check__feature_na_mode(feature_na_mode) @@ -279,7 +283,7 @@ def _check__hierarchy(hierarchy): def _check_na_mode(na_mode): if na_mode not in ("include_all", "include_best"): raise ValueError( - "`num_na_mode` must be Literal['include_all', 'include_best']." + "`na_mode` must be Literal['include_all', 'include_best']." f" The current value of `na_mode` is {na_mode!r}." ) @@ -308,6 +312,14 @@ def _check__cat_na_filler(cat_na_filler): ) +def _check__rank_na_mode(rank_na_mode): + if rank_na_mode not in ("include_all", "include_best"): + raise ValueError( + "`rank_na_mode` must be Literal['include_all', 'include_best']." + f" The current value of `na_mode` is {rank_na_mode!r}." + ) + + def _check__feature_na_mode(feature_na_mode): if not isinstance(feature_na_mode, dict): raise ValueError( diff --git a/smarttree/_classes.py b/smarttree/_classes.py index 2240930..cfe7aa8 100644 --- a/smarttree/_classes.py +++ b/smarttree/_classes.py @@ -48,6 +48,7 @@ def __init__( num_na_mode: NumNaModeType | None = None, cat_na_mode: CatNaModeType | None = None, cat_na_filler: str = "missing_value", + rank_na_mode: CommonNaModeType | None = None, feature_na_mode: dict[str, NaModeType] | None = None, verbose: VerboseType = "WARNING", ) -> None: @@ -68,6 +69,7 @@ def __init__( num_na_mode=num_na_mode, cat_na_mode=cat_na_mode, cat_na_filler=cat_na_filler, + rank_na_mode=rank_na_mode, feature_na_mode=feature_na_mode, ) @@ -102,6 +104,7 @@ def __init__( self.__num_na_mode = num_na_mode self.__cat_na_mode = cat_na_mode self.__cat_na_filler = cat_na_filler + self.__rank_na_mode = rank_na_mode self.__feature_na_mode: dict[str, NaModeType] = feature_na_mode or dict() self.logger = logging.getLogger() @@ -177,6 +180,10 @@ def cat_na_mode(self) -> CatNaModeType | None: def cat_na_filler(self) -> str: return self.__cat_na_filler + @property + def rank_na_mode(self) -> CommonNaModeType | None: + return self.__rank_na_mode + @property def feature_na_mode(self) -> dict[str, NaModeType]: return self.__feature_na_mode @@ -242,6 +249,7 @@ def get_params( "num_na_mode": self.num_na_mode, "cat_na_mode": self.cat_na_mode, "cat_na_filler": self.cat_na_filler, + "rank_na_mode": self.rank_na_mode, "feature_na_mode": self.feature_na_mode, } @@ -398,6 +406,16 @@ class SmartDecisionTreeClassifier(BaseSmartDecisionTree): training and predicting missing values will be filled with `categorical_na_filler`. + rank_na_mode: {"include_all", "include_best"}, default=None + The mode of handling missing values in a rank feature. + + - If "include_all", then while training samples with missing values + are included into all child nodes. While predicting decision is + weighted mean of all decisions in child nodes. + - If "include_best", then while training and prediction samples with + missing values are included into the best child node according to + information gain. + feature_na_mode: dict[str, {"min", "max", "as_category", "include_all", "include_best"}], default=None The mode of handling missing values in a feature. @@ -425,6 +443,7 @@ def __init__( num_na_mode: NumNaModeType | None = None, cat_na_mode: CatNaModeType | None = None, cat_na_filler: str = "missing_value", + rank_na_mode: CommonNaModeType | None = None, feature_na_mode: dict[str, NaModeType] | None = None, verbose: VerboseType = "WARNING", ) -> None: @@ -445,6 +464,7 @@ def __init__( num_na_mode=num_na_mode, cat_na_mode=cat_na_mode, cat_na_filler=cat_na_filler, + rank_na_mode=rank_na_mode, feature_na_mode=feature_na_mode, verbose=verbose, ) @@ -489,6 +509,8 @@ def __repr__(self) -> str: repr_.append(f"cat_na_mode={self.cat_na_mode!r}") if self.cat_na_filler != "missing_value": repr_.append(f"cat_na_filler={self.cat_na_filler!r}") + if self.rank_na_mode: + repr_.append(f"rank_na_mode={self.rank_na_mode!r}") return ( f"{self.__class__.__name__}({', '.join(repr_)})" @@ -562,6 +584,8 @@ def fit(self, X: pd.DataFrame, y: pd.Series) -> Self: self.feature_na_mode.update({f: self.num_na_mode for f in self.num_features}) if self.cat_na_mode is not None: self.feature_na_mode.update({f: self.cat_na_mode for f in self.cat_features}) + if self.rank_na_mode is not None: + self.feature_na_mode.update({f: self.rank_na_mode for f in self.rank_features}) self.feature_na_mode.update(temp_feature_na_mode) self.__classes = np.sort(y.unique()) diff --git a/tests/decision_tree/base/test__get_set_params.py b/tests/decision_tree/base/test__get_set_params.py index 84ff8f6..240d2d9 100644 --- a/tests/decision_tree/base/test__get_set_params.py +++ b/tests/decision_tree/base/test__get_set_params.py @@ -20,6 +20,7 @@ "na_mode": "include_best", "num_na_mode": None, "cat_na_mode": None, + "rank_na_mode": None, "cat_na_filler": "missing_value", "feature_na_mode": {}, } @@ -50,6 +51,7 @@ def test__get_params(concrete_smart_tree): ({"num_na_mode": "max"}, does_not_raise()), ({"cat_na_mode": "as_category"}, does_not_raise()), ({"cat_na_filler": "NA"}, does_not_raise()), + ({"rank_na_mode": "include_all"}, does_not_raise()), ({"feature_na_mode": {"feature": "max"}}, does_not_raise()), ( {"aboba": "aboba"}, @@ -60,7 +62,8 @@ def test__get_params(concrete_smart_tree): " Valid parameters are: criterion, max_depth, min_samples_split," " min_samples_leaf, max_leaf_nodes, min_impurity_decrease," " max_childs, num_features, cat_features, rank_features, hierarchy," - " na_mode, num_na_mode, cat_na_mode, cat_na_filler, feature_na_mode." + " na_mode, num_na_mode, cat_na_mode, cat_na_filler, rank_na_mode," + " feature_na_mode." ), ), ), @@ -82,6 +85,7 @@ def test__get_params(concrete_smart_tree): "num_na_mode", "cat_na_mode", "cat_na_filler", + "rank_na_mode", "feature_na_mode", "invalid", ], diff --git a/tests/decision_tree/classifier/test__repr_tree.py b/tests/decision_tree/classifier/test__repr_tree.py index 26f9733..46ac5f7 100644 --- a/tests/decision_tree/classifier/test__repr_tree.py +++ b/tests/decision_tree/classifier/test__repr_tree.py @@ -178,6 +178,7 @@ def test_repr_tree__hierarchy(hierarchy, expected): ("include_best", f"{CLASS_NAME}()"), ("include_all", f"{CLASS_NAME}(na_mode='include_all')"), ], + ids=["default value", "not default value"], ) def test_repr_tree__na_mode(na_mode, expected): na_mode: CommonNaModeType @@ -224,3 +225,16 @@ def test_repr_tree__cat_na_mode(cat_na_mode, expected): def test_repr_tree__cat_na_filler(cat_na_filler, expected): tree_classifier = SmartDecisionTreeClassifier(cat_na_filler=cat_na_filler) assert repr(tree_classifier) == expected + + +@pytest.mark.parametrize( + ("rank_na_mode", "expected"), + [ + (None, f"{CLASS_NAME}()"), + ("include_all", f"{CLASS_NAME}(rank_na_mode='include_all')"), + ], + ids=["default value", "not default value"], +) +def test_repr_tree_rank_na_mode(rank_na_mode, expected): + tree_classifier = SmartDecisionTreeClassifier(rank_na_mode=rank_na_mode) + assert repr(tree_classifier) == expected