Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
512 changes: 222 additions & 290 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.poetry]
name = "smarttree"
version = "0.1.0"
description = ""
description = "Custom Decision Tree with bells and whistles"
authors = ["Mikhail Martin <mikhailmartin95@yandex.ru>"]
readme = "README.md"

Expand Down
2 changes: 1 addition & 1 deletion smarttree/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._classes import BaseSmartDecisionTree, SmartDecisionTreeClassifier
from ._tree_node import TreeNode
from ._tree import TreeNode


__all__ = [
Expand Down
76 changes: 25 additions & 51 deletions smarttree/_builder.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import bisect
import math
from collections import defaultdict

import numpy as np
import pandas as pd
from numpy.typing import NDArray

from ._node_splitter import NodeSplitter
from ._tree_node import TreeNode
from ._tree import Tree, TreeNode
from ._types import ClassificationCriterionType


Expand Down Expand Up @@ -35,11 +35,9 @@ def __init__(
self.impurity = self.entropy

if self.criterion in ("gini", "entropy", "log_loss"):
self.class_names = sorted(self.y.unique())
self.class_names = np.sort(self.y.unique())

self.node_counter: int = 0

def build(self) -> tuple[TreeNode, defaultdict[str, float]]:
def build(self, tree: Tree) -> None:

for value in self.hierarchy.values():
if isinstance(value, list):
Expand All @@ -48,25 +46,27 @@ def build(self) -> tuple[TreeNode, defaultdict[str, float]]:
else: # str
self.available_features.remove(value)

root = self.create_node(
mask=self.y.apply(lambda x: True),
mask = self.y.apply(lambda x: True)
root = tree.create_node(
mask=mask,
hierarchy=self.hierarchy,
distribution=self.distribution(mask),
impurity=self.impurity(mask),
label=self.y[mask].mode()[0],
available_features=self.available_features,
depth=0,
is_root=True,
)

splittable_leaf_nodes: list[TreeNode] = []
feature_importances: defaultdict[str, float] = defaultdict(float)

if self.splitter.is_splittable(root):
if self.splitter.is_splittable(root, tree.leaf_counter):
splittable_leaf_nodes.append(root)

while (
len(splittable_leaf_nodes) > 0
and self.splitter.leaf_counter < self.max_leaf_nodes
):
while len(splittable_leaf_nodes) > 0 and tree.leaf_counter < self.max_leaf_nodes:

node = splittable_leaf_nodes.pop()
feature_importances[node.split_feature] += node.information_gain
tree.feature_importances[node.split_feature] += node.information_gain

for child_mask, feature_value in zip(node.child_masks, node.feature_values):
# add opened features
Expand All @@ -77,59 +77,33 @@ def build(self) -> tuple[TreeNode, defaultdict[str, float]]:
else: # str
node.available_features.append(value)

child_node = self.create_node(
child_node = tree.create_node(
mask=child_mask,
hierarchy=node.hierarchy,
distribution=self.distribution(child_mask),
impurity=self.impurity(child_mask),
label=self.y[child_mask].mode()[0],
available_features=node.available_features,
depth=node.depth + 1,
depth=node.depth+1,
)
child_node.feature_value = feature_value
self.splitter.leaf_counter += 1

node.childs.append(child_node)
if self.splitter.is_splittable(child_node):
if self.splitter.is_splittable(child_node, tree.leaf_counter):
bisect.insort(
splittable_leaf_nodes,
child_node,
key=lambda n: n.information_gain,
)

node.is_leaf = False
self.splitter.leaf_counter -= 1
tree.leaf_counter -= 1

return root, feature_importances

def create_node(
self,
mask: pd.Series,
hierarchy: dict[str, str | list[str]],
available_features: list[str],
depth: int,
) -> TreeNode:
"""Creates a node of the tree."""
tree_node = TreeNode(
number=self.node_counter,
num_samples=mask.sum(),
distribution=self.distribution(mask),
impurity=self.impurity(mask),
label=self.y[mask].mode()[0],
depth=depth,
mask=mask,
hierarchy=hierarchy.copy(),
available_features=available_features.copy(),
)
self.node_counter += 1
return tree_node

def distribution(self, mask: pd.Series) -> np.ndarray:
"""Calculates the class distribution."""
distribution = np.array([
(mask & (self.y == class_name)).sum()
for class_name in self.class_names
def distribution(self, mask: pd.Series) -> NDArray[np.integer]:
return np.array([
(mask & (self.y == class_name)).sum() for class_name in self.class_names
])

return distribution

def gini_index(self, mask: pd.Series) -> float:
r"""
Calculates Gini index in a tree node.
Expand Down
23 changes: 11 additions & 12 deletions smarttree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ._exceptions import NotFittedError
from ._node_splitter import NodeSplitter
from ._renderer import Renderer
from ._tree_node import TreeNode
from ._tree import Tree, TreeNode
from ._types import (
CatNaModeType,
ClassificationCriterionType,
Expand Down Expand Up @@ -118,7 +118,7 @@ def __init__(
self.logger.setLevel(verbose)

self._is_fitted: bool = False
self._root: TreeNode | None = None
self._tree: Tree | None = None
self._feature_importances: dict = dict()
self._feature_na_filler: dict[str, int | float | str] = dict()

Expand Down Expand Up @@ -188,15 +188,15 @@ def feature_na_mode(self) -> dict[str, NaModeType | None]:
return self.__feature_na_mode

@property
def tree(self) -> TreeNode:
def tree_(self) -> Tree:
self._check_is_fitted()
assert self._root is not None
return self._root
assert self._tree is not None
return self._tree

@property
def feature_importances_(self) -> dict[str, float]:
self._check_is_fitted()
return self._feature_importances
return self.tree_.feature_importances

@abstractmethod
def fit(self, X: pd.DataFrame, y: pd.Series) -> Self:
Expand Down Expand Up @@ -577,6 +577,8 @@ def fit(self, X: pd.DataFrame, y: pd.Series) -> Self:
feature_na_mode=self.feature_na_mode,
)

self._tree = Tree()

builder = Builder(
X=X,
y=y,
Expand All @@ -585,10 +587,7 @@ def fit(self, X: pd.DataFrame, y: pd.Series) -> Self:
max_leaf_nodes=max_leaf_nodes,
hierarchy=self.hierarchy,
)
root, feature_importances = builder.build()

self._root = root
self._feature_importances = feature_importances
builder.build(self._tree)

self._is_fitted = True

Expand Down Expand Up @@ -627,7 +626,7 @@ def predict_proba(self, X: pd.DataFrame) -> NDArray[np.floating]:
X = self.__preprocess(X)

distributions = np.array([
self.__get_distribution(self.tree, point) for _, point in X.iterrows()
self.__get_distribution(self.tree_.root, point) for _, point in X.iterrows()
])

return distributions / distributions.sum(axis=1, keepdims=True)
Expand Down Expand Up @@ -733,7 +732,7 @@ def render(
"""
renderer = Renderer(criterion=self.criterion, rounded=rounded)
graph = renderer.render(
tree=self.tree,
root=self.tree_.root,
show_impurity=show_impurity,
show_num_samples=show_num_samples,
show_distribution=show_distribution,
Expand Down
10 changes: 5 additions & 5 deletions smarttree/_column_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

import numpy as np
import pandas as pd
from numpy.typing import NDArray

from ._dataset import Dataset
from ._tree_node import TreeNode
from ._tree import TreeNode
from ._types import ClassificationCriterionType, NaModeType


Expand Down Expand Up @@ -202,16 +203,15 @@ def split(self, node: TreeNode, split_feature: str) -> ColumnSplitResult:

return best_split_result

def __get_thresholds(self, array: np.ndarray) -> np.ndarray:
def __get_thresholds(self, array: NDArray) -> NDArray:

array.sort()
array = np.unique(array)
array = np.sort(np.unique(array))
thresholds = np.array([]) if len(array) <= 1 else self.__moving_average(array)

return thresholds

@staticmethod
def __moving_average(array: np.ndarray, window: int = 2) -> np.ndarray:
def __moving_average(array: NDArray, window: int = 2) -> NDArray:
return np.convolve(array, np.ones(window), mode="valid") / window

def __num_split(
Expand Down
11 changes: 5 additions & 6 deletions smarttree/_node_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ._column_splitter import CatColumnSplitter, NumColumnSplitter, RankColumnSplitter
from ._dataset import Dataset
from ._tree_node import TreeNode
from ._tree import TreeNode
from ._types import ClassificationCriterionType, NaModeType, SplitType


Expand Down Expand Up @@ -49,7 +49,6 @@ def __init__(
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.min_impurity_decrease = min_impurity_decrease
self.leaf_counter: int = 0

self.feature_split_type: dict[str, SplitType] = dict()
for num_feature in num_features:
Expand Down Expand Up @@ -85,7 +84,7 @@ def __init__(
feature_na_mode=feature_na_mode,
)

def is_splittable(self, node: TreeNode) -> bool:
def is_splittable(self, node: TreeNode, leaf_counter: int) -> bool:
"""
Checks whether a tree node can be split.

Expand All @@ -97,7 +96,7 @@ def is_splittable(self, node: TreeNode) -> bool:
if node.num_samples < self.min_samples_split:
return False

split_result = self.find_best_split_for(node)
split_result = self.find_best_split_for(node, leaf_counter)
if split_result.information_gain >= self.min_impurity_decrease:
node.information_gain = split_result.information_gain
node.split_type = split_result.split_type
Expand All @@ -108,7 +107,7 @@ def is_splittable(self, node: TreeNode) -> bool:
else:
return False

def find_best_split_for(self, node: TreeNode) -> NodeSplitResult:
def find_best_split_for(self, node: TreeNode, leaf_counter: int) -> NodeSplitResult:

best_split_result = NodeSplitResult.no_split()
for feature in node.available_features:
Expand All @@ -117,7 +116,7 @@ def find_best_split_for(self, node: TreeNode) -> NodeSplitResult:
case "numerical":
split_result = self.num_col_splitter.split(node, feature)
case "categorical":
split_result = self.cat_col_splitter.split(node, feature, self.leaf_counter)
split_result = self.cat_col_splitter.split(node, feature, leaf_counter)
case "rank":
split_result = self.rank_col_splitter.split(node, feature)

Expand Down
6 changes: 3 additions & 3 deletions smarttree/_renderer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from graphviz import Digraph

from ._tree_node import TreeNode
from ._tree import TreeNode
from ._types import ClassificationCriterionType


Expand All @@ -15,7 +15,7 @@ def __init__(self, rounded: bool, criterion: ClassificationCriterionType) -> Non

def render(
self,
tree: TreeNode,
root: TreeNode,
*,
show_impurity: bool = False,
show_num_samples: bool = False,
Expand All @@ -24,7 +24,7 @@ def render(
**kwargs,
) -> Digraph:
self.__add_node(
node=tree,
node=root,
parent_name=None,
show_impurity=show_impurity,
show_num_samples=show_num_samples,
Expand Down
Loading
Loading