From 569e98ed5b8e056bbf4a0edb2d083178c058114b Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Mon, 22 Sep 2025 22:19:11 +0300 Subject: [PATCH 1/5] extract .include_best_split() --- smarttree/_column_splitter.py | 111 +++++++++++++++------------------- 1 file changed, 49 insertions(+), 62 deletions(-) diff --git a/smarttree/_column_splitter.py b/smarttree/_column_splitter.py index 2b0b22f..86c7ad8 100644 --- a/smarttree/_column_splitter.py +++ b/smarttree/_column_splitter.py @@ -61,6 +61,36 @@ def __init__( def split(self, *args, **kwargs) -> ColumnSplitResult: raise NotImplementedError + def include_best_split( + self, + parent_mask: pd.Series, + mask_na: pd.Series, + child_masks: list[pd.Series], + ) -> tuple[float, list[pd.Series], int]: + + candidates = [] + origin_child_masks = child_masks + for i, child_mask in enumerate(origin_child_masks): + child_masks = deepcopy(origin_child_masks) + child_masks[i] = child_mask | (parent_mask & mask_na) + for child_mask in child_masks: + if child_mask.sum() < self.min_samples_leaf: + break + else: + candidates.append(child_masks) + + best_information_gain = NO_INFORMATION_GAIN + best_child_masks = [] + best_child_na_index = -1 + for child_na_index, child_masks in enumerate(candidates): + information_gain = self.information_gain(parent_mask, child_masks, "include_best") + if best_information_gain < information_gain: + best_information_gain = information_gain + best_child_masks = child_masks + best_child_na_index = child_na_index + + return best_information_gain, best_child_masks, best_child_na_index + def information_gain( self, parent_mask: pd.Series, @@ -133,15 +163,6 @@ def gini_index(self, mask: pd.Series) -> float: C - total number of classes; p_i - the probability of choosing a sample with class i. """ - # N = mask.sum() - # - # gini_index = 1 - # for label in self.dataset.class_names: - # N_i = (mask & (self.dataset.y == label)).sum() - # p_i = N_i / N - # gini_index -= pow(p_i, 2) - # - # return gini_index return cgini_index(mask, self.dataset.y, self.dataset.class_names) def entropy(self, mask: pd.Series) -> float: @@ -231,33 +252,12 @@ def __num_split( na_mode = self.feature_na_mode[split_feature] if na_mode == "include_all": for i, child_mask in enumerate(child_masks): - child_masks[i] = child_mask | (parent_mask & mask_na) # update + child_masks[i] = child_mask | (parent_mask & mask_na) if child_masks[i].sum() < self.min_samples_leaf: return NO_INFORMATION_GAIN, [], -1 elif na_mode == "include_best": - candidates = [] - origin_child_masks = child_masks - for i, child_mask in enumerate(origin_child_masks): - child_masks = deepcopy(origin_child_masks) - child_masks[i] = child_mask | (parent_mask & mask_na) # update - for child_mask in child_masks: - if child_mask.sum() < self.min_samples_leaf: - break - else: - candidates.append(child_masks) - - best_information_gain = NO_INFORMATION_GAIN - best_child_masks = [] - best_child_na_index = -1 - for i, child_masks in enumerate(candidates): - information_gain = self.information_gain(parent_mask, child_masks, na_mode) - if best_information_gain < information_gain: - best_information_gain = information_gain - best_child_masks = child_masks - best_child_na_index = i - - return best_information_gain, best_child_masks, best_child_na_index + return self.include_best_split(parent_mask, mask_na, child_masks) information_gain = self.information_gain(parent_mask, child_masks, na_mode) @@ -340,33 +340,12 @@ def __cat_split( na_mode = self.feature_na_mode[split_feature] if na_mode == "include_all": for i, child_mask in enumerate(child_masks): - child_masks[i] = child_mask | (parent_mask & mask_na) # update + child_masks[i] = child_mask | (parent_mask & mask_na) if child_masks[i].sum() < self.min_samples_leaf: return NO_INFORMATION_GAIN, [], -1 elif na_mode == "include_best": - candidates = [] - origin_child_masks = child_masks - for i, child_mask in enumerate(origin_child_masks): - child_masks = deepcopy(origin_child_masks) - child_masks[i] = child_mask | (parent_mask & mask_na) # update - for child_mask in child_masks: - if child_mask.sum() < self.min_samples_leaf: - break - else: - candidates.append(child_masks) - - best_information_gain = NO_INFORMATION_GAIN - best_child_masks = [] - best_child_na_index = -1 - for i, child_masks in enumerate(candidates): - information_gain = self.information_gain(parent_mask, child_masks, na_mode) - if best_information_gain < information_gain: - best_information_gain = information_gain - best_child_masks = child_masks - best_child_na_index = i - - return best_information_gain, best_child_masks, best_child_na_index + return self.include_best_split(parent_mask, mask_na, child_masks) information_gain = self.information_gain(parent_mask, child_masks, na_mode) @@ -417,12 +396,12 @@ def split(self, node: TreeNode, split_feature: str) -> ColumnSplitResult: best_split_result = ColumnSplitResult.no_split() for feature_values in self.__rank_partitions(available_feature_values): - information_gain, child_masks = self.__rank_split( + information_gain, child_masks, child_na_index = self.__rank_split( node.mask, split_feature, feature_values ) if best_split_result.information_gain < information_gain: best_split_result = ColumnSplitResult( - information_gain, list(feature_values), child_masks + information_gain, list(feature_values), child_masks, child_na_index ) return best_split_result @@ -432,7 +411,9 @@ def __rank_split( parent_mask: pd.Series, split_feature: str, feature_values: tuple[list, list], - ) -> tuple[float, list[pd.Series]]: + ) -> tuple[float, list[pd.Series], int]: + + mask_na = parent_mask & self.dataset.mask_na[split_feature] feature_values_left, feature_values_right = feature_values @@ -440,13 +421,19 @@ def __rank_split( mask_right = parent_mask & self.dataset.X[split_feature].isin(feature_values_right) child_masks = [mask_left, mask_right] - for child_mask in child_masks: - if child_mask.sum() < self.min_samples_leaf: - return NO_INFORMATION_GAIN, [] + na_mode = self.feature_na_mode[split_feature] + if na_mode == "include_all": + for i, child_mask in enumerate(child_masks): + child_masks[i] = child_mask | (parent_mask & mask_na) + if child_masks[i].sum() < self.min_samples_leaf: + return NO_INFORMATION_GAIN, [], -1 + + elif na_mode == "include_best": + return self.include_best_split(parent_mask, mask_na, child_masks) information_gain = self.information_gain(parent_mask, child_masks) - return information_gain, child_masks + return information_gain, child_masks, -1 @staticmethod def __rank_partitions(collection: list) -> Generator[tuple[list, list], None, None]: From b26de58c0359933903b2300e5a491d2374233993 Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Mon, 22 Sep 2025 22:29:46 +0300 Subject: [PATCH 2/5] extract .include_all_split() --- smarttree/_column_splitter.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/smarttree/_column_splitter.py b/smarttree/_column_splitter.py index 86c7ad8..aea78e2 100644 --- a/smarttree/_column_splitter.py +++ b/smarttree/_column_splitter.py @@ -61,6 +61,22 @@ def __init__( def split(self, *args, **kwargs) -> ColumnSplitResult: raise NotImplementedError + def include_all_split( + self, + parent_mask: pd.Series, + mask_na: pd.Series, + child_masks: list[pd.Series], + ) -> tuple[float, list[pd.Series], int]: + + for i, child_mask in enumerate(child_masks): + child_masks[i] = child_mask | (parent_mask & mask_na) + if child_masks[i].sum() < self.min_samples_leaf: + return NO_INFORMATION_GAIN, [], -1 + + information_gain = self.information_gain(parent_mask, child_masks, "include_all") + + return information_gain, child_masks, -1 + def include_best_split( self, parent_mask: pd.Series, @@ -251,10 +267,7 @@ def __num_split( na_mode = self.feature_na_mode[split_feature] if na_mode == "include_all": - for i, child_mask in enumerate(child_masks): - child_masks[i] = child_mask | (parent_mask & mask_na) - if child_masks[i].sum() < self.min_samples_leaf: - return NO_INFORMATION_GAIN, [], -1 + return self.include_all_split(parent_mask, mask_na, child_masks) elif na_mode == "include_best": return self.include_best_split(parent_mask, mask_na, child_masks) @@ -339,10 +352,7 @@ def __cat_split( na_mode = self.feature_na_mode[split_feature] if na_mode == "include_all": - for i, child_mask in enumerate(child_masks): - child_masks[i] = child_mask | (parent_mask & mask_na) - if child_masks[i].sum() < self.min_samples_leaf: - return NO_INFORMATION_GAIN, [], -1 + return self.include_all_split(parent_mask, mask_na, child_masks) elif na_mode == "include_best": return self.include_best_split(parent_mask, mask_na, child_masks) @@ -423,10 +433,7 @@ def __rank_split( na_mode = self.feature_na_mode[split_feature] if na_mode == "include_all": - for i, child_mask in enumerate(child_masks): - child_masks[i] = child_mask | (parent_mask & mask_na) - if child_masks[i].sum() < self.min_samples_leaf: - return NO_INFORMATION_GAIN, [], -1 + return self.include_all_split(parent_mask, mask_na, child_masks) elif na_mode == "include_best": return self.include_best_split(parent_mask, mask_na, child_masks) From 9de7bdfc0782ff87244d3ebf69cce72fa3af9e88 Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Mon, 22 Sep 2025 22:39:33 +0300 Subject: [PATCH 3/5] refactoring --- smarttree/_column_splitter.py | 1 + smarttree/_node_splitter.py | 1 + 2 files changed, 2 insertions(+) diff --git a/smarttree/_column_splitter.py b/smarttree/_column_splitter.py index aea78e2..55178b1 100644 --- a/smarttree/_column_splitter.py +++ b/smarttree/_column_splitter.py @@ -20,6 +20,7 @@ class ColumnSplitResult(NamedTuple): + information_gain: float feature_values: list[list] child_masks: list[pd.Series] diff --git a/smarttree/_node_splitter.py b/smarttree/_node_splitter.py index fca0af4..c4fec9b 100644 --- a/smarttree/_node_splitter.py +++ b/smarttree/_node_splitter.py @@ -9,6 +9,7 @@ class NodeSplitResult(NamedTuple): + information_gain: float split_type: str split_feature: str From c38dbc3af683e21ddbef05205ab5eceba1628fb9 Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Mon, 22 Sep 2025 23:37:24 +0300 Subject: [PATCH 4/5] unified --- smarttree/_column_splitter.py | 85 ++++++++++++++++++----------------- smarttree/_dataset.py | 16 +++++-- tests/conftest.py | 9 ++-- tests/test__node_splitter.py | 12 ++--- 4 files changed, 65 insertions(+), 57 deletions(-) diff --git a/smarttree/_column_splitter.py b/smarttree/_column_splitter.py index 55178b1..0e73414 100644 --- a/smarttree/_column_splitter.py +++ b/smarttree/_column_splitter.py @@ -74,7 +74,7 @@ def include_all_split( if child_masks[i].sum() < self.min_samples_leaf: return NO_INFORMATION_GAIN, [], -1 - information_gain = self.information_gain(parent_mask, child_masks, "include_all") + information_gain = self.information_gain(parent_mask, child_masks, normalize=True) return information_gain, child_masks, -1 @@ -100,7 +100,7 @@ def include_best_split( best_child_masks = [] best_child_na_index = -1 for child_na_index, child_masks in enumerate(candidates): - information_gain = self.information_gain(parent_mask, child_masks, "include_best") + information_gain = self.information_gain(parent_mask, child_masks) if best_information_gain < information_gain: best_information_gain = information_gain best_child_masks = child_masks @@ -112,7 +112,7 @@ def information_gain( self, parent_mask: pd.Series, child_masks: list[pd.Series], - na_mode: NaModeType | None = None, + normalize: bool = False, ) -> float: r""" Calculates information gain of the split. @@ -122,8 +122,9 @@ def information_gain( boolean mask of parent node. child_masks: pd.Series list of boolean masks of child nodes. - na_mode: {"include_all", ...}, default=None - If "include_all" use normalization. + normalize: bool, default=False + if True, normalizes information gain by split factor to handle + unbalanced splits. Uses child node counts for normalization. Returns: float: information gain. @@ -160,7 +161,7 @@ def information_gain( impurity_child_i = self.impurity(child_mask_i) weighted_impurity_childs += (N_child_i / N_parent) * impurity_child_i - if na_mode == "include_all": + if normalize: norm_coef = N_parent / N_childs weighted_impurity_childs *= norm_coef @@ -260,22 +261,22 @@ def __num_split( threshold: float, ) -> tuple[float, list[pd.Series], int]: - mask_na = parent_mask & self.dataset.mask_na[split_feature] - mask_less = parent_mask & (self.dataset.X[split_feature] <= threshold) mask_more = parent_mask & (self.dataset.X[split_feature] > threshold) child_masks = [mask_less, mask_more] - na_mode = self.feature_na_mode[split_feature] - if na_mode == "include_all": - return self.include_all_split(parent_mask, mask_na, child_masks) - - elif na_mode == "include_best": - return self.include_best_split(parent_mask, mask_na, child_masks) - - information_gain = self.information_gain(parent_mask, child_masks, na_mode) - - return information_gain, child_masks, -1 + if self.dataset.has_na[split_feature]: + mask_na = parent_mask & self.dataset.mask_na[split_feature] + na_mode = self.feature_na_mode[split_feature] + if na_mode == "include_all": + return self.include_all_split(parent_mask, mask_na, child_masks) + elif na_mode == "include_best": + return self.include_best_split(parent_mask, mask_na, child_masks) + else: + assert False + else: + information_gain = self.information_gain(parent_mask, child_masks) + return information_gain, child_masks, -1 class CatColumnSplitter(BaseColumnSplitter): @@ -343,24 +344,24 @@ def __cat_split( feature_values: list[list], ) -> tuple[float, list[pd.Series], int]: - mask_na = parent_mask & self.dataset.mask_na[split_feature] - child_masks = [] for partition in feature_values: partition_mask = self.dataset.X[split_feature].isin(partition) child_mask = parent_mask & partition_mask child_masks.append(child_mask) - na_mode = self.feature_na_mode[split_feature] - if na_mode == "include_all": - return self.include_all_split(parent_mask, mask_na, child_masks) - - elif na_mode == "include_best": - return self.include_best_split(parent_mask, mask_na, child_masks) - - information_gain = self.information_gain(parent_mask, child_masks, na_mode) - - return information_gain, child_masks, -1 + if self.dataset.has_na[split_feature]: + mask_na = parent_mask & self.dataset.mask_na[split_feature] + na_mode = self.feature_na_mode[split_feature] + if na_mode == "include_all": + return self.include_all_split(parent_mask, mask_na, child_masks) + elif na_mode == "include_best": + return self.include_best_split(parent_mask, mask_na, child_masks) + else: + assert False + else: + information_gain = self.information_gain(parent_mask, child_masks) + return information_gain, child_masks, -1 def __cat_partitions( self, @@ -424,24 +425,24 @@ def __rank_split( feature_values: tuple[list, list], ) -> tuple[float, list[pd.Series], int]: - mask_na = parent_mask & self.dataset.mask_na[split_feature] - feature_values_left, feature_values_right = feature_values mask_left = parent_mask & self.dataset.X[split_feature].isin(feature_values_left) mask_right = parent_mask & self.dataset.X[split_feature].isin(feature_values_right) child_masks = [mask_left, mask_right] - na_mode = self.feature_na_mode[split_feature] - if na_mode == "include_all": - return self.include_all_split(parent_mask, mask_na, child_masks) - - elif na_mode == "include_best": - return self.include_best_split(parent_mask, mask_na, child_masks) - - information_gain = self.information_gain(parent_mask, child_masks) - - return information_gain, child_masks, -1 + if self.dataset.has_na[split_feature]: + mask_na = parent_mask & self.dataset.mask_na[split_feature] + na_mode = self.feature_na_mode[split_feature] + if na_mode == "include_all": + return self.include_all_split(parent_mask, mask_na, child_masks) + elif na_mode == "include_best": + return self.include_best_split(parent_mask, mask_na, child_masks) + else: + assert False + else: + information_gain = self.information_gain(parent_mask, child_masks) + return information_gain, child_masks, -1 @staticmethod def __rank_partitions(collection: list) -> Generator[tuple[list, list], None, None]: diff --git a/smarttree/_dataset.py b/smarttree/_dataset.py index fed54eb..90ee9b0 100644 --- a/smarttree/_dataset.py +++ b/smarttree/_dataset.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field import numpy as np import pandas as pd @@ -10,10 +10,20 @@ class Dataset: X: pd.DataFrame y: pd.Series + class_names: NDArray = field(init=False) + has_na: dict[str, bool] = field(init=False) + mask_na: dict[str, pd.Series] = field(init=False) def __post_init__(self) -> None: - self.class_names: NDArray = np.sort(self.y.unique()) - self.mask_na = {column: self.X[column].isna() for column in self.X.columns} + self.class_names = np.sort(self.y.unique()) + self.has_na = dict() + self.mask_na = dict() + for column in self.X.columns: + mask_na = self.X[column].isna() + has_na = mask_na.any() + self.has_na[column] = has_na + if has_na: + self.mask_na[column] = mask_na @property def size(self) -> int: diff --git a/tests/conftest.py b/tests/conftest.py index 89da3ef..482b5e0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import os +from itertools import chain import numpy as np import pandas as pd @@ -151,12 +152,8 @@ def feature_na_mode( ) -> dict[str, NaModeType | None]: result = dict() - for num_feature in num_features: - result[num_feature] = "min" - for cat_feature in cat_features: - result[cat_feature] = "as_category" - for rank_feature in rank_features: - result[rank_feature] = None + for feature in chain(num_features, cat_features, rank_features): + result[feature] = "include_best" return result diff --git a/tests/test__node_splitter.py b/tests/test__node_splitter.py index e816f04..010af8b 100644 --- a/tests/test__node_splitter.py +++ b/tests/test__node_splitter.py @@ -4,7 +4,7 @@ @pytest.fixture(scope="module") -def concrete_node_splitter( +def node_splitter( X, y, num_features, cat_features, rank_features, feature_na_mode ) -> NodeSplitter: return NodeSplitter( @@ -24,10 +24,10 @@ def concrete_node_splitter( ) -def test__find_best_split(concrete_node_splitter, root_node): - concrete_node_splitter.find_best_split_for(root_node, leaf_counter=0) +def test__find_best_split(node_splitter, root_node): + node_splitter.find_best_split_for(root_node, leaf_counter=0) -def test__is_splittable(concrete_node_splitter, root_node): - concrete_node_splitter.find_best_split_for(root_node, leaf_counter=0) - concrete_node_splitter.is_splittable(root_node, leaf_counter=0) +def test__is_splittable(node_splitter, root_node): + node_splitter.find_best_split_for(root_node, leaf_counter=0) + node_splitter.is_splittable(root_node, leaf_counter=0) From 43890af71c42bab6e70d48a89f2476f5adcc87b0 Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Mon, 22 Sep 2025 23:58:46 +0300 Subject: [PATCH 5/5] extract unified as .foo() --- smarttree/_column_splitter.py | 62 ++++++++++++++--------------------- 1 file changed, 25 insertions(+), 37 deletions(-) diff --git a/smarttree/_column_splitter.py b/smarttree/_column_splitter.py index 0e73414..e245624 100644 --- a/smarttree/_column_splitter.py +++ b/smarttree/_column_splitter.py @@ -62,6 +62,26 @@ def __init__( def split(self, *args, **kwargs) -> ColumnSplitResult: raise NotImplementedError + def foo( + self, + parent_mask: pd.Series, + split_feature: str, + child_masks: list[pd.Series], + ) -> tuple[float, list[pd.Series], int]: + + if self.dataset.has_na[split_feature]: + mask_na = parent_mask & self.dataset.mask_na[split_feature] + na_mode = self.feature_na_mode[split_feature] + if na_mode == "include_all": + return self.include_all_split(parent_mask, mask_na, child_masks) + elif na_mode == "include_best": + return self.include_best_split(parent_mask, mask_na, child_masks) + else: + assert False + else: + information_gain = self.information_gain(parent_mask, child_masks) + return information_gain, child_masks, -1 + def include_all_split( self, parent_mask: pd.Series, @@ -256,7 +276,8 @@ def __moving_average(array: NDArray, window: int = 2) -> NDArray: return np.convolve(array, np.ones(window), mode="valid") / window def __num_split( - self, parent_mask: pd.Series, + self, + parent_mask: pd.Series, split_feature: str, threshold: float, ) -> tuple[float, list[pd.Series], int]: @@ -265,18 +286,7 @@ def __num_split( mask_more = parent_mask & (self.dataset.X[split_feature] > threshold) child_masks = [mask_less, mask_more] - if self.dataset.has_na[split_feature]: - mask_na = parent_mask & self.dataset.mask_na[split_feature] - na_mode = self.feature_na_mode[split_feature] - if na_mode == "include_all": - return self.include_all_split(parent_mask, mask_na, child_masks) - elif na_mode == "include_best": - return self.include_best_split(parent_mask, mask_na, child_masks) - else: - assert False - else: - information_gain = self.information_gain(parent_mask, child_masks) - return information_gain, child_masks, -1 + return self.foo(parent_mask, split_feature, child_masks) class CatColumnSplitter(BaseColumnSplitter): @@ -350,18 +360,7 @@ def __cat_split( child_mask = parent_mask & partition_mask child_masks.append(child_mask) - if self.dataset.has_na[split_feature]: - mask_na = parent_mask & self.dataset.mask_na[split_feature] - na_mode = self.feature_na_mode[split_feature] - if na_mode == "include_all": - return self.include_all_split(parent_mask, mask_na, child_masks) - elif na_mode == "include_best": - return self.include_best_split(parent_mask, mask_na, child_masks) - else: - assert False - else: - information_gain = self.information_gain(parent_mask, child_masks) - return information_gain, child_masks, -1 + return self.foo(parent_mask, split_feature, child_masks) def __cat_partitions( self, @@ -431,18 +430,7 @@ def __rank_split( mask_right = parent_mask & self.dataset.X[split_feature].isin(feature_values_right) child_masks = [mask_left, mask_right] - if self.dataset.has_na[split_feature]: - mask_na = parent_mask & self.dataset.mask_na[split_feature] - na_mode = self.feature_na_mode[split_feature] - if na_mode == "include_all": - return self.include_all_split(parent_mask, mask_na, child_masks) - elif na_mode == "include_best": - return self.include_best_split(parent_mask, mask_na, child_masks) - else: - assert False - else: - information_gain = self.information_gain(parent_mask, child_masks) - return information_gain, child_masks, -1 + return self.foo(parent_mask, split_feature, child_masks) @staticmethod def __rank_partitions(collection: list) -> Generator[tuple[list, list], None, None]: