Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion smarttree/_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def check__params(
num_na_mode=None,
cat_na_mode=None,
cat_na_filler=None,
rank_na_mode=None,
feature_na_mode=None,
):
if criterion is not None:
Expand Down Expand Up @@ -86,6 +87,9 @@ def check__params(
if cat_na_filler is not None:
_check__cat_na_filler(cat_na_filler)

if rank_na_mode is not None:
_check__rank_na_mode(rank_na_mode)

if feature_na_mode is not None:
_check__feature_na_mode(feature_na_mode)

Expand Down Expand Up @@ -279,7 +283,7 @@ def _check__hierarchy(hierarchy):
def _check_na_mode(na_mode):
if na_mode not in ("include_all", "include_best"):
raise ValueError(
"`num_na_mode` must be Literal['include_all', 'include_best']."
"`na_mode` must be Literal['include_all', 'include_best']."
f" The current value of `na_mode` is {na_mode!r}."
)

Expand Down Expand Up @@ -308,6 +312,14 @@ def _check__cat_na_filler(cat_na_filler):
)


def _check__rank_na_mode(rank_na_mode):
if rank_na_mode not in ("include_all", "include_best"):
raise ValueError(
"`rank_na_mode` must be Literal['include_all', 'include_best']."
f" The current value of `na_mode` is {rank_na_mode!r}."
)


def _check__feature_na_mode(feature_na_mode):
if not isinstance(feature_na_mode, dict):
raise ValueError(
Expand Down
24 changes: 24 additions & 0 deletions smarttree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
num_na_mode: NumNaModeType | None = None,
cat_na_mode: CatNaModeType | None = None,
cat_na_filler: str = "missing_value",
rank_na_mode: CommonNaModeType | None = None,
feature_na_mode: dict[str, NaModeType] | None = None,
verbose: VerboseType = "WARNING",
) -> None:
Expand All @@ -68,6 +69,7 @@ def __init__(
num_na_mode=num_na_mode,
cat_na_mode=cat_na_mode,
cat_na_filler=cat_na_filler,
rank_na_mode=rank_na_mode,
feature_na_mode=feature_na_mode,
)

Expand Down Expand Up @@ -102,6 +104,7 @@ def __init__(
self.__num_na_mode = num_na_mode
self.__cat_na_mode = cat_na_mode
self.__cat_na_filler = cat_na_filler
self.__rank_na_mode = rank_na_mode
self.__feature_na_mode: dict[str, NaModeType] = feature_na_mode or dict()

self.logger = logging.getLogger()
Expand Down Expand Up @@ -177,6 +180,10 @@ def cat_na_mode(self) -> CatNaModeType | None:
def cat_na_filler(self) -> str:
return self.__cat_na_filler

@property
def rank_na_mode(self) -> CommonNaModeType | None:
return self.__rank_na_mode

@property
def feature_na_mode(self) -> dict[str, NaModeType]:
return self.__feature_na_mode
Expand Down Expand Up @@ -242,6 +249,7 @@ def get_params(
"num_na_mode": self.num_na_mode,
"cat_na_mode": self.cat_na_mode,
"cat_na_filler": self.cat_na_filler,
"rank_na_mode": self.rank_na_mode,
"feature_na_mode": self.feature_na_mode,
}

Expand Down Expand Up @@ -398,6 +406,16 @@ class SmartDecisionTreeClassifier(BaseSmartDecisionTree):
training and predicting missing values will be filled with
`categorical_na_filler`.

rank_na_mode: {"include_all", "include_best"}, default=None
The mode of handling missing values in a rank feature.

- If "include_all", then while training samples with missing values
are included into all child nodes. While predicting decision is
weighted mean of all decisions in child nodes.
- If "include_best", then while training and prediction samples with
missing values are included into the best child node according to
information gain.

feature_na_mode: dict[str, {"min", "max", "as_category", "include_all", "include_best"}],
default=None
The mode of handling missing values in a feature.
Expand Down Expand Up @@ -425,6 +443,7 @@ def __init__(
num_na_mode: NumNaModeType | None = None,
cat_na_mode: CatNaModeType | None = None,
cat_na_filler: str = "missing_value",
rank_na_mode: CommonNaModeType | None = None,
feature_na_mode: dict[str, NaModeType] | None = None,
verbose: VerboseType = "WARNING",
) -> None:
Expand All @@ -445,6 +464,7 @@ def __init__(
num_na_mode=num_na_mode,
cat_na_mode=cat_na_mode,
cat_na_filler=cat_na_filler,
rank_na_mode=rank_na_mode,
feature_na_mode=feature_na_mode,
verbose=verbose,
)
Expand Down Expand Up @@ -489,6 +509,8 @@ def __repr__(self) -> str:
repr_.append(f"cat_na_mode={self.cat_na_mode!r}")
if self.cat_na_filler != "missing_value":
repr_.append(f"cat_na_filler={self.cat_na_filler!r}")
if self.rank_na_mode:
repr_.append(f"rank_na_mode={self.rank_na_mode!r}")

return (
f"{self.__class__.__name__}({', '.join(repr_)})"
Expand Down Expand Up @@ -562,6 +584,8 @@ def fit(self, X: pd.DataFrame, y: pd.Series) -> Self:
self.feature_na_mode.update({f: self.num_na_mode for f in self.num_features})
if self.cat_na_mode is not None:
self.feature_na_mode.update({f: self.cat_na_mode for f in self.cat_features})
if self.rank_na_mode is not None:
self.feature_na_mode.update({f: self.rank_na_mode for f in self.rank_features})
self.feature_na_mode.update(temp_feature_na_mode)

self.__classes = np.sort(y.unique())
Expand Down
6 changes: 5 additions & 1 deletion tests/decision_tree/base/test__get_set_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"na_mode": "include_best",
"num_na_mode": None,
"cat_na_mode": None,
"rank_na_mode": None,
"cat_na_filler": "missing_value",
"feature_na_mode": {},
}
Expand Down Expand Up @@ -50,6 +51,7 @@ def test__get_params(concrete_smart_tree):
({"num_na_mode": "max"}, does_not_raise()),
({"cat_na_mode": "as_category"}, does_not_raise()),
({"cat_na_filler": "NA"}, does_not_raise()),
({"rank_na_mode": "include_all"}, does_not_raise()),
({"feature_na_mode": {"feature": "max"}}, does_not_raise()),
(
{"aboba": "aboba"},
Expand All @@ -60,7 +62,8 @@ def test__get_params(concrete_smart_tree):
" Valid parameters are: criterion, max_depth, min_samples_split,"
" min_samples_leaf, max_leaf_nodes, min_impurity_decrease,"
" max_childs, num_features, cat_features, rank_features, hierarchy,"
" na_mode, num_na_mode, cat_na_mode, cat_na_filler, feature_na_mode."
" na_mode, num_na_mode, cat_na_mode, cat_na_filler, rank_na_mode,"
" feature_na_mode."
),
),
),
Expand All @@ -82,6 +85,7 @@ def test__get_params(concrete_smart_tree):
"num_na_mode",
"cat_na_mode",
"cat_na_filler",
"rank_na_mode",
"feature_na_mode",
"invalid",
],
Expand Down
14 changes: 14 additions & 0 deletions tests/decision_tree/classifier/test__repr_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def test_repr_tree__hierarchy(hierarchy, expected):
("include_best", f"{CLASS_NAME}()"),
("include_all", f"{CLASS_NAME}(na_mode='include_all')"),
],
ids=["default value", "not default value"],
)
def test_repr_tree__na_mode(na_mode, expected):
na_mode: CommonNaModeType
Expand Down Expand Up @@ -224,3 +225,16 @@ def test_repr_tree__cat_na_mode(cat_na_mode, expected):
def test_repr_tree__cat_na_filler(cat_na_filler, expected):
tree_classifier = SmartDecisionTreeClassifier(cat_na_filler=cat_na_filler)
assert repr(tree_classifier) == expected


@pytest.mark.parametrize(
("rank_na_mode", "expected"),
[
(None, f"{CLASS_NAME}()"),
("include_all", f"{CLASS_NAME}(rank_na_mode='include_all')"),
],
ids=["default value", "not default value"],
)
def test_repr_tree_rank_na_mode(rank_na_mode, expected):
tree_classifier = SmartDecisionTreeClassifier(rank_na_mode=rank_na_mode)
assert repr(tree_classifier) == expected
Loading