Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 80 additions & 96 deletions smarttree/_column_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@


class ColumnSplitResult(NamedTuple):

information_gain: float
feature_values: list[list]
child_masks: list[pd.Series]
Expand Down Expand Up @@ -61,11 +62,77 @@ def __init__(
def split(self, *args, **kwargs) -> ColumnSplitResult:
raise NotImplementedError

def foo(
self,
parent_mask: pd.Series,
split_feature: str,
child_masks: list[pd.Series],
) -> tuple[float, list[pd.Series], int]:

if self.dataset.has_na[split_feature]:
mask_na = parent_mask & self.dataset.mask_na[split_feature]
na_mode = self.feature_na_mode[split_feature]
if na_mode == "include_all":
return self.include_all_split(parent_mask, mask_na, child_masks)
elif na_mode == "include_best":
return self.include_best_split(parent_mask, mask_na, child_masks)
else:
assert False
else:
information_gain = self.information_gain(parent_mask, child_masks)
return information_gain, child_masks, -1

def include_all_split(
self,
parent_mask: pd.Series,
mask_na: pd.Series,
child_masks: list[pd.Series],
) -> tuple[float, list[pd.Series], int]:

for i, child_mask in enumerate(child_masks):
child_masks[i] = child_mask | (parent_mask & mask_na)
if child_masks[i].sum() < self.min_samples_leaf:
return NO_INFORMATION_GAIN, [], -1

information_gain = self.information_gain(parent_mask, child_masks, normalize=True)

return information_gain, child_masks, -1

def include_best_split(
self,
parent_mask: pd.Series,
mask_na: pd.Series,
child_masks: list[pd.Series],
) -> tuple[float, list[pd.Series], int]:

candidates = []
origin_child_masks = child_masks
for i, child_mask in enumerate(origin_child_masks):
child_masks = deepcopy(origin_child_masks)
child_masks[i] = child_mask | (parent_mask & mask_na)
for child_mask in child_masks:
if child_mask.sum() < self.min_samples_leaf:
break
else:
candidates.append(child_masks)

best_information_gain = NO_INFORMATION_GAIN
best_child_masks = []
best_child_na_index = -1
for child_na_index, child_masks in enumerate(candidates):
information_gain = self.information_gain(parent_mask, child_masks)
if best_information_gain < information_gain:
best_information_gain = information_gain
best_child_masks = child_masks
best_child_na_index = child_na_index

return best_information_gain, best_child_masks, best_child_na_index

def information_gain(
self,
parent_mask: pd.Series,
child_masks: list[pd.Series],
na_mode: NaModeType | None = None,
normalize: bool = False,
) -> float:
r"""
Calculates information gain of the split.
Expand All @@ -75,8 +142,9 @@ def information_gain(
boolean mask of parent node.
child_masks: pd.Series
list of boolean masks of child nodes.
na_mode: {"include_all", ...}, default=None
If "include_all" use normalization.
normalize: bool, default=False
if True, normalizes information gain by split factor to handle
unbalanced splits. Uses child node counts for normalization.

Returns:
float: information gain.
Expand Down Expand Up @@ -113,7 +181,7 @@ def information_gain(
impurity_child_i = self.impurity(child_mask_i)
weighted_impurity_childs += (N_child_i / N_parent) * impurity_child_i

if na_mode == "include_all":
if normalize:
norm_coef = N_parent / N_childs
weighted_impurity_childs *= norm_coef

Expand All @@ -133,15 +201,6 @@ def gini_index(self, mask: pd.Series) -> float:
C - total number of classes;
p_i - the probability of choosing a sample with class i.
"""
# N = mask.sum()
#
# gini_index = 1
# for label in self.dataset.class_names:
# N_i = (mask & (self.dataset.y == label)).sum()
# p_i = N_i / N
# gini_index -= pow(p_i, 2)
#
# return gini_index
return cgini_index(mask, self.dataset.y, self.dataset.class_names)

def entropy(self, mask: pd.Series) -> float:
Expand Down Expand Up @@ -217,51 +276,17 @@ def __moving_average(array: NDArray, window: int = 2) -> NDArray:
return np.convolve(array, np.ones(window), mode="valid") / window

def __num_split(
self, parent_mask: pd.Series,
self,
parent_mask: pd.Series,
split_feature: str,
threshold: float,
) -> tuple[float, list[pd.Series], int]:

mask_na = parent_mask & self.dataset.mask_na[split_feature]

mask_less = parent_mask & (self.dataset.X[split_feature] <= threshold)
mask_more = parent_mask & (self.dataset.X[split_feature] > threshold)
child_masks = [mask_less, mask_more]

na_mode = self.feature_na_mode[split_feature]
if na_mode == "include_all":
for i, child_mask in enumerate(child_masks):
child_masks[i] = child_mask | (parent_mask & mask_na) # update
if child_masks[i].sum() < self.min_samples_leaf:
return NO_INFORMATION_GAIN, [], -1

elif na_mode == "include_best":
candidates = []
origin_child_masks = child_masks
for i, child_mask in enumerate(origin_child_masks):
child_masks = deepcopy(origin_child_masks)
child_masks[i] = child_mask | (parent_mask & mask_na) # update
for child_mask in child_masks:
if child_mask.sum() < self.min_samples_leaf:
break
else:
candidates.append(child_masks)

best_information_gain = NO_INFORMATION_GAIN
best_child_masks = []
best_child_na_index = -1
for i, child_masks in enumerate(candidates):
information_gain = self.information_gain(parent_mask, child_masks, na_mode)
if best_information_gain < information_gain:
best_information_gain = information_gain
best_child_masks = child_masks
best_child_na_index = i

return best_information_gain, best_child_masks, best_child_na_index

information_gain = self.information_gain(parent_mask, child_masks, na_mode)

return information_gain, child_masks, -1
return self.foo(parent_mask, split_feature, child_masks)


class CatColumnSplitter(BaseColumnSplitter):
Expand Down Expand Up @@ -329,48 +354,13 @@ def __cat_split(
feature_values: list[list],
) -> tuple[float, list[pd.Series], int]:

mask_na = parent_mask & self.dataset.mask_na[split_feature]

child_masks = []
for partition in feature_values:
partition_mask = self.dataset.X[split_feature].isin(partition)
child_mask = parent_mask & partition_mask
child_masks.append(child_mask)

na_mode = self.feature_na_mode[split_feature]
if na_mode == "include_all":
for i, child_mask in enumerate(child_masks):
child_masks[i] = child_mask | (parent_mask & mask_na) # update
if child_masks[i].sum() < self.min_samples_leaf:
return NO_INFORMATION_GAIN, [], -1

elif na_mode == "include_best":
candidates = []
origin_child_masks = child_masks
for i, child_mask in enumerate(origin_child_masks):
child_masks = deepcopy(origin_child_masks)
child_masks[i] = child_mask | (parent_mask & mask_na) # update
for child_mask in child_masks:
if child_mask.sum() < self.min_samples_leaf:
break
else:
candidates.append(child_masks)

best_information_gain = NO_INFORMATION_GAIN
best_child_masks = []
best_child_na_index = -1
for i, child_masks in enumerate(candidates):
information_gain = self.information_gain(parent_mask, child_masks, na_mode)
if best_information_gain < information_gain:
best_information_gain = information_gain
best_child_masks = child_masks
best_child_na_index = i

return best_information_gain, best_child_masks, best_child_na_index

information_gain = self.information_gain(parent_mask, child_masks, na_mode)

return information_gain, child_masks, -1
return self.foo(parent_mask, split_feature, child_masks)

def __cat_partitions(
self,
Expand Down Expand Up @@ -417,12 +407,12 @@ def split(self, node: TreeNode, split_feature: str) -> ColumnSplitResult:

best_split_result = ColumnSplitResult.no_split()
for feature_values in self.__rank_partitions(available_feature_values):
information_gain, child_masks = self.__rank_split(
information_gain, child_masks, child_na_index = self.__rank_split(
node.mask, split_feature, feature_values
)
if best_split_result.information_gain < information_gain:
best_split_result = ColumnSplitResult(
information_gain, list(feature_values), child_masks
information_gain, list(feature_values), child_masks, child_na_index
)

return best_split_result
Expand All @@ -432,21 +422,15 @@ def __rank_split(
parent_mask: pd.Series,
split_feature: str,
feature_values: tuple[list, list],
) -> tuple[float, list[pd.Series]]:
) -> tuple[float, list[pd.Series], int]:

feature_values_left, feature_values_right = feature_values

mask_left = parent_mask & self.dataset.X[split_feature].isin(feature_values_left)
mask_right = parent_mask & self.dataset.X[split_feature].isin(feature_values_right)
child_masks = [mask_left, mask_right]

for child_mask in child_masks:
if child_mask.sum() < self.min_samples_leaf:
return NO_INFORMATION_GAIN, []

information_gain = self.information_gain(parent_mask, child_masks)

return information_gain, child_masks
return self.foo(parent_mask, split_feature, child_masks)

@staticmethod
def __rank_partitions(collection: list) -> Generator[tuple[list, list], None, None]:
Expand Down
16 changes: 13 additions & 3 deletions smarttree/_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field

import numpy as np
import pandas as pd
Expand All @@ -10,10 +10,20 @@ class Dataset:

X: pd.DataFrame
y: pd.Series
class_names: NDArray = field(init=False)
has_na: dict[str, bool] = field(init=False)
mask_na: dict[str, pd.Series] = field(init=False)

def __post_init__(self) -> None:
self.class_names: NDArray = np.sort(self.y.unique())
self.mask_na = {column: self.X[column].isna() for column in self.X.columns}
self.class_names = np.sort(self.y.unique())
self.has_na = dict()
self.mask_na = dict()
for column in self.X.columns:
mask_na = self.X[column].isna()
has_na = mask_na.any()
self.has_na[column] = has_na
if has_na:
self.mask_na[column] = mask_na

@property
def size(self) -> int:
Expand Down
1 change: 1 addition & 0 deletions smarttree/_node_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@


class NodeSplitResult(NamedTuple):

information_gain: float
split_type: str
split_feature: str
Expand Down
9 changes: 3 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from itertools import chain

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -151,12 +152,8 @@ def feature_na_mode(
) -> dict[str, NaModeType | None]:

result = dict()
for num_feature in num_features:
result[num_feature] = "min"
for cat_feature in cat_features:
result[cat_feature] = "as_category"
for rank_feature in rank_features:
result[rank_feature] = None
for feature in chain(num_features, cat_features, rank_features):
result[feature] = "include_best"

return result

Expand Down
12 changes: 6 additions & 6 deletions tests/test__node_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


@pytest.fixture(scope="module")
def concrete_node_splitter(
def node_splitter(
X, y, num_features, cat_features, rank_features, feature_na_mode
) -> NodeSplitter:
return NodeSplitter(
Expand All @@ -24,10 +24,10 @@ def concrete_node_splitter(
)


def test__find_best_split(concrete_node_splitter, root_node):
concrete_node_splitter.find_best_split_for(root_node, leaf_counter=0)
def test__find_best_split(node_splitter, root_node):
node_splitter.find_best_split_for(root_node, leaf_counter=0)


def test__is_splittable(concrete_node_splitter, root_node):
concrete_node_splitter.find_best_split_for(root_node, leaf_counter=0)
concrete_node_splitter.is_splittable(root_node, leaf_counter=0)
def test__is_splittable(node_splitter, root_node):
node_splitter.find_best_split_for(root_node, leaf_counter=0)
node_splitter.is_splittable(root_node, leaf_counter=0)
Loading