From c746ecab218179e73fcf010c09c76a519fa7c7ca Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Thu, 25 Sep 2025 19:26:51 +0300 Subject: [PATCH 1/3] remove dataclass from dataset --- smarttree/_dataset.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/smarttree/_dataset.py b/smarttree/_dataset.py index 90ee9b0..2da3a61 100644 --- a/smarttree/_dataset.py +++ b/smarttree/_dataset.py @@ -1,23 +1,17 @@ -from dataclasses import dataclass, field - import numpy as np import pandas as pd from numpy.typing import NDArray -@dataclass class Dataset: - X: pd.DataFrame - y: pd.Series - class_names: NDArray = field(init=False) - has_na: dict[str, bool] = field(init=False) - mask_na: dict[str, pd.Series] = field(init=False) + def __init__(self, X: pd.DataFrame, y: pd.Series) -> None: + self.X = X + self.y = y - def __post_init__(self) -> None: - self.class_names = np.sort(self.y.unique()) - self.has_na = dict() - self.mask_na = dict() + self.class_names: NDArray = np.sort(self.y.unique()) + self.has_na: dict[str, bool] = dict() + self.mask_na: dict[str, pd.Series] = dict() for column in self.X.columns: mask_na = self.X[column].isna() has_na = mask_na.any() From abc36329d5e10bb5f2bd9a6ce47e913c96bee47b Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Thu, 25 Sep 2025 19:37:20 +0300 Subject: [PATCH 2/3] renaming class_names -> classes --- smarttree/_cy_column_splitter.pyx | 20 ++++++++++---------- smarttree/_dataset.py | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/smarttree/_cy_column_splitter.pyx b/smarttree/_cy_column_splitter.pyx index 5e0ef56..a4bc212 100644 --- a/smarttree/_cy_column_splitter.pyx +++ b/smarttree/_cy_column_splitter.pyx @@ -16,12 +16,12 @@ cdef class CyBaseColumnSplitter: cdef int criterion cdef object[:] y - cdef object[:] class_names + cdef object[:] classes def __cinit__(self, dataset: Dataset, criterion: Criterion) -> None: self.criterion = criterion.value self.y = dataset.y.values - self.class_names = dataset.class_names + self.classes = dataset.classes cdef double impurity(self, int8_t[:] mask): if self.criterion == CRITERION_GINI: @@ -141,18 +141,18 @@ cdef class CyBaseColumnSplitter: cdef: int j cdef long N_i - cdef object class_name, label + cdef object class_, label double p_i = 0.0 gini_index = 1.0 - for j in range(len(self.class_names)): + for j in range(len(self.classes)): N_i = 0 - class_name = self.class_names[j] + class_ = self.classes[j] for i in range(n): mask_value = mask[i] if mask_value: label = self.y[i] - if label == class_name: + if label == class_: N_i += 1 p_i = N_i / N @@ -184,18 +184,18 @@ cdef class CyBaseColumnSplitter: cdef: int j long N_i = 0 - object class_name, label + object class_, label double p_i = 0.0 entropy = 0.0 - for j in range(len(self.class_names)): + for j in range(len(self.classes)): N_i = 0 - class_name = self.class_names[j] + class_ = self.classes[j] for i in range(n): mask_value = mask[i] if mask_value: label = self.y[i] - if label == class_name: + if label == class_: N_i += 1 if N_i != 0: diff --git a/smarttree/_dataset.py b/smarttree/_dataset.py index 2da3a61..699ff37 100644 --- a/smarttree/_dataset.py +++ b/smarttree/_dataset.py @@ -9,7 +9,7 @@ def __init__(self, X: pd.DataFrame, y: pd.Series) -> None: self.X = X self.y = y - self.class_names: NDArray = np.sort(self.y.unique()) + self.classes: NDArray = np.sort(self.y.unique()) self.has_na: dict[str, bool] = dict() self.mask_na: dict[str, pd.Series] = dict() for column in self.X.columns: From cd6e53becbb65f6bd14838b63b125a7da63ba2f1 Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Thu, 25 Sep 2025 20:48:33 +0300 Subject: [PATCH 3/3] y to int32 --- smarttree/_cy_column_splitter.pyx | 64 +++++++++---------- smarttree/_dataset.py | 6 +- .../test__cy_base_column_splitter.py | 4 +- 3 files changed, 35 insertions(+), 39 deletions(-) diff --git a/smarttree/_cy_column_splitter.pyx b/smarttree/_cy_column_splitter.pyx index a4bc212..aa4b9e1 100644 --- a/smarttree/_cy_column_splitter.pyx +++ b/smarttree/_cy_column_splitter.pyx @@ -15,13 +15,15 @@ cdef int CRITERION_GINI = 1 cdef class CyBaseColumnSplitter: cdef int criterion - cdef object[:] y - cdef object[:] classes + cdef int[:] y + cdef Py_ssize_t n_classes + cdef Py_ssize_t n_samples def __cinit__(self, dataset: Dataset, criterion: Criterion) -> None: self.criterion = criterion.value - self.y = dataset.y.values - self.classes = dataset.classes + self.y = dataset.y + self.n_classes = len(dataset.classes) + self.n_samples = len(dataset.y) cdef double impurity(self, int8_t[:] mask): if self.criterion == CRITERION_GINI: @@ -76,12 +78,11 @@ cdef class CyBaseColumnSplitter: ] cdef: - int i - Py_ssize_t n = len(parent_mask_arr) + Py_ssize_t i long N = 0 long N_parent = 0 int8_t parent_mask_value - for i in range(n): + for i in range(self.n_samples): N += 1 parent_mask_value = parent_mask_arr[i] if parent_mask_value: @@ -90,16 +91,17 @@ cdef class CyBaseColumnSplitter: cdef double impurity_parent = self.impurity(parent_mask_arr) cdef: - int j + Py_ssize_t j + Py_ssize_t n_childs = len(child_mask_arrs) double weighted_impurity_childs = 0.0 long N_childs = 0 long N_child_j int8_t child_mask_value double impurity_child_i - for j in range(len(child_mask_arrs)): + for j in range(n_childs): N_child_j = 0 child_mask_arr = child_mask_arrs[j] - for i in range(n): + for i in range(self.n_samples): child_mask_value = child_mask_arr[i] if child_mask_value: N_child_j += 1 @@ -129,30 +131,28 @@ cdef class CyBaseColumnSplitter: p_i - the probability of choosing a sample with class i. """ cdef: - int i - Py_ssize_t n = len(mask) + Py_ssize_t i int8_t mask_value long N = 0 - for i in range(n): + for i in range(self.n_samples): mask_value = mask[i] if mask_value: N += 1 cdef: - int j - cdef long N_i - cdef object class_, label + Py_ssize_t j + long N_i + int encoded_label double p_i = 0.0 - gini_index = 1.0 - for j in range(len(self.classes)): + double gini_index = 1.0 + for j in range(self.n_classes): N_i = 0 - class_ = self.classes[j] - for i in range(n): + for i in range(self.n_samples): mask_value = mask[i] if mask_value: - label = self.y[i] - if label == class_: + encoded_label = self.y[i] + if encoded_label == j: N_i += 1 p_i = N_i / N @@ -172,30 +172,28 @@ cdef class CyBaseColumnSplitter: p_i - probability of the i-th system state. """ cdef: - int i - Py_ssize_t n = len(mask) + Py_ssize_t i int8_t mask_value long N = 0 - for i in range(n): + for i in range(self.n_samples): mask_value = mask[i] if mask_value: N += 1 cdef: - int j + Py_ssize_t j long N_i = 0 - object class_, label + int encoded_label double p_i = 0.0 - entropy = 0.0 - for j in range(len(self.classes)): + double entropy = 0.0 + for j in range(self.n_classes): N_i = 0 - class_ = self.classes[j] - for i in range(n): + for i in range(self.n_samples): mask_value = mask[i] if mask_value: - label = self.y[i] - if label == class_: + encoded_label = self.y[i] + if encoded_label == j: N_i += 1 if N_i != 0: diff --git a/smarttree/_dataset.py b/smarttree/_dataset.py index 699ff37..e3a2811 100644 --- a/smarttree/_dataset.py +++ b/smarttree/_dataset.py @@ -1,15 +1,13 @@ import numpy as np import pandas as pd -from numpy.typing import NDArray class Dataset: def __init__(self, X: pd.DataFrame, y: pd.Series) -> None: self.X = X - self.y = y - - self.classes: NDArray = np.sort(self.y.unique()) + self.classes = np.sort(y.unique()) + self.y = np.searchsorted(self.classes, y.to_numpy()).astype(np.int32) self.has_na: dict[str, bool] = dict() self.mask_na: dict[str, pd.Series] = dict() for column in self.X.columns: diff --git a/tests/column_splitter/test__cy_base_column_splitter.py b/tests/column_splitter/test__cy_base_column_splitter.py index 77605d2..8c1f591 100644 --- a/tests/column_splitter/test__cy_base_column_splitter.py +++ b/tests/column_splitter/test__cy_base_column_splitter.py @@ -10,7 +10,7 @@ def test__gini_index(dataset): dataset=dataset, criterion=Criterion.GINI ) - mask = dataset.y.apply(lambda x: True).values.astype(np.int8) + mask = np.ones(dataset.y.shape, dtype=np.int8) gini_index = cy_base_column_splitter.gini_index(mask) assert gini_index == 0.6666591342419322 @@ -22,7 +22,7 @@ def test__entropy(dataset): dataset=dataset, criterion=Criterion.ENTROPY ) - mask = dataset.y.apply(lambda x: True).values.astype(np.int8) + mask = np.ones(dataset.y.shape, dtype=np.int8) gini_index = cy_base_column_splitter.entropy(mask) assert gini_index == 1.584946181877191