From ca8c3c6c0dc724c92de55872384a80bf9a13cc8c Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Sat, 27 Sep 2025 12:05:49 +0300 Subject: [PATCH] extract Criterion --- smarttree/_builder.py | 60 ++++------------ smarttree/_criterion.pxd | 15 ++++ smarttree/_criterion.pyi | 46 ++++++++++++ smarttree/_criterion.pyx | 71 +++++++++++++++++++ smarttree/_cy_column_splitter.pxd | 8 +-- smarttree/_cy_column_splitter.pyi | 27 ------- smarttree/_cy_column_splitter.pyx | 71 +++---------------- .../test__cy_base_column_splitter.py | 28 -------- tests/test__criterion.py | 17 +++++ 9 files changed, 173 insertions(+), 170 deletions(-) create mode 100644 smarttree/_criterion.pxd create mode 100644 smarttree/_criterion.pyi create mode 100644 smarttree/_criterion.pyx delete mode 100644 tests/column_splitter/test__cy_base_column_splitter.py create mode 100644 tests/test__criterion.py diff --git a/smarttree/_builder.py b/smarttree/_builder.py index 346acd3..5e099d2 100644 --- a/smarttree/_builder.py +++ b/smarttree/_builder.py @@ -1,10 +1,11 @@ import bisect -import math import numpy as np import pandas as pd from numpy.typing import NDArray +from ._criterion import ClassificationCriterion, Entropy, Gini +from ._dataset import Dataset from ._node_splitter import NodeSplitter from ._tree import Tree, TreeNode from ._types import ClassificationCriterionType @@ -21,6 +22,7 @@ def __init__( hierarchy: dict[str, str | list[str]], ) -> None: + self.X = X self.available_features = X.columns.to_list() self.y = y self.criterion = criterion @@ -28,12 +30,6 @@ def __init__( self.max_leaf_nodes = max_leaf_nodes self.hierarchy = hierarchy - match self.criterion: - case "gini": - self.impurity = self.gini_index - case "entropy" | "log_loss": - self.impurity = self.entropy - if self.criterion in ("gini", "entropy", "log_loss"): self.class_names = np.sort(self.y.unique()) @@ -110,44 +106,12 @@ def distribution(self, mask: pd.Series) -> NDArray[np.integer]: return result - def gini_index(self, mask: pd.Series) -> float: - r""" - Calculates Gini index in a tree node. - - Gini index formula in LaTeX: - \text{Gini Index} = 1 - \sum^C_{i=1} p_i^2 - where - C - total number of classes; - p_i - the probability of choosing a sample with class i. - """ - N = mask.sum() - - gini_index = 1 - for label in self.class_names: - N_i = (mask & (self.y == label)).sum() - p_i = N_i / N - gini_index -= pow(p_i, 2) - - return gini_index - - def entropy(self, mask: pd.Series) -> float: - r""" - Calculates entropy in a tree node. - - Entropy formula in LaTeX: - H = \log{\overline{N}} = \sum^N_{i=1} p_i \log{(1/p_i)} = -\sum^N_{i=1} p_i \log{p_i} - where - H - entropy; - \overline{N} - effective number of states; - p_i - probability of the i-th system state. - """ - N = mask.sum() - - entropy = 0 - for label in self.class_names: - N_i = (mask & (self.y == label)).sum() - if N_i != 0: - p_i = N_i / N - entropy -= p_i * math.log2(p_i) - - return entropy + def impurity(self, mask: pd.Series) -> float: + + criterion: ClassificationCriterion + if self.criterion == "gini": + criterion = Gini(Dataset(self.X, self.y)) + else: # "entropy" | "log_loss" + criterion = Entropy(Dataset(self.X, self.y)) + + return criterion.impurity(mask.to_numpy(dtype=np.int8)) diff --git a/smarttree/_criterion.pxd b/smarttree/_criterion.pxd new file mode 100644 index 0000000..4d6f2ba --- /dev/null +++ b/smarttree/_criterion.pxd @@ -0,0 +1,15 @@ +from libc.stdint cimport int8_t + + +cdef class ClassificationCriterion: + + cdef int[:] y + cdef Py_ssize_t n_classes + cdef Py_ssize_t n_samples + + +cdef class Gini(ClassificationCriterion): + cpdef double impurity(self, int8_t[:] mask) + +cdef class Entropy(ClassificationCriterion): + cpdef double impurity(self, int8_t[:] mask) diff --git a/smarttree/_criterion.pyi b/smarttree/_criterion.pyi new file mode 100644 index 0000000..e9a9d12 --- /dev/null +++ b/smarttree/_criterion.pyi @@ -0,0 +1,46 @@ +from abc import ABC, abstractmethod + +import numpy as np +from numpy.typing import NDArray + +from ._dataset import Dataset + +class ClassificationCriterion(ABC): + + def __init__(self, dataset: Dataset) -> None: + ... + + @abstractmethod + def impurity(self, mask: NDArray[np.int8]) -> float: + raise NotImplementedError + + +class Gini(ClassificationCriterion): + + def impurity(self, mask: NDArray[np.int8]) -> float: + r""" + Calculates Gini index in a tree node. + + Gini index formula in LaTeX: + \text{Gini Index} = 1 - \sum^C_{i=1} p_i^2 + where + C - total number of classes; + p_i - the probability of choosing a sample with class i. + """ + ... + + +class Entropy(ClassificationCriterion): + + def impurity(self, mask: NDArray[np.int8]) -> float: + r""" + Calculates entropy in a tree node. + + Entropy formula in LaTeX: + H = \log{\overline{N}} = \sum^N_{i=1} p_i \log{(1/p_i)} = -\sum^N_{i=1} p_i \log{p_i} + where + H - entropy; + \overline{N} - effective number of states; + p_i - probability of the i-th system state. + """ + ... diff --git a/smarttree/_criterion.pyx b/smarttree/_criterion.pyx new file mode 100644 index 0000000..79014ac --- /dev/null +++ b/smarttree/_criterion.pyx @@ -0,0 +1,71 @@ +cimport cython +from libc.math cimport log2 +from libc.stdint cimport int8_t + +import numpy as np + +from ._dataset import Dataset + + +cdef class ClassificationCriterion: + + def __cinit__(self, dataset: Dataset) -> None: + self.y = dataset.y + self.n_classes = len(dataset.classes) + self.n_samples = len(dataset.y) + + +cdef class Gini(ClassificationCriterion): + + @cython.boundscheck(False) + @cython.cdivision(True) + @cython.wraparound(False) + cpdef double impurity(self, int8_t[:] mask): + + cdef Py_ssize_t i + cdef long[:] counts + cdef long N + cdef double p_i, gini + + counts = np.zeros(self.n_classes, dtype=np.int32) + N = 0 + for i in range(self.n_samples): + if mask[i]: + N += 1 + counts[self.y[i]] += 1 + + gini = 1.0 + for i in range(self.n_classes): + if counts[i] > 0: + p_i = counts[i] / N + gini -= p_i * p_i + + return gini + + +cdef class Entropy(ClassificationCriterion): + + @cython.boundscheck(False) + @cython.cdivision(True) + @cython.wraparound(False) + cpdef double impurity(self, int8_t[:] mask): + + cdef Py_ssize_t i + cdef long[:] counts + cdef long N + cdef double p_i, entropy + + counts = np.zeros(self.n_classes, dtype=np.int32) + N = 0 + for i in range(self.n_samples): + if mask[i]: + N += 1 + counts[self.y[i]] += 1 + + entropy = 0.0 + for i in range(self.n_classes): + if counts[i] > 0: + p_i = counts[i] / N + entropy -= p_i * log2(p_i) + + return entropy diff --git a/smarttree/_cy_column_splitter.pxd b/smarttree/_cy_column_splitter.pxd index 0630d36..e297fb4 100644 --- a/smarttree/_cy_column_splitter.pxd +++ b/smarttree/_cy_column_splitter.pxd @@ -1,13 +1,11 @@ from libc.stdint cimport int8_t +from ._criterion cimport ClassificationCriterion + cdef class CyBaseColumnSplitter: - cdef int criterion + cdef ClassificationCriterion criterion cdef int[:] y cdef Py_ssize_t n_classes cdef Py_ssize_t n_samples - - cdef double impurity(self, int8_t[:] mask) - cpdef double gini_index(self, int8_t[:] mask) - cpdef double entropy(self, int8_t[:] mask) diff --git a/smarttree/_cy_column_splitter.pyi b/smarttree/_cy_column_splitter.pyi index 5328e3c..efa445f 100644 --- a/smarttree/_cy_column_splitter.pyi +++ b/smarttree/_cy_column_splitter.pyi @@ -1,6 +1,4 @@ -import numpy as np import pandas as pd -from numpy.typing import NDArray from ._dataset import Dataset from ._types import Criterion @@ -51,28 +49,3 @@ class CyBaseColumnSplitter: \end{itemize} """ ... - - def gini_index(self, mask: NDArray[np.int8]) -> float: - r""" - Calculates Gini index in a tree node. - - Gini index formula in LaTeX: - \text{Gini Index} = 1 - \sum^C_{i=1} p_i^2 - where - C - total number of classes; - p_i - the probability of choosing a sample with class i. - """ - ... - - def entropy(self, mask: NDArray[np.int8]) -> float: - r""" - Calculates entropy in a tree node. - - Entropy formula in LaTeX: - H = \log{\overline{N}} = \sum^N_{i=1} p_i \log{(1/p_i)} = -\sum^N_{i=1} p_i \log{p_i} - where - H - entropy; - \overline{N} - effective number of states; - p_i - probability of the i-th system state. - """ - ... diff --git a/smarttree/_cy_column_splitter.pyx b/smarttree/_cy_column_splitter.pyx index 484c87f..6f7d724 100644 --- a/smarttree/_cy_column_splitter.pyx +++ b/smarttree/_cy_column_splitter.pyx @@ -1,5 +1,4 @@ cimport cython -from libc.math cimport log2 from libc.stdint cimport int8_t import numpy as np @@ -7,6 +6,7 @@ import pandas as pd from ._dataset import Dataset from ._types import Criterion +from ._criterion cimport Entropy, Gini cdef int CRITERION_GINI = 1 @@ -15,10 +15,13 @@ cdef int CRITERION_GINI = 1 cdef class CyBaseColumnSplitter: def __cinit__(self, dataset: Dataset, criterion: Criterion) -> None: - self.criterion = criterion.value self.y = dataset.y self.n_classes = len(dataset.classes) self.n_samples = len(dataset.y) + if criterion.value == CRITERION_GINI: + self.criterion = Gini(dataset) + else: + self.criterion = Entropy(dataset) def information_gain( self, @@ -32,9 +35,9 @@ cdef class CyBaseColumnSplitter: cdef long N, N_parent, N_childs, N_child_j cdef double impurity_parent, weighted_impurity_childs, impurity_child_i - parent_mask_arr = parent_mask.values.astype(np.int8) + parent_mask_arr = parent_mask.to_numpy(dtype=np.int8) child_mask_arrs = [ - child_mask.values.astype(np.int8) for child_mask in child_masks + child_mask.to_numpy(dtype=np.int8) for child_mask in child_masks ] N = 0 @@ -44,7 +47,7 @@ cdef class CyBaseColumnSplitter: if parent_mask_arr[i]: N_parent += 1 - impurity_parent = self.impurity(parent_mask_arr) + impurity_parent = self.criterion.impurity(parent_mask_arr) N_childs = 0 n_childs = len(child_mask_arrs) @@ -56,7 +59,7 @@ cdef class CyBaseColumnSplitter: if child_mask_arr[i]: N_child_j += 1 N_childs += N_child_j - impurity_child_i = self.impurity(child_mask_arr) + impurity_child_i = self.criterion.impurity(child_mask_arr) weighted_impurity_childs += (N_child_j / N_parent) * impurity_child_i cdef double norm_coef @@ -69,59 +72,3 @@ cdef class CyBaseColumnSplitter: cdef double information_gain = (N_parent / N) * local_information_gain return information_gain - - cdef double impurity(self, int8_t[:] mask): - if self.criterion == CRITERION_GINI: - return self.gini_index(mask) - else: - return self.entropy(mask) - - @cython.boundscheck(False) - @cython.cdivision(True) - @cython.wraparound(False) - cpdef double gini_index(self, int8_t[:] mask): - - cdef Py_ssize_t i - cdef long[:] counts - cdef long N - cdef double p_i, gini - - counts = np.zeros(self.n_classes, dtype=np.int32) - N = 0 - for i in range(self.n_samples): - if mask[i]: - N += 1 - counts[self.y[i]] += 1 - - gini = 1.0 - for i in range(self.n_classes): - if counts[i] > 0: - p_i = counts[i] / N - gini -= p_i * p_i - - return gini - - @cython.boundscheck(False) - @cython.cdivision(True) - @cython.wraparound(False) - cpdef double entropy(self, int8_t[:] mask): - - cdef Py_ssize_t i - cdef long[:] counts - cdef long N - cdef double p_i, entropy - - counts = np.zeros(self.n_classes, dtype=np.int32) - N = 0 - for i in range(self.n_samples): - if mask[i]: - N += 1 - counts[self.y[i]] += 1 - - entropy = 0.0 - for i in range(self.n_classes): - if counts[i] > 0: - p_i = counts[i] / N - entropy -= p_i * log2(p_i) - - return entropy diff --git a/tests/column_splitter/test__cy_base_column_splitter.py b/tests/column_splitter/test__cy_base_column_splitter.py deleted file mode 100644 index 8c1f591..0000000 --- a/tests/column_splitter/test__cy_base_column_splitter.py +++ /dev/null @@ -1,28 +0,0 @@ -import numpy as np - -from smarttree._cy_column_splitter import CyBaseColumnSplitter -from smarttree._types import Criterion - - -def test__gini_index(dataset): - - cy_base_column_splitter = CyBaseColumnSplitter( - dataset=dataset, criterion=Criterion.GINI - ) - - mask = np.ones(dataset.y.shape, dtype=np.int8) - - gini_index = cy_base_column_splitter.gini_index(mask) - assert gini_index == 0.6666591342419322 - - -def test__entropy(dataset): - - cy_base_column_splitter = CyBaseColumnSplitter( - dataset=dataset, criterion=Criterion.ENTROPY - ) - - mask = np.ones(dataset.y.shape, dtype=np.int8) - - gini_index = cy_base_column_splitter.entropy(mask) - assert gini_index == 1.584946181877191 diff --git a/tests/test__criterion.py b/tests/test__criterion.py new file mode 100644 index 0000000..4a690ea --- /dev/null +++ b/tests/test__criterion.py @@ -0,0 +1,17 @@ +import numpy as np + +from smarttree._criterion import Entropy, Gini + + +def test__gini(dataset): + gini_criterion = Gini(dataset) + mask = np.ones(dataset.y.shape, dtype=np.int8) + gini_index = gini_criterion.impurity(mask) + assert gini_index == 0.6666591342419322 + + +def test__entropy(dataset): + entropy_criterion = Entropy(dataset) + mask = np.ones(dataset.y.shape, dtype=np.int8) + entropy = entropy_criterion.impurity(mask) + assert entropy == 1.584946181877191