From 017d94c5a459b0b63cb344f744e337d9d4ae45e9 Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Tue, 23 Sep 2025 00:36:35 +0300 Subject: [PATCH 1/8] CyBaseColumnSplitter --- smarttree/_cgini_index.pyi | 4 --- smarttree/_cgini_index.pyx | 47 ----------------------------- smarttree/_column_splitter.py | 6 ++-- smarttree/_cy_column_splitter.pyi | 7 +++++ smarttree/_cy_column_splitter.pyx | 50 +++++++++++++++++++++++++++++++ 5 files changed, 61 insertions(+), 53 deletions(-) delete mode 100644 smarttree/_cgini_index.pyi delete mode 100644 smarttree/_cgini_index.pyx create mode 100644 smarttree/_cy_column_splitter.pyi create mode 100644 smarttree/_cy_column_splitter.pyx diff --git a/smarttree/_cgini_index.pyi b/smarttree/_cgini_index.pyi deleted file mode 100644 index 2e4e34b..0000000 --- a/smarttree/_cgini_index.pyi +++ /dev/null @@ -1,4 +0,0 @@ -import pandas as pd -from numpy.typing import NDArray - -def cgini_index(mask: pd.Series, y: pd.Series, class_names: NDArray) -> float: ... diff --git a/smarttree/_cgini_index.pyx b/smarttree/_cgini_index.pyx deleted file mode 100644 index 2e97138..0000000 --- a/smarttree/_cgini_index.pyx +++ /dev/null @@ -1,47 +0,0 @@ -cimport cython -from libc.stdint cimport int8_t -import numpy as np - - -@cython.boundscheck(False) -@cython.wraparound(False) -@cython.cdivision(True) -def cgini_index(mask, y, class_names): - - cdef int8_t[:] mask_arr = mask.values.astype(np.int8) - cdef object[:] y_arr = y.values - cdef long N = 0 - cdef long N_i = 0 - cdef double p_i = 0.0 - cdef double gini_index = 1.0 - cdef int i - cdef int j - cdef int n = len(mask) - cdef int n_classes = len(class_names) - cdef object class_name - cdef object label - cdef int8_t mask_value - - for i in range(n): - mask_value = mask_arr[i] - if mask_value: - N += 1 - - if N == 0: - return 0.0 - - for j in range(n_classes): - N_i = 0 - class_name = class_names[j] - - for i in range(n): - mask_value = mask_arr[i] - if mask_value: - label = y_arr[i] - if label == class_name: - N_i += 1 - - p_i = N_i / N - gini_index -= p_i * p_i - - return gini_index diff --git a/smarttree/_column_splitter.py b/smarttree/_column_splitter.py index e245624..abc1a5e 100644 --- a/smarttree/_column_splitter.py +++ b/smarttree/_column_splitter.py @@ -10,7 +10,7 @@ import pandas as pd from numpy.typing import NDArray -from ._cgini_index import cgini_index +from ._cy_column_splitter import CyBaseColumnSplitter from ._dataset import Dataset from ._tree import TreeNode from ._types import ClassificationCriterionType, NaModeType @@ -201,7 +201,9 @@ def gini_index(self, mask: pd.Series) -> float: C - total number of classes; p_i - the probability of choosing a sample with class i. """ - return cgini_index(mask, self.dataset.y, self.dataset.class_names) + return CyBaseColumnSplitter.gini_index( + mask, self.dataset.y, self.dataset.class_names + ) def entropy(self, mask: pd.Series) -> float: r""" diff --git a/smarttree/_cy_column_splitter.pyi b/smarttree/_cy_column_splitter.pyi new file mode 100644 index 0000000..fe7ed90 --- /dev/null +++ b/smarttree/_cy_column_splitter.pyi @@ -0,0 +1,7 @@ +import pandas as pd +from numpy.typing import NDArray + +class CyBaseColumnSplitter: + @staticmethod + def gini_index(mask: pd.Series, y: pd.Series, class_names: NDArray) -> float: + ... diff --git a/smarttree/_cy_column_splitter.pyx b/smarttree/_cy_column_splitter.pyx new file mode 100644 index 0000000..ad84df3 --- /dev/null +++ b/smarttree/_cy_column_splitter.pyx @@ -0,0 +1,50 @@ +cimport cython +from libc.stdint cimport int8_t +import numpy as np + + +cdef class CyBaseColumnSplitter: + + @staticmethod + @cython.boundscheck(False) + @cython.wraparound(False) + @cython.cdivision(True) + def gini_index(mask, y, class_names): + + cdef int8_t[:] mask_arr = mask.values.astype(np.int8) + cdef object[:] y_arr = y.values + cdef long N = 0 + cdef long N_i = 0 + cdef double p_i = 0.0 + cdef double gini_index = 1.0 + cdef int i + cdef int j + cdef int n = len(mask) + cdef int n_classes = len(class_names) + cdef object class_name + cdef object label + cdef int8_t mask_value + + for i in range(n): + mask_value = mask_arr[i] + if mask_value: + N += 1 + + if N == 0: + return 0.0 + + for j in range(n_classes): + N_i = 0 + class_name = class_names[j] + + for i in range(n): + mask_value = mask_arr[i] + if mask_value: + label = y_arr[i] + if label == class_name: + N_i += 1 + + p_i = N_i / N + gini_index -= p_i * p_i + + return gini_index From c183c2c328059cc8df4e9ac27f394c90401cc10c Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Tue, 23 Sep 2025 01:10:56 +0300 Subject: [PATCH 2/8] CyBaseColumnSplitter.entropy() --- smarttree/_column_splitter.py | 14 ++------- smarttree/_cy_column_splitter.pyi | 10 +++++++ smarttree/_cy_column_splitter.pyx | 50 +++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 11 deletions(-) diff --git a/smarttree/_column_splitter.py b/smarttree/_column_splitter.py index abc1a5e..9be5e0c 100644 --- a/smarttree/_column_splitter.py +++ b/smarttree/_column_splitter.py @@ -1,6 +1,5 @@ from __future__ import annotations -import math from abc import ABC, abstractmethod from collections.abc import Generator from copy import deepcopy @@ -216,16 +215,9 @@ def entropy(self, mask: pd.Series) -> float: \overline{N} - effective number of states; p_i - probability of the i-th system state. """ - N = mask.sum() - - entropy = 0 - for label in self.dataset.class_names: - N_i = (mask & (self.dataset.y == label)).sum() - if N_i != 0: - p_i = N_i / N - entropy -= p_i * math.log2(p_i) - - return entropy + return CyBaseColumnSplitter.entropy( + mask, self.dataset.y, self.dataset.class_names + ) class NumColumnSplitter(BaseColumnSplitter): diff --git a/smarttree/_cy_column_splitter.pyi b/smarttree/_cy_column_splitter.pyi index fe7ed90..b6d3698 100644 --- a/smarttree/_cy_column_splitter.pyi +++ b/smarttree/_cy_column_splitter.pyi @@ -1,7 +1,17 @@ import pandas as pd from numpy.typing import NDArray +from ._types import ClassificationCriterionType + class CyBaseColumnSplitter: + + def __init__(self, criterion: ClassificationCriterionType) -> None: + ... + @staticmethod def gini_index(mask: pd.Series, y: pd.Series, class_names: NDArray) -> float: ... + + @staticmethod + def entropy(mask: pd.Series, y: pd.Series, class_names: NDArray) -> float: + ... diff --git a/smarttree/_cy_column_splitter.pyx b/smarttree/_cy_column_splitter.pyx index ad84df3..1449075 100644 --- a/smarttree/_cy_column_splitter.pyx +++ b/smarttree/_cy_column_splitter.pyx @@ -1,10 +1,18 @@ cimport cython from libc.stdint cimport int8_t +import math + import numpy as np cdef class CyBaseColumnSplitter: + def __init__(self, criterion): + if criterion == "gini": + self.impurity = self.gini_index + elif criterion in ("entropy", "log_loss"): + self.impurity = self.entropy + @staticmethod @cython.boundscheck(False) @cython.wraparound(False) @@ -48,3 +56,45 @@ cdef class CyBaseColumnSplitter: gini_index -= p_i * p_i return gini_index + + @staticmethod + @cython.boundscheck(False) + @cython.wraparound(False) + @cython.cdivision(True) + def entropy(mask, y, class_names): + + cdef int8_t[:] mask_arr = mask.values.astype(np.int8) + cdef object[:] y_arr = y.values + cdef long N = 0 + cdef long N_i = 0 + cdef double p_i = 0.0 + cdef double entropy = 0.0 + cdef int i + cdef int j + cdef int n = len(mask) + cdef int n_classes = len(class_names) + cdef object class_name + cdef object label + cdef int8_t mask_value + + for i in range(n): + mask_value = mask_arr[i] + if mask_value: + N += 1 + + for j in range(n_classes): + N_i = 0 + class_name = class_names[j] + + for i in range(n): + mask_value = mask_arr[i] + if mask_value: + label = y_arr[i] + if label == class_name: + N_i += 1 + + if N_i != 0: + p_i = N_i / N + entropy -= p_i * math.log2(p_i) + + return entropy From 0123066c4e8f13e432ac4ff5ee795f81239cbe18 Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Tue, 23 Sep 2025 13:14:00 +0300 Subject: [PATCH 3/8] remove n_classes --- smarttree/_cy_column_splitter.pyx | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/smarttree/_cy_column_splitter.pyx b/smarttree/_cy_column_splitter.pyx index 1449075..6939d34 100644 --- a/smarttree/_cy_column_splitter.pyx +++ b/smarttree/_cy_column_splitter.pyx @@ -28,7 +28,6 @@ cdef class CyBaseColumnSplitter: cdef int i cdef int j cdef int n = len(mask) - cdef int n_classes = len(class_names) cdef object class_name cdef object label cdef int8_t mask_value @@ -41,7 +40,7 @@ cdef class CyBaseColumnSplitter: if N == 0: return 0.0 - for j in range(n_classes): + for j in range(len(class_names)): N_i = 0 class_name = class_names[j] @@ -72,7 +71,6 @@ cdef class CyBaseColumnSplitter: cdef int i cdef int j cdef int n = len(mask) - cdef int n_classes = len(class_names) cdef object class_name cdef object label cdef int8_t mask_value @@ -82,7 +80,7 @@ cdef class CyBaseColumnSplitter: if mask_value: N += 1 - for j in range(n_classes): + for j in range(len(class_names)): N_i = 0 class_name = class_names[j] From 3e2bfe738d6ba0f7f940b61c267fb542204f4182 Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Tue, 23 Sep 2025 14:02:36 +0300 Subject: [PATCH 4/8] refactoring --- smarttree/_cy_column_splitter.pyx | 59 ++++++++++++++++--------------- 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/smarttree/_cy_column_splitter.pyx b/smarttree/_cy_column_splitter.pyx index 6939d34..8106106 100644 --- a/smarttree/_cy_column_splitter.pyx +++ b/smarttree/_cy_column_splitter.pyx @@ -1,6 +1,6 @@ cimport cython +from libc.math cimport log2 from libc.stdint cimport int8_t -import math import numpy as np @@ -19,27 +19,26 @@ cdef class CyBaseColumnSplitter: @cython.cdivision(True) def gini_index(mask, y, class_names): - cdef int8_t[:] mask_arr = mask.values.astype(np.int8) - cdef object[:] y_arr = y.values - cdef long N = 0 - cdef long N_i = 0 - cdef double p_i = 0.0 - cdef double gini_index = 1.0 - cdef int i - cdef int j - cdef int n = len(mask) - cdef object class_name - cdef object label - cdef int8_t mask_value + cdef: + int8_t[:] mask_arr = mask.values.astype(np.int8) + object[:] y_arr = y.values + cdef: + int i + Py_ssize_t n = len(mask) + int8_t mask_value + long N = 0 for i in range(n): mask_value = mask_arr[i] if mask_value: N += 1 - if N == 0: - return 0.0 - + cdef: + int j + long N_i = 0 + object class_name, label + double p_i = 0.0 + gini_index = 1.0 for j in range(len(class_names)): N_i = 0 class_name = class_names[j] @@ -62,24 +61,26 @@ cdef class CyBaseColumnSplitter: @cython.cdivision(True) def entropy(mask, y, class_names): - cdef int8_t[:] mask_arr = mask.values.astype(np.int8) - cdef object[:] y_arr = y.values - cdef long N = 0 - cdef long N_i = 0 - cdef double p_i = 0.0 - cdef double entropy = 0.0 - cdef int i - cdef int j - cdef int n = len(mask) - cdef object class_name - cdef object label - cdef int8_t mask_value + cdef: + int8_t[:] mask_arr = mask.values.astype(np.int8) + object[:] y_arr = y.values + cdef: + int i + Py_ssize_t n = len(mask) + int8_t mask_value + long N = 0 for i in range(n): mask_value = mask_arr[i] if mask_value: N += 1 + cdef: + int j + long N_i = 0 + object class_name, label + double p_i = 0.0 + entropy = 0.0 for j in range(len(class_names)): N_i = 0 class_name = class_names[j] @@ -93,6 +94,6 @@ cdef class CyBaseColumnSplitter: if N_i != 0: p_i = N_i / N - entropy -= p_i * math.log2(p_i) + entropy -= p_i * log2(p_i) return entropy From 6c0c15ab99bb2a9b2c561e45a4850ebc292173ba Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Thu, 25 Sep 2025 01:44:35 +0300 Subject: [PATCH 5/8] cythonize .information_gain() --- smarttree/_column_splitter.py | 59 +------ smarttree/_cy_column_splitter.pyi | 87 ++++++++- smarttree/_cy_column_splitter.pyx | 165 ++++++++++++++---- .../test__base_column_splitter.py | 12 -- .../test__cy_base_column_splitter.py | 27 +++ tests/conftest.py | 6 + 6 files changed, 251 insertions(+), 105 deletions(-) create mode 100644 tests/column_splitter/test__cy_base_column_splitter.py diff --git a/smarttree/_column_splitter.py b/smarttree/_column_splitter.py index 9be5e0c..809f4f2 100644 --- a/smarttree/_column_splitter.py +++ b/smarttree/_column_splitter.py @@ -51,12 +51,6 @@ def __init__( self.min_samples_leaf = min_samples_leaf self.feature_na_mode = feature_na_mode - match self.criterion: - case "gini": - self.impurity = self.gini_index - case "entropy" | "log_loss": - self.impurity = self.entropy - @abstractmethod def split(self, *args, **kwargs) -> ColumnSplitResult: raise NotImplementedError @@ -167,57 +161,8 @@ def information_gain( \item $\text{impurity}_{\text{child}_i}$ — child node impurity. \end{itemize} """ - N = self.dataset.size - N_parent = parent_mask.sum() - - impurity_parent = self.impurity(parent_mask) - - weighted_impurity_childs = 0 - N_childs = 0 - for child_mask_i in child_masks: - N_child_i = child_mask_i.sum() - N_childs += N_child_i - impurity_child_i = self.impurity(child_mask_i) - weighted_impurity_childs += (N_child_i / N_parent) * impurity_child_i - - if normalize: - norm_coef = N_parent / N_childs - weighted_impurity_childs *= norm_coef - - local_information_gain = impurity_parent - weighted_impurity_childs - - information_gain = (N_parent / N) * local_information_gain - - return information_gain - - def gini_index(self, mask: pd.Series) -> float: - r""" - Calculates Gini index in a tree node. - - Gini index formula in LaTeX: - \text{Gini Index} = 1 - \sum^C_{i=1} p_i^2 - where - C - total number of classes; - p_i - the probability of choosing a sample with class i. - """ - return CyBaseColumnSplitter.gini_index( - mask, self.dataset.y, self.dataset.class_names - ) - - def entropy(self, mask: pd.Series) -> float: - r""" - Calculates entropy in a tree node. - - Entropy formula in LaTeX: - H = \log{\overline{N}} = \sum^N_{i=1} p_i \log{(1/p_i)} = -\sum^N_{i=1} p_i \log{p_i} - where - H - entropy; - \overline{N} - effective number of states; - p_i - probability of the i-th system state. - """ - return CyBaseColumnSplitter.entropy( - mask, self.dataset.y, self.dataset.class_names - ) + cs = CyBaseColumnSplitter(self.dataset, self.criterion) + return cs.information_gain(parent_mask, child_masks, normalize) class NumColumnSplitter(BaseColumnSplitter): diff --git a/smarttree/_cy_column_splitter.pyi b/smarttree/_cy_column_splitter.pyi index b6d3698..f04b74d 100644 --- a/smarttree/_cy_column_splitter.pyi +++ b/smarttree/_cy_column_splitter.pyi @@ -1,17 +1,94 @@ +import numpy as np import pandas as pd from numpy.typing import NDArray +from ._dataset import Dataset from ._types import ClassificationCriterionType class CyBaseColumnSplitter: - def __init__(self, criterion: ClassificationCriterionType) -> None: + criterion: ClassificationCriterionType + + def __init__( + self, + dataset: Dataset, + criterion: ClassificationCriterionType, + ) -> None: + ... + + def information_gain( + self, + parent_mask: pd.Series, + child_masks: list[pd.Series], + normalize: bool = False, + ) -> float: + r""" + Calculates information gain of the split. + + Parameters: + parent_mask: pd.Series + boolean mask of parent node. + child_masks: pd.Series + list of boolean masks of child nodes. + normalize: bool, default=False + if True, normalizes information gain by split factor to handle + unbalanced splits. Uses child node counts for normalization. + + Returns: + float: information gain. + + Formula in LaTeX: + \begin{align*} + \text{Information Gain} = + \frac{N_{\text{parent}}}{N} \cdot + \Biggl( & \text{impurity}_{\text{parent}} - \\ + & \sum^C_{i=1} \frac{N_{\text{child}_i}}{N_{\text{parent}}} + \cdot \text{impurity}_{\text{child}_i} \Biggr) + \end{align*} + where: + \begin{itemize} + \item $\text{Information Gain}$ — information gain; + \item $N$ — number of samples in entire training set; + \item $N_{\text{parent}}$ — number of samples in parent node; + \item $\text{impurity}_{\text{parent}}$ — parent node impurity; + \item $C$ — number of child nodes; + \item $N_{\text{child}_i}$ — number of samples in child node; + \item $\text{impurity}_{\text{child}_i}$ — child node impurity. + \end{itemize} + """ ... - @staticmethod - def gini_index(mask: pd.Series, y: pd.Series, class_names: NDArray) -> float: + def gini_index( + self, + mask: NDArray[np.int8], + y: NDArray[np.object_], + class_names: NDArray[np.object_], + ) -> float: + r""" + Calculates Gini index in a tree node. + + Gini index formula in LaTeX: + \text{Gini Index} = 1 - \sum^C_{i=1} p_i^2 + where + C - total number of classes; + p_i - the probability of choosing a sample with class i. + """ ... - @staticmethod - def entropy(mask: pd.Series, y: pd.Series, class_names: NDArray) -> float: + def entropy( + self, + mask: NDArray[np.int8], + y: NDArray[np.object_], + class_names: NDArray[np.object_], + ) -> float: + r""" + Calculates entropy in a tree node. + + Entropy formula in LaTeX: + H = \log{\overline{N}} = \sum^N_{i=1} p_i \log{(1/p_i)} = -\sum^N_{i=1} p_i \log{p_i} + where + H - entropy; + \overline{N} - effective number of states; + p_i - probability of the i-th system state. + """ ... diff --git a/smarttree/_cy_column_splitter.pyx b/smarttree/_cy_column_splitter.pyx index 8106106..a3332bc 100644 --- a/smarttree/_cy_column_splitter.pyx +++ b/smarttree/_cy_column_splitter.pyx @@ -3,40 +3,142 @@ from libc.math cimport log2 from libc.stdint cimport int8_t import numpy as np +import pandas as pd cdef class CyBaseColumnSplitter: - def __init__(self, criterion): - if criterion == "gini": - self.impurity = self.gini_index - elif criterion in ("entropy", "log_loss"): - self.impurity = self.entropy - - @staticmethod - @cython.boundscheck(False) - @cython.wraparound(False) - @cython.cdivision(True) - def gini_index(mask, y, class_names): + cdef public object dataset + cdef public str criterion + + def __init__(self, dataset, criterion) -> None: + self.dataset = dataset + self.criterion = criterion + + cdef double impurity(self, int8_t[:] mask, object[:] y, object[:] class_names): + if self.criterion == "gini": + return self.gini_index(mask, y, class_names) + elif self.criterion in ("entropy", "log_loss"): + return self.entropy(mask, y, class_names) + else: + assert False + + def information_gain( + self, + parent_mask: pd.Series, + child_masks: list[pd.Series], + normalize: bool = False, + ) -> float: + r""" + Calculates information gain of the split. + + Parameters: + parent_mask: pd.Series + boolean mask of parent node. + child_masks: pd.Series + list of boolean masks of child nodes. + normalize: bool, default=False + if True, normalizes information gain by split factor to handle + unbalanced splits. Uses child node counts for normalization. + + Returns: + float: information gain. + + Formula in LaTeX: + \begin{align*} + \text{Information Gain} = + \frac{N_{\text{parent}}}{N} \cdot + \Biggl( & \text{impurity}_{\text{parent}} - \\ + & \sum^C_{i=1} \frac{N_{\text{child}_i}}{N_{\text{parent}}} + \cdot \text{impurity}_{\text{child}_i} \Biggr) + \end{align*} + where: + \begin{itemize} + \item $\text{Information Gain}$ — information gain; + \item $N$ — number of samples in entire training set; + \item $N_{\text{parent}}$ — number of samples in parent node; + \item $\text{impurity}_{\text{parent}}$ — parent node impurity; + \item $C$ — number of child nodes; + \item $N_{\text{child}_i}$ — number of samples in child node; + \item $\text{impurity}_{\text{child}_i}$ — child node impurity. + \end{itemize} + """ + cdef int8_t[:] parent_mask_arr, child_mask_arr + parent_mask_arr = parent_mask.values.astype(np.int8) + child_mask_arrs = [ + child_mask.values.astype(np.int8) for child_mask in child_masks + ] + cdef object[:] y_arr, class_names_arr + y_arr = self.dataset.y.values + class_names_arr = self.dataset.class_names cdef: - int8_t[:] mask_arr = mask.values.astype(np.int8) - object[:] y_arr = y.values + int i + Py_ssize_t n = len(parent_mask_arr) + long N = 0 + long N_parent = 0 + int8_t parent_mask_value + for i in range(n): + N += 1 + parent_mask_value = parent_mask_arr[i] + if parent_mask_value: + N_parent += 1 + + cdef double impurity_parent = self.impurity(parent_mask_arr, y_arr, class_names_arr) + cdef: + int j + double weighted_impurity_childs = 0.0 + long N_childs = 0 + long N_child_j + int8_t child_mask_value + double impurity_child_i + for j in range(len(child_mask_arrs)): + N_child_j = 0 + child_mask_arr = child_mask_arrs[j] + for i in range(n): + child_mask_value = child_mask_arr[i] + if child_mask_value: + N_child_j += 1 + N_childs += N_child_j + impurity_child_i = self.impurity(child_mask_arr, y_arr, class_names_arr) + weighted_impurity_childs += (N_child_j / N_parent) * impurity_child_i + + cdef double norm_coef + if normalize: + norm_coef = N_parent / N_childs + weighted_impurity_childs *= norm_coef + + cdef double local_information_gain = impurity_parent - weighted_impurity_childs + + cdef double information_gain = (N_parent / N) * local_information_gain + + return information_gain + + cpdef double gini_index(self, int8_t[:] mask, object[:] y, object[:] class_names): + r""" + Calculates Gini index in a tree node. + + Gini index formula in LaTeX: + \text{Gini Index} = 1 - \sum^C_{i=1} p_i^2 + where + C - total number of classes; + p_i - the probability of choosing a sample with class i. + """ cdef: int i Py_ssize_t n = len(mask) int8_t mask_value long N = 0 for i in range(n): - mask_value = mask_arr[i] + mask_value = mask[i] if mask_value: N += 1 cdef: int j - long N_i = 0 - object class_name, label + cdef long N_i + cdef object class_name, label double p_i = 0.0 gini_index = 1.0 for j in range(len(class_names)): @@ -44,9 +146,9 @@ cdef class CyBaseColumnSplitter: class_name = class_names[j] for i in range(n): - mask_value = mask_arr[i] + mask_value = mask[i] if mask_value: - label = y_arr[i] + label = y[i] if label == class_name: N_i += 1 @@ -55,23 +157,24 @@ cdef class CyBaseColumnSplitter: return gini_index - @staticmethod - @cython.boundscheck(False) - @cython.wraparound(False) - @cython.cdivision(True) - def entropy(mask, y, class_names): - - cdef: - int8_t[:] mask_arr = mask.values.astype(np.int8) - object[:] y_arr = y.values - + cpdef double entropy(self, int8_t[:] mask, object[:] y, object[:] class_names): + r""" + Calculates entropy in a tree node. + + Entropy formula in LaTeX: + H = \log{\overline{N}} = \sum^N_{i=1} p_i \log{(1/p_i)} = -\sum^N_{i=1} p_i \log{p_i} + where + H - entropy; + \overline{N} - effective number of states; + p_i - probability of the i-th system state. + """ cdef: int i Py_ssize_t n = len(mask) int8_t mask_value long N = 0 for i in range(n): - mask_value = mask_arr[i] + mask_value = mask[i] if mask_value: N += 1 @@ -86,9 +189,9 @@ cdef class CyBaseColumnSplitter: class_name = class_names[j] for i in range(n): - mask_value = mask_arr[i] + mask_value = mask[i] if mask_value: - label = y_arr[i] + label = y[i] if label == class_name: N_i += 1 diff --git a/tests/column_splitter/test__base_column_splitter.py b/tests/column_splitter/test__base_column_splitter.py index 6a03709..7f7ae7e 100644 --- a/tests/column_splitter/test__base_column_splitter.py +++ b/tests/column_splitter/test__base_column_splitter.py @@ -24,18 +24,6 @@ def split( ) -def test__gini_index(concrete_column_splitter, y): - parent_mask = y.apply(lambda x: True) - gini_index = concrete_column_splitter.gini_index(parent_mask) - assert gini_index == 0.6666591342419322 - - -def test__entropy(concrete_column_splitter, y): - parent_mask = y.apply(lambda x: True) - entropy = concrete_column_splitter.entropy(parent_mask) - assert entropy == 1.584946181877191 - - def test__information_gain(concrete_column_splitter, y): parent_mask = y.apply(lambda x: True) diff --git a/tests/column_splitter/test__cy_base_column_splitter.py b/tests/column_splitter/test__cy_base_column_splitter.py new file mode 100644 index 0000000..e8cebf1 --- /dev/null +++ b/tests/column_splitter/test__cy_base_column_splitter.py @@ -0,0 +1,27 @@ +import numpy as np + +from smarttree._cy_column_splitter import CyBaseColumnSplitter + + +def test__gini_index(dataset): + + cy_base_column_splitter = CyBaseColumnSplitter(dataset=dataset, criterion="gini") + + mask = dataset.y.apply(lambda x: True).values.astype(np.int8) + class_names = np.sort(dataset.y.unique()) + y = dataset.y.values + + gini_index = cy_base_column_splitter.gini_index(mask, y, class_names) + assert gini_index == 0.6666591342419322 + + +def test__entropy(dataset): + + cy_base_column_splitter = CyBaseColumnSplitter(dataset=dataset, criterion="entropy") + + mask = dataset.y.apply(lambda x: True).values.astype(np.int8) + class_names = np.sort(dataset.y.unique()) + y = dataset.y.values + + gini_index = cy_base_column_splitter.entropy(mask, y, class_names) + assert gini_index == 1.584946181877191 diff --git a/tests/conftest.py b/tests/conftest.py index 482b5e0..0c3b604 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ from numpy.typing import NDArray from smarttree import BaseSmartDecisionTree +from smarttree._dataset import Dataset from smarttree._tree import TreeNode from smarttree._types import NaModeType @@ -163,6 +164,11 @@ def y(data) -> pd.Series: return data[TARGET_COL] +@pytest.fixture(scope="session") +def dataset(X, y) -> Dataset: + return Dataset(X, y) + + @pytest.fixture(scope="function") def root_node(X, y): return TreeNode( From c8269722028d7eb820b22dbf62496c2397b9982c Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Thu, 25 Sep 2025 02:03:11 +0300 Subject: [PATCH 6/8] cinit --- smarttree/_cy_column_splitter.pyi | 14 +------ smarttree/_cy_column_splitter.pyx | 39 ++++++++++--------- .../test__cy_base_column_splitter.py | 8 +--- 3 files changed, 24 insertions(+), 37 deletions(-) diff --git a/smarttree/_cy_column_splitter.pyi b/smarttree/_cy_column_splitter.pyi index f04b74d..d578b70 100644 --- a/smarttree/_cy_column_splitter.pyi +++ b/smarttree/_cy_column_splitter.pyi @@ -58,12 +58,7 @@ class CyBaseColumnSplitter: """ ... - def gini_index( - self, - mask: NDArray[np.int8], - y: NDArray[np.object_], - class_names: NDArray[np.object_], - ) -> float: + def gini_index(self, mask: NDArray[np.int8]) -> float: r""" Calculates Gini index in a tree node. @@ -75,12 +70,7 @@ class CyBaseColumnSplitter: """ ... - def entropy( - self, - mask: NDArray[np.int8], - y: NDArray[np.object_], - class_names: NDArray[np.object_], - ) -> float: + def entropy(self, mask: NDArray[np.int8]) -> float: r""" Calculates entropy in a tree node. diff --git a/smarttree/_cy_column_splitter.pyx b/smarttree/_cy_column_splitter.pyx index a3332bc..27af703 100644 --- a/smarttree/_cy_column_splitter.pyx +++ b/smarttree/_cy_column_splitter.pyx @@ -5,21 +5,25 @@ from libc.stdint cimport int8_t import numpy as np import pandas as pd +from ._dataset import Dataset + cdef class CyBaseColumnSplitter: - cdef public object dataset cdef public str criterion + cdef public object[:] y + cdef public object[:] class_names - def __init__(self, dataset, criterion) -> None: - self.dataset = dataset + def __cinit__(self, dataset: Dataset, criterion: str) -> None: self.criterion = criterion + self.y = dataset.y.values + self.class_names = dataset.class_names - cdef double impurity(self, int8_t[:] mask, object[:] y, object[:] class_names): + cdef double impurity(self, int8_t[:] mask): if self.criterion == "gini": - return self.gini_index(mask, y, class_names) + return self.gini_index(mask) elif self.criterion in ("entropy", "log_loss"): - return self.entropy(mask, y, class_names) + return self.entropy(mask) else: assert False @@ -68,9 +72,6 @@ cdef class CyBaseColumnSplitter: child_mask_arrs = [ child_mask.values.astype(np.int8) for child_mask in child_masks ] - cdef object[:] y_arr, class_names_arr - y_arr = self.dataset.y.values - class_names_arr = self.dataset.class_names cdef: int i @@ -84,7 +85,7 @@ cdef class CyBaseColumnSplitter: if parent_mask_value: N_parent += 1 - cdef double impurity_parent = self.impurity(parent_mask_arr, y_arr, class_names_arr) + cdef double impurity_parent = self.impurity(parent_mask_arr) cdef: int j @@ -101,7 +102,7 @@ cdef class CyBaseColumnSplitter: if child_mask_value: N_child_j += 1 N_childs += N_child_j - impurity_child_i = self.impurity(child_mask_arr, y_arr, class_names_arr) + impurity_child_i = self.impurity(child_mask_arr) weighted_impurity_childs += (N_child_j / N_parent) * impurity_child_i cdef double norm_coef @@ -115,7 +116,7 @@ cdef class CyBaseColumnSplitter: return information_gain - cpdef double gini_index(self, int8_t[:] mask, object[:] y, object[:] class_names): + cpdef double gini_index(self, int8_t[:] mask): r""" Calculates Gini index in a tree node. @@ -141,14 +142,14 @@ cdef class CyBaseColumnSplitter: cdef object class_name, label double p_i = 0.0 gini_index = 1.0 - for j in range(len(class_names)): + for j in range(len(self.class_names)): N_i = 0 - class_name = class_names[j] + class_name = self.class_names[j] for i in range(n): mask_value = mask[i] if mask_value: - label = y[i] + label = self.y[i] if label == class_name: N_i += 1 @@ -157,7 +158,7 @@ cdef class CyBaseColumnSplitter: return gini_index - cpdef double entropy(self, int8_t[:] mask, object[:] y, object[:] class_names): + cpdef double entropy(self, int8_t[:] mask): r""" Calculates entropy in a tree node. @@ -184,14 +185,14 @@ cdef class CyBaseColumnSplitter: object class_name, label double p_i = 0.0 entropy = 0.0 - for j in range(len(class_names)): + for j in range(len(self.class_names)): N_i = 0 - class_name = class_names[j] + class_name = self.class_names[j] for i in range(n): mask_value = mask[i] if mask_value: - label = y[i] + label = self.y[i] if label == class_name: N_i += 1 diff --git a/tests/column_splitter/test__cy_base_column_splitter.py b/tests/column_splitter/test__cy_base_column_splitter.py index e8cebf1..2997e01 100644 --- a/tests/column_splitter/test__cy_base_column_splitter.py +++ b/tests/column_splitter/test__cy_base_column_splitter.py @@ -8,10 +8,8 @@ def test__gini_index(dataset): cy_base_column_splitter = CyBaseColumnSplitter(dataset=dataset, criterion="gini") mask = dataset.y.apply(lambda x: True).values.astype(np.int8) - class_names = np.sort(dataset.y.unique()) - y = dataset.y.values - gini_index = cy_base_column_splitter.gini_index(mask, y, class_names) + gini_index = cy_base_column_splitter.gini_index(mask) assert gini_index == 0.6666591342419322 @@ -20,8 +18,6 @@ def test__entropy(dataset): cy_base_column_splitter = CyBaseColumnSplitter(dataset=dataset, criterion="entropy") mask = dataset.y.apply(lambda x: True).values.astype(np.int8) - class_names = np.sort(dataset.y.unique()) - y = dataset.y.values - gini_index = cy_base_column_splitter.entropy(mask, y, class_names) + gini_index = cy_base_column_splitter.entropy(mask) assert gini_index == 1.584946181877191 From e27d29373177f6de51e8891d12a80fe561793c3c Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Thu, 25 Sep 2025 02:10:29 +0300 Subject: [PATCH 7/8] no public --- smarttree/_cy_column_splitter.pyx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/smarttree/_cy_column_splitter.pyx b/smarttree/_cy_column_splitter.pyx index 27af703..04aaada 100644 --- a/smarttree/_cy_column_splitter.pyx +++ b/smarttree/_cy_column_splitter.pyx @@ -10,9 +10,9 @@ from ._dataset import Dataset cdef class CyBaseColumnSplitter: - cdef public str criterion - cdef public object[:] y - cdef public object[:] class_names + cdef str criterion + cdef object[:] y + cdef object[:] class_names def __cinit__(self, dataset: Dataset, criterion: str) -> None: self.criterion = criterion From 69cd7e9b8110114461845e6a20ee0c28e74d2581 Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Thu, 25 Sep 2025 02:56:38 +0300 Subject: [PATCH 8/8] Criterion(Enum) --- smarttree/_column_splitter.py | 10 ++++++++-- smarttree/_cy_column_splitter.pyi | 10 ++-------- smarttree/_cy_column_splitter.pyx | 16 +++++++++------- smarttree/_types.py | 7 +++++++ .../test__cy_base_column_splitter.py | 9 +++++++-- 5 files changed, 33 insertions(+), 19 deletions(-) diff --git a/smarttree/_column_splitter.py b/smarttree/_column_splitter.py index 809f4f2..1a58ad1 100644 --- a/smarttree/_column_splitter.py +++ b/smarttree/_column_splitter.py @@ -12,7 +12,7 @@ from ._cy_column_splitter import CyBaseColumnSplitter from ._dataset import Dataset from ._tree import TreeNode -from ._types import ClassificationCriterionType, NaModeType +from ._types import ClassificationCriterionType, Criterion, NaModeType NO_INFORMATION_GAIN = float("-inf") @@ -36,6 +36,12 @@ def no_split(cls) -> ColumnSplitResult: class BaseColumnSplitter(ABC): + mapping: dict[ClassificationCriterionType, Criterion] = { + "gini": Criterion.GINI, + "entropy": Criterion.ENTROPY, + "log_loss": Criterion.LOG_LOSS, + } + def __init__( self, dataset: Dataset, @@ -46,7 +52,7 @@ def __init__( ) -> None: self.dataset = dataset - self.criterion = criterion + self.criterion = self.mapping[criterion] self.min_samples_split = min_samples_split self.min_samples_leaf = min_samples_leaf self.feature_na_mode = feature_na_mode diff --git a/smarttree/_cy_column_splitter.pyi b/smarttree/_cy_column_splitter.pyi index d578b70..5328e3c 100644 --- a/smarttree/_cy_column_splitter.pyi +++ b/smarttree/_cy_column_splitter.pyi @@ -3,17 +3,11 @@ import pandas as pd from numpy.typing import NDArray from ._dataset import Dataset -from ._types import ClassificationCriterionType +from ._types import Criterion class CyBaseColumnSplitter: - criterion: ClassificationCriterionType - - def __init__( - self, - dataset: Dataset, - criterion: ClassificationCriterionType, - ) -> None: + def __init__(self, dataset: Dataset, criterion: Criterion) -> None: ... def information_gain( diff --git a/smarttree/_cy_column_splitter.pyx b/smarttree/_cy_column_splitter.pyx index 04aaada..5e0ef56 100644 --- a/smarttree/_cy_column_splitter.pyx +++ b/smarttree/_cy_column_splitter.pyx @@ -6,26 +6,28 @@ import numpy as np import pandas as pd from ._dataset import Dataset +from ._types import Criterion + + +cdef int CRITERION_GINI = 1 cdef class CyBaseColumnSplitter: - cdef str criterion + cdef int criterion cdef object[:] y cdef object[:] class_names - def __cinit__(self, dataset: Dataset, criterion: str) -> None: - self.criterion = criterion + def __cinit__(self, dataset: Dataset, criterion: Criterion) -> None: + self.criterion = criterion.value self.y = dataset.y.values self.class_names = dataset.class_names cdef double impurity(self, int8_t[:] mask): - if self.criterion == "gini": + if self.criterion == CRITERION_GINI: return self.gini_index(mask) - elif self.criterion in ("entropy", "log_loss"): - return self.entropy(mask) else: - assert False + return self.entropy(mask) def information_gain( self, diff --git a/smarttree/_types.py b/smarttree/_types.py index 927e113..9ffa7b8 100644 --- a/smarttree/_types.py +++ b/smarttree/_types.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Literal @@ -11,3 +12,9 @@ VerboseType = Literal["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"] | int SplitType = Literal["numerical", "categorical", "rank"] + + +class Criterion(Enum): + GINI = 1 + ENTROPY = 2 + LOG_LOSS = 2 diff --git a/tests/column_splitter/test__cy_base_column_splitter.py b/tests/column_splitter/test__cy_base_column_splitter.py index 2997e01..77605d2 100644 --- a/tests/column_splitter/test__cy_base_column_splitter.py +++ b/tests/column_splitter/test__cy_base_column_splitter.py @@ -1,11 +1,14 @@ import numpy as np from smarttree._cy_column_splitter import CyBaseColumnSplitter +from smarttree._types import Criterion def test__gini_index(dataset): - cy_base_column_splitter = CyBaseColumnSplitter(dataset=dataset, criterion="gini") + cy_base_column_splitter = CyBaseColumnSplitter( + dataset=dataset, criterion=Criterion.GINI + ) mask = dataset.y.apply(lambda x: True).values.astype(np.int8) @@ -15,7 +18,9 @@ def test__gini_index(dataset): def test__entropy(dataset): - cy_base_column_splitter = CyBaseColumnSplitter(dataset=dataset, criterion="entropy") + cy_base_column_splitter = CyBaseColumnSplitter( + dataset=dataset, criterion=Criterion.ENTROPY + ) mask = dataset.y.apply(lambda x: True).values.astype(np.int8)