Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions smarttree/_cgini_index.pyi

This file was deleted.

47 changes: 0 additions & 47 deletions smarttree/_cgini_index.pyx

This file was deleted.

77 changes: 11 additions & 66 deletions smarttree/_column_splitter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import math
from abc import ABC, abstractmethod
from collections.abc import Generator
from copy import deepcopy
Expand All @@ -10,10 +9,10 @@
import pandas as pd
from numpy.typing import NDArray

from ._cgini_index import cgini_index
from ._cy_column_splitter import CyBaseColumnSplitter
from ._dataset import Dataset
from ._tree import TreeNode
from ._types import ClassificationCriterionType, NaModeType
from ._types import ClassificationCriterionType, Criterion, NaModeType


NO_INFORMATION_GAIN = float("-inf")
Expand All @@ -37,6 +36,12 @@ def no_split(cls) -> ColumnSplitResult:

class BaseColumnSplitter(ABC):

mapping: dict[ClassificationCriterionType, Criterion] = {
"gini": Criterion.GINI,
"entropy": Criterion.ENTROPY,
"log_loss": Criterion.LOG_LOSS,
}

def __init__(
self,
dataset: Dataset,
Expand All @@ -47,17 +52,11 @@ def __init__(
) -> None:

self.dataset = dataset
self.criterion = criterion
self.criterion = self.mapping[criterion]
self.min_samples_split = min_samples_split
self.min_samples_leaf = min_samples_leaf
self.feature_na_mode = feature_na_mode

match self.criterion:
case "gini":
self.impurity = self.gini_index
case "entropy" | "log_loss":
self.impurity = self.entropy

@abstractmethod
def split(self, *args, **kwargs) -> ColumnSplitResult:
raise NotImplementedError
Expand Down Expand Up @@ -168,62 +167,8 @@ def information_gain(
\item $\text{impurity}_{\text{child}_i}$ — child node impurity.
\end{itemize}
"""
N = self.dataset.size
N_parent = parent_mask.sum()

impurity_parent = self.impurity(parent_mask)

weighted_impurity_childs = 0
N_childs = 0
for child_mask_i in child_masks:
N_child_i = child_mask_i.sum()
N_childs += N_child_i
impurity_child_i = self.impurity(child_mask_i)
weighted_impurity_childs += (N_child_i / N_parent) * impurity_child_i

if normalize:
norm_coef = N_parent / N_childs
weighted_impurity_childs *= norm_coef

local_information_gain = impurity_parent - weighted_impurity_childs

information_gain = (N_parent / N) * local_information_gain

return information_gain

def gini_index(self, mask: pd.Series) -> float:
r"""
Calculates Gini index in a tree node.

Gini index formula in LaTeX:
\text{Gini Index} = 1 - \sum^C_{i=1} p_i^2
where
C - total number of classes;
p_i - the probability of choosing a sample with class i.
"""
return cgini_index(mask, self.dataset.y, self.dataset.class_names)

def entropy(self, mask: pd.Series) -> float:
r"""
Calculates entropy in a tree node.

Entropy formula in LaTeX:
H = \log{\overline{N}} = \sum^N_{i=1} p_i \log{(1/p_i)} = -\sum^N_{i=1} p_i \log{p_i}
where
H - entropy;
\overline{N} - effective number of states;
p_i - probability of the i-th system state.
"""
N = mask.sum()

entropy = 0
for label in self.dataset.class_names:
N_i = (mask & (self.dataset.y == label)).sum()
if N_i != 0:
p_i = N_i / N
entropy -= p_i * math.log2(p_i)

return entropy
cs = CyBaseColumnSplitter(self.dataset, self.criterion)
return cs.information_gain(parent_mask, child_masks, normalize)


class NumColumnSplitter(BaseColumnSplitter):
Expand Down
78 changes: 78 additions & 0 deletions smarttree/_cy_column_splitter.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy as np
import pandas as pd
from numpy.typing import NDArray

from ._dataset import Dataset
from ._types import Criterion

class CyBaseColumnSplitter:

def __init__(self, dataset: Dataset, criterion: Criterion) -> None:
...

def information_gain(
self,
parent_mask: pd.Series,
child_masks: list[pd.Series],
normalize: bool = False,
) -> float:
r"""
Calculates information gain of the split.

Parameters:
parent_mask: pd.Series
boolean mask of parent node.
child_masks: pd.Series
list of boolean masks of child nodes.
normalize: bool, default=False
if True, normalizes information gain by split factor to handle
unbalanced splits. Uses child node counts for normalization.

Returns:
float: information gain.

Formula in LaTeX:
\begin{align*}
\text{Information Gain} =
\frac{N_{\text{parent}}}{N} \cdot
\Biggl( & \text{impurity}_{\text{parent}} - \\
& \sum^C_{i=1} \frac{N_{\text{child}_i}}{N_{\text{parent}}}
\cdot \text{impurity}_{\text{child}_i} \Biggr)
\end{align*}
where:
\begin{itemize}
\item $\text{Information Gain}$ — information gain;
\item $N$ — number of samples in entire training set;
\item $N_{\text{parent}}$ — number of samples in parent node;
\item $\text{impurity}_{\text{parent}}$ — parent node impurity;
\item $C$ — number of child nodes;
\item $N_{\text{child}_i}$ — number of samples in child node;
\item $\text{impurity}_{\text{child}_i}$ — child node impurity.
\end{itemize}
"""
...

def gini_index(self, mask: NDArray[np.int8]) -> float:
r"""
Calculates Gini index in a tree node.

Gini index formula in LaTeX:
\text{Gini Index} = 1 - \sum^C_{i=1} p_i^2
where
C - total number of classes;
p_i - the probability of choosing a sample with class i.
"""
...

def entropy(self, mask: NDArray[np.int8]) -> float:
r"""
Calculates entropy in a tree node.

Entropy formula in LaTeX:
H = \log{\overline{N}} = \sum^N_{i=1} p_i \log{(1/p_i)} = -\sum^N_{i=1} p_i \log{p_i}
where
H - entropy;
\overline{N} - effective number of states;
p_i - probability of the i-th system state.
"""
...
Loading
Loading