From 8a8776315898307321c46465a088dd7a5653a401 Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Wed, 1 Oct 2025 22:17:59 +0300 Subject: [PATCH 1/9] added SmartDecisionTreeRegressor --- smarttree/_classes.py | 65 +++++++++++++++++++++++++++++++++++++++++-- smarttree/_types.py | 2 ++ 2 files changed, 64 insertions(+), 3 deletions(-) diff --git a/smarttree/_classes.py b/smarttree/_classes.py index bbdadcb..1b01d3c 100644 --- a/smarttree/_classes.py +++ b/smarttree/_classes.py @@ -3,7 +3,7 @@ import math from abc import ABC, abstractmethod from functools import lru_cache -from typing import Self +from typing import cast, Self import numpy as np import pandas as pd @@ -23,6 +23,7 @@ CommonNaModeType, NaModeType, NumNaModeType, + RegressionCriterionType, VerboseType, ) @@ -33,7 +34,7 @@ class BaseSmartDecisionTree(ABC): def __init__( self, *, - criterion: ClassificationCriterionType = "gini", + criterion: ClassificationCriterionType | RegressionCriterionType = "gini", max_depth: int | None = None, min_samples_split: int | float = 2, min_samples_leaf: int | float = 1, @@ -116,7 +117,7 @@ def __init__( self._feature_na_filler: dict[str, int | float | str] = dict() @property - def criterion(self) -> ClassificationCriterionType: + def criterion(self) -> ClassificationCriterionType | RegressionCriterionType: return self.__criterion @property @@ -469,6 +470,10 @@ def __init__( ) self.__classes: NDArray = np.array([]) + @property + def criterion(self) -> ClassificationCriterionType: + return cast(ClassificationCriterionType, super().criterion) + @property def classes_(self) -> NDArray: self._check_is_fitted() @@ -782,3 +787,57 @@ def render( ) return graph + + +class SmartDecisionTreeRegressor(BaseSmartDecisionTree): + """ + TODO. + + """ + def __init__( + self, + *, + criterion: RegressionCriterionType = "squared_error", + max_depth: int | None = None, + min_samples_split: int | float = 2, + min_samples_leaf: int | float = 1, + max_leaf_nodes: int | None = None, + min_impurity_decrease: float = .0, + max_childs: int | None = None, + num_features: list[str] | str | None = None, + cat_features: list[str] | str | None = None, + rank_features: dict[str, list] | None = None, + hierarchy: dict[str, str | list[str]] | None = None, + na_mode: CommonNaModeType = "include_best", + num_na_mode: NumNaModeType | None = None, + cat_na_mode: CatNaModeType | None = None, + cat_na_filler: str = "missing_value", + rank_na_mode: CommonNaModeType | None = None, + feature_na_mode: dict[str, NaModeType] | None = None, + verbose: VerboseType = "WARNING", + ) -> None: + + super().__init__( + criterion=criterion, + max_depth=max_depth, + min_samples_split=min_samples_split, + min_samples_leaf=min_samples_leaf, + max_leaf_nodes=max_leaf_nodes, + min_impurity_decrease=min_impurity_decrease, + max_childs=max_childs, + num_features=num_features, + cat_features=cat_features, + rank_features=rank_features, + hierarchy=hierarchy, + na_mode=na_mode, + num_na_mode=num_na_mode, + cat_na_mode=cat_na_mode, + cat_na_filler=cat_na_filler, + rank_na_mode=rank_na_mode, + feature_na_mode=feature_na_mode, + verbose=verbose, + ) + + @property + def criterion(self) -> RegressionCriterionType: + return cast(RegressionCriterionType, super().criterion) diff --git a/smarttree/_types.py b/smarttree/_types.py index 927e113..b8fa8c3 100644 --- a/smarttree/_types.py +++ b/smarttree/_types.py @@ -8,6 +8,8 @@ CatNaModeType = Literal["as_category", "include_all", "include_best"] NaModeType = Literal["min", "max", "as_category", "include_all", "include_best"] +RegressionCriterionType = Literal["squared_error"] + VerboseType = Literal["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"] | int SplitType = Literal["numerical", "categorical", "rank"] From 7fd55c16fe08b495b0d209e66351a25c9ab66fd8 Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Wed, 1 Oct 2025 22:31:57 +0300 Subject: [PATCH 2/9] up .__repr__() --- smarttree/_classes.py | 84 +++++++++++++++++++++---------------------- 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/smarttree/_classes.py b/smarttree/_classes.py index 1b01d3c..ac65ef4 100644 --- a/smarttree/_classes.py +++ b/smarttree/_classes.py @@ -3,7 +3,7 @@ import math from abc import ABC, abstractmethod from functools import lru_cache -from typing import cast, Self +from typing import Self, cast import numpy as np import pandas as pd @@ -206,6 +206,47 @@ def feature_importances_(self) -> dict[str, float]: self._check_is_fitted() return self.tree_.compute_feature_importances() + def __repr__(self) -> str: + repr_ = [] + + # if a parameter value differs from default, then it added to the representation + if self.criterion not in ("gini", "squared_error"): + repr_.append(f"criterion={self.criterion!r}") + if self.max_depth: + repr_.append(f"max_depth={self.max_depth}") + if self.min_samples_split != 2: + repr_.append(f"min_samples_split={self.min_samples_split}") + if self.min_samples_leaf != 1: + repr_.append(f"min_samples_leaf={self.min_samples_leaf}") + if self.max_leaf_nodes: + repr_.append(f"max_leaf_nodes={self.max_leaf_nodes}") + if self.min_impurity_decrease != .0: + repr_.append(f"min_impurity_decrease={self.min_impurity_decrease}") + if self.max_childs: + repr_.append(f"max_childs={self.max_childs}") + if self.num_features: + repr_.append(f"num_features={self.num_features}") + if self.cat_features: + repr_.append(f"cat_features={self.cat_features}") + if self.rank_features: + repr_.append(f"rank_features={self.rank_features}") + if self.hierarchy: + repr_.append(f"hierarchy={self.hierarchy}") + if self.na_mode != "include_best": + repr_.append(f"na_mode={self.na_mode!r}") + if self.num_na_mode: + repr_.append(f"num_na_mode={self.num_na_mode!r}") + if self.cat_na_mode: + repr_.append(f"cat_na_mode={self.cat_na_mode!r}") + if self.cat_na_filler != "missing_value": + repr_.append(f"cat_na_filler={self.cat_na_filler!r}") + if self.rank_na_mode: + repr_.append(f"rank_na_mode={self.rank_na_mode!r}") + + return ( + f"{self.__class__.__name__}({', '.join(repr_)})" + ) + @abstractmethod def fit(self, X: pd.DataFrame, y: pd.Series) -> Self: raise NotImplementedError @@ -479,47 +520,6 @@ def classes_(self) -> NDArray: self._check_is_fitted() return self.__classes - def __repr__(self) -> str: - repr_ = [] - - # if a parameter value differs from default, then it added to the representation - if self.criterion != "gini": - repr_.append(f"criterion={self.criterion!r}") - if self.max_depth: - repr_.append(f"max_depth={self.max_depth}") - if self.min_samples_split != 2: - repr_.append(f"min_samples_split={self.min_samples_split}") - if self.min_samples_leaf != 1: - repr_.append(f"min_samples_leaf={self.min_samples_leaf}") - if self.max_leaf_nodes: - repr_.append(f"max_leaf_nodes={self.max_leaf_nodes}") - if self.min_impurity_decrease != .0: - repr_.append(f"min_impurity_decrease={self.min_impurity_decrease}") - if self.max_childs: - repr_.append(f"max_childs={self.max_childs}") - if self.num_features: - repr_.append(f"num_features={self.num_features}") - if self.cat_features: - repr_.append(f"cat_features={self.cat_features}") - if self.rank_features: - repr_.append(f"rank_features={self.rank_features}") - if self.hierarchy: - repr_.append(f"hierarchy={self.hierarchy}") - if self.na_mode != "include_best": - repr_.append(f"na_mode={self.na_mode!r}") - if self.num_na_mode: - repr_.append(f"num_na_mode={self.num_na_mode!r}") - if self.cat_na_mode: - repr_.append(f"cat_na_mode={self.cat_na_mode!r}") - if self.cat_na_filler != "missing_value": - repr_.append(f"cat_na_filler={self.cat_na_filler!r}") - if self.rank_na_mode: - repr_.append(f"rank_na_mode={self.rank_na_mode!r}") - - return ( - f"{self.__class__.__name__}({', '.join(repr_)})" - ) - def fit(self, X: pd.DataFrame, y: pd.Series) -> Self: """ Build a decision tree classifier from the training set (X, y). From 40d8cc29974dafcce0e07ac1be79f3d499c91edd Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Wed, 1 Oct 2025 22:34:43 +0300 Subject: [PATCH 3/9] del rank_features from repr --- smarttree/_classes.py | 2 -- .../decision_tree/classifier/test__repr_tree.py | 16 ---------------- 2 files changed, 18 deletions(-) diff --git a/smarttree/_classes.py b/smarttree/_classes.py index ac65ef4..1473998 100644 --- a/smarttree/_classes.py +++ b/smarttree/_classes.py @@ -228,8 +228,6 @@ def __repr__(self) -> str: repr_.append(f"num_features={self.num_features}") if self.cat_features: repr_.append(f"cat_features={self.cat_features}") - if self.rank_features: - repr_.append(f"rank_features={self.rank_features}") if self.hierarchy: repr_.append(f"hierarchy={self.hierarchy}") if self.na_mode != "include_best": diff --git a/tests/decision_tree/classifier/test__repr_tree.py b/tests/decision_tree/classifier/test__repr_tree.py index 46ac5f7..0fe835e 100644 --- a/tests/decision_tree/classifier/test__repr_tree.py +++ b/tests/decision_tree/classifier/test__repr_tree.py @@ -138,22 +138,6 @@ def test_repr_tree__cat_features(cat_features, expected): assert repr(tree_classifier) == expected -@pytest.mark.parametrize( - ("rank_features", "expected"), - [ - (None, f"{CLASS_NAME}()"), - ( - {"f": ["v1", "v2"]}, - f"{CLASS_NAME}(rank_features={{'f': ['v1', 'v2']}})", - ), - ], - ids=["default value", "not default value"], -) -def test_repr_tree__rank_features(rank_features, expected): - tree_classifier = SmartDecisionTreeClassifier(rank_features=rank_features) - assert repr(tree_classifier) == expected - - @pytest.mark.parametrize( ("hierarchy", "expected"), [ From 78fb9128a5e293a5a42c5a2a3a9e32ba6804e77b Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Wed, 1 Oct 2025 22:42:34 +0300 Subject: [PATCH 4/9] up .__repr__() in tests --- tests/__init__.py | 0 tests/decision_tree/__init__.py | 0 tests/decision_tree/base/__init__.py | 0 .../{classifier => base}/test__repr_tree.py | 36 +++++++++---------- 4 files changed, 17 insertions(+), 19 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/decision_tree/__init__.py create mode 100644 tests/decision_tree/base/__init__.py rename tests/decision_tree/{classifier => base}/test__repr_tree.py (82%) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/decision_tree/__init__.py b/tests/decision_tree/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/decision_tree/base/__init__.py b/tests/decision_tree/base/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/decision_tree/classifier/test__repr_tree.py b/tests/decision_tree/base/test__repr_tree.py similarity index 82% rename from tests/decision_tree/classifier/test__repr_tree.py rename to tests/decision_tree/base/test__repr_tree.py index 0fe835e..c1995b3 100644 --- a/tests/decision_tree/classifier/test__repr_tree.py +++ b/tests/decision_tree/base/test__repr_tree.py @@ -1,6 +1,6 @@ import pytest -from smarttree import SmartDecisionTreeClassifier +from ...conftest import ConcreteSmartTree from smarttree._types import ( CatNaModeType, ClassificationCriterionType, @@ -9,7 +9,7 @@ ) -CLASS_NAME = SmartDecisionTreeClassifier.__name__ +CLASS_NAME = ConcreteSmartTree.__name__ @pytest.mark.parametrize( @@ -22,7 +22,7 @@ ) def test_repr_tree__criterion(criterion, expected): criterion: ClassificationCriterionType - tree_classifier = SmartDecisionTreeClassifier(criterion=criterion) + tree_classifier = ConcreteSmartTree(criterion=criterion) assert repr(tree_classifier) == expected @@ -35,7 +35,7 @@ def test_repr_tree__criterion(criterion, expected): ids=["default value", "not default value"], ) def test_repr_tree__max_depth(max_depth, expected): - tree_classifier = SmartDecisionTreeClassifier(max_depth=max_depth) + tree_classifier = ConcreteSmartTree(max_depth=max_depth) assert repr(tree_classifier) == expected @@ -49,7 +49,7 @@ def test_repr_tree__max_depth(max_depth, expected): ids=["default value", "not default value(int)", "not default value(float)"], ) def test_repr_tree__min_samples_split(min_samples_split, expected): - tree_classifier = SmartDecisionTreeClassifier(min_samples_split=min_samples_split) + tree_classifier = ConcreteSmartTree(min_samples_split=min_samples_split) assert repr(tree_classifier) == expected @@ -63,7 +63,7 @@ def test_repr_tree__min_samples_split(min_samples_split, expected): ids=["default value", "not default value(int)", "not default value(float)"], ) def test_repr_tree__min_samples_leaf(min_samples_split, min_samples_leaf, expected): - tree_classifier = SmartDecisionTreeClassifier( + tree_classifier = ConcreteSmartTree( min_samples_split=min_samples_split, min_samples_leaf=min_samples_leaf ) assert repr(tree_classifier) == expected @@ -78,7 +78,7 @@ def test_repr_tree__min_samples_leaf(min_samples_split, min_samples_leaf, expect ids=["default value", "not default value"], ) def test_repr_tree__max_leaf_nodes(max_leaf_nodes, expected): - tree_classifier = SmartDecisionTreeClassifier(max_leaf_nodes=max_leaf_nodes) + tree_classifier = ConcreteSmartTree(max_leaf_nodes=max_leaf_nodes) assert repr(tree_classifier) == expected @@ -91,9 +91,7 @@ def test_repr_tree__max_leaf_nodes(max_leaf_nodes, expected): ids=["default value", "not default value"], ) def test_repr_tree__min_impurity_decrease(min_impurity_decrease, expected): - tree_classifier = SmartDecisionTreeClassifier( - min_impurity_decrease=min_impurity_decrease - ) + tree_classifier = ConcreteSmartTree(min_impurity_decrease=min_impurity_decrease) assert repr(tree_classifier) == expected @@ -106,7 +104,7 @@ def test_repr_tree__min_impurity_decrease(min_impurity_decrease, expected): ids=["default value", "not default value"], ) def test_repr_tree__max_childs(max_childs, expected): - tree_classifier = SmartDecisionTreeClassifier(max_childs=max_childs) + tree_classifier = ConcreteSmartTree(max_childs=max_childs) assert repr(tree_classifier) == expected @@ -120,7 +118,7 @@ def test_repr_tree__max_childs(max_childs, expected): ids=["default value", "not default value(str)", "not default value(list[str])"], ) def test_repr_tree__num_features(num_features, expected): - tree_classifier = SmartDecisionTreeClassifier(num_features=num_features) + tree_classifier = ConcreteSmartTree(num_features=num_features) assert repr(tree_classifier) == expected @@ -134,7 +132,7 @@ def test_repr_tree__num_features(num_features, expected): ids=["default value", "not default value(str)", "not default value(list[str])"], ) def test_repr_tree__cat_features(cat_features, expected): - tree_classifier = SmartDecisionTreeClassifier(cat_features=cat_features) + tree_classifier = ConcreteSmartTree(cat_features=cat_features) assert repr(tree_classifier) == expected @@ -152,7 +150,7 @@ def test_repr_tree__cat_features(cat_features, expected): ], ) def test_repr_tree__hierarchy(hierarchy, expected): - tree_classifier = SmartDecisionTreeClassifier(hierarchy=hierarchy) + tree_classifier = ConcreteSmartTree(hierarchy=hierarchy) assert repr(tree_classifier) == expected @@ -166,7 +164,7 @@ def test_repr_tree__hierarchy(hierarchy, expected): ) def test_repr_tree__na_mode(na_mode, expected): na_mode: CommonNaModeType - tree_classifier = SmartDecisionTreeClassifier(na_mode=na_mode) + tree_classifier = ConcreteSmartTree(na_mode=na_mode) assert repr(tree_classifier) == expected @@ -180,7 +178,7 @@ def test_repr_tree__na_mode(na_mode, expected): ) def test_repr_tree__num_na_mode(num_na_mode, expected): num_na_mode: NumNaModeType - tree_classifier = SmartDecisionTreeClassifier(num_na_mode=num_na_mode) + tree_classifier = ConcreteSmartTree(num_na_mode=num_na_mode) assert repr(tree_classifier) == expected @@ -194,7 +192,7 @@ def test_repr_tree__num_na_mode(num_na_mode, expected): ) def test_repr_tree__cat_na_mode(cat_na_mode, expected): cat_na_mode: CatNaModeType - tree_classifier = SmartDecisionTreeClassifier(cat_na_mode=cat_na_mode) + tree_classifier = ConcreteSmartTree(cat_na_mode=cat_na_mode) assert repr(tree_classifier) == expected @@ -207,7 +205,7 @@ def test_repr_tree__cat_na_mode(cat_na_mode, expected): ids=["default value", "not default value"], ) def test_repr_tree__cat_na_filler(cat_na_filler, expected): - tree_classifier = SmartDecisionTreeClassifier(cat_na_filler=cat_na_filler) + tree_classifier = ConcreteSmartTree(cat_na_filler=cat_na_filler) assert repr(tree_classifier) == expected @@ -220,5 +218,5 @@ def test_repr_tree__cat_na_filler(cat_na_filler, expected): ids=["default value", "not default value"], ) def test_repr_tree_rank_na_mode(rank_na_mode, expected): - tree_classifier = SmartDecisionTreeClassifier(rank_na_mode=rank_na_mode) + tree_classifier = ConcreteSmartTree(rank_na_mode=rank_na_mode) assert repr(tree_classifier) == expected From d6d6152c6becc18ead2e7bd5845980726bfbefb5 Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Wed, 1 Oct 2025 22:44:40 +0300 Subject: [PATCH 5/9] del cat_features from repr --- smarttree/_classes.py | 2 -- tests/decision_tree/base/test__repr_tree.py | 14 -------------- 2 files changed, 16 deletions(-) diff --git a/smarttree/_classes.py b/smarttree/_classes.py index 1473998..87d3747 100644 --- a/smarttree/_classes.py +++ b/smarttree/_classes.py @@ -226,8 +226,6 @@ def __repr__(self) -> str: repr_.append(f"max_childs={self.max_childs}") if self.num_features: repr_.append(f"num_features={self.num_features}") - if self.cat_features: - repr_.append(f"cat_features={self.cat_features}") if self.hierarchy: repr_.append(f"hierarchy={self.hierarchy}") if self.na_mode != "include_best": diff --git a/tests/decision_tree/base/test__repr_tree.py b/tests/decision_tree/base/test__repr_tree.py index c1995b3..01d22f9 100644 --- a/tests/decision_tree/base/test__repr_tree.py +++ b/tests/decision_tree/base/test__repr_tree.py @@ -122,20 +122,6 @@ def test_repr_tree__num_features(num_features, expected): assert repr(tree_classifier) == expected -@pytest.mark.parametrize( - ("cat_features", "expected"), - [ - (None, f"{CLASS_NAME}()"), - ("feature", f"{CLASS_NAME}(cat_features=['feature'])"), - (["feature"], f"{CLASS_NAME}(cat_features=['feature'])"), - ], - ids=["default value", "not default value(str)", "not default value(list[str])"], -) -def test_repr_tree__cat_features(cat_features, expected): - tree_classifier = ConcreteSmartTree(cat_features=cat_features) - assert repr(tree_classifier) == expected - - @pytest.mark.parametrize( ("hierarchy", "expected"), [ From 53a7ccd5ec6ed8b0609b2950911d68a1e40db7e2 Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Wed, 1 Oct 2025 22:45:31 +0300 Subject: [PATCH 6/9] del num_features from repr --- smarttree/_classes.py | 2 -- tests/decision_tree/base/test__repr_tree.py | 14 -------------- 2 files changed, 16 deletions(-) diff --git a/smarttree/_classes.py b/smarttree/_classes.py index 87d3747..6428bd5 100644 --- a/smarttree/_classes.py +++ b/smarttree/_classes.py @@ -224,8 +224,6 @@ def __repr__(self) -> str: repr_.append(f"min_impurity_decrease={self.min_impurity_decrease}") if self.max_childs: repr_.append(f"max_childs={self.max_childs}") - if self.num_features: - repr_.append(f"num_features={self.num_features}") if self.hierarchy: repr_.append(f"hierarchy={self.hierarchy}") if self.na_mode != "include_best": diff --git a/tests/decision_tree/base/test__repr_tree.py b/tests/decision_tree/base/test__repr_tree.py index 01d22f9..377dee7 100644 --- a/tests/decision_tree/base/test__repr_tree.py +++ b/tests/decision_tree/base/test__repr_tree.py @@ -108,20 +108,6 @@ def test_repr_tree__max_childs(max_childs, expected): assert repr(tree_classifier) == expected -@pytest.mark.parametrize( - ("num_features", "expected"), - [ - (None, f"{CLASS_NAME}()"), - ("feature", f"{CLASS_NAME}(num_features=['feature'])"), - (["feature"], f"{CLASS_NAME}(num_features=['feature'])"), - ], - ids=["default value", "not default value(str)", "not default value(list[str])"], -) -def test_repr_tree__num_features(num_features, expected): - tree_classifier = ConcreteSmartTree(num_features=num_features) - assert repr(tree_classifier) == expected - - @pytest.mark.parametrize( ("hierarchy", "expected"), [ From 1f34e553a0d1c9ff763a10358460a3ed5d45aedf Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Wed, 1 Oct 2025 23:22:09 +0300 Subject: [PATCH 7/9] up dataset --- smarttree/_classes.py | 6 ++++-- smarttree/_node_splitter.py | 5 +---- tests/decision_tree/base/test__repr_tree.py | 3 ++- tests/test__node_splitter.py | 5 ++--- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/smarttree/_classes.py b/smarttree/_classes.py index 6428bd5..e51e674 100644 --- a/smarttree/_classes.py +++ b/smarttree/_classes.py @@ -13,6 +13,7 @@ from ._builder import Builder from ._check import check__data, check__params +from ._dataset import Dataset from ._exceptions import NotFittedError from ._node_splitter import NodeSplitter from ._renderer import Renderer @@ -601,9 +602,10 @@ def fit(self, X: pd.DataFrame, y: pd.Series) -> Self: X = self.__preprocess(X) + dataset = Dataset(X, y) + splitter = NodeSplitter( - X=X, - y=y, + dataset=dataset, criterion=self.criterion, max_depth=max_depth, min_samples_split=min_samples_split, diff --git a/smarttree/_node_splitter.py b/smarttree/_node_splitter.py index a713f98..7482dff 100644 --- a/smarttree/_node_splitter.py +++ b/smarttree/_node_splitter.py @@ -1,7 +1,6 @@ from typing import NamedTuple import numpy as np -import pandas as pd from numpy.typing import NDArray from ._column_splitter import CatColumnSplitter, NumColumnSplitter, RankColumnSplitter @@ -34,8 +33,7 @@ class NodeSplitter: def __init__( self, - X: pd.DataFrame, - y: pd.Series, + dataset: Dataset, criterion: ClassificationCriterionType, max_depth: int | float, min_samples_split: int, @@ -61,7 +59,6 @@ def __init__( for rank_feature in rank_features: self.feature_split_type[rank_feature] = "rank" - dataset = Dataset(X, y) self.num_col_splitter = NumColumnSplitter( dataset=dataset, criterion=criterion, diff --git a/tests/decision_tree/base/test__repr_tree.py b/tests/decision_tree/base/test__repr_tree.py index 377dee7..5e9ca5a 100644 --- a/tests/decision_tree/base/test__repr_tree.py +++ b/tests/decision_tree/base/test__repr_tree.py @@ -1,6 +1,5 @@ import pytest -from ...conftest import ConcreteSmartTree from smarttree._types import ( CatNaModeType, ClassificationCriterionType, @@ -8,6 +7,8 @@ NumNaModeType, ) +from ...conftest import ConcreteSmartTree + CLASS_NAME = ConcreteSmartTree.__name__ diff --git a/tests/test__node_splitter.py b/tests/test__node_splitter.py index 010af8b..e75b56e 100644 --- a/tests/test__node_splitter.py +++ b/tests/test__node_splitter.py @@ -5,11 +5,10 @@ @pytest.fixture(scope="module") def node_splitter( - X, y, num_features, cat_features, rank_features, feature_na_mode + dataset, num_features, cat_features, rank_features, feature_na_mode ) -> NodeSplitter: return NodeSplitter( - X=X, - y=y, + dataset=dataset, criterion="gini", max_depth=float("+inf"), min_samples_split=2, From 9d00d1598c760a7fd920467161f22721f08e7719 Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Wed, 1 Oct 2025 23:34:00 +0300 Subject: [PATCH 8/9] up dataset --- smarttree/_builder.py | 12 ++++-------- smarttree/_classes.py | 3 +-- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/smarttree/_builder.py b/smarttree/_builder.py index 68f9f66..21130a7 100644 --- a/smarttree/_builder.py +++ b/smarttree/_builder.py @@ -1,7 +1,6 @@ import bisect import numpy as np -import pandas as pd from ._criterion import ClassificationCriterion, Entropy, Gini from ._dataset import Dataset @@ -13,18 +12,15 @@ class Builder: def __init__( self, - X: pd.DataFrame, - y: pd.Series, + dataset: Dataset, criterion: ClassificationCriterionType, splitter: NodeSplitter, max_leaf_nodes: int | float, hierarchy: dict[str, str | list[str]], ) -> None: - self.X = X - self.y = y - self.dataset = Dataset(X, y) - self.available_features = X.columns.to_list() + self.dataset = dataset + self.available_features = list(dataset.columns) self.splitter = splitter self.max_leaf_nodes = max_leaf_nodes self.hierarchy = hierarchy @@ -44,7 +40,7 @@ def build(self, tree: Tree) -> None: else: # str self.available_features.remove(value) - mask = self.y.apply(lambda x: True).to_numpy() + mask = np.ones(self.dataset.n_samples, dtype=bool) distribution = np.frombuffer(self.criterion.distribution(mask), dtype=np.int64) label = self.dataset.classes[distribution.argmax()] root = tree.create_node( diff --git a/smarttree/_classes.py b/smarttree/_classes.py index e51e674..94bd542 100644 --- a/smarttree/_classes.py +++ b/smarttree/_classes.py @@ -622,8 +622,7 @@ def fit(self, X: pd.DataFrame, y: pd.Series) -> Self: self._tree = Tree() builder = Builder( - X=X, - y=y, + dataset=dataset, criterion=self.criterion, splitter=splitter, max_leaf_nodes=max_leaf_nodes, From 4d7233f39951277c84c063230755368b55ea7621 Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Wed, 1 Oct 2025 23:41:11 +0300 Subject: [PATCH 9/9] refactoring --- smarttree/_classes.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/smarttree/_classes.py b/smarttree/_classes.py index 94bd542..fc53d8d 100644 --- a/smarttree/_classes.py +++ b/smarttree/_classes.py @@ -587,8 +587,6 @@ def fit(self, X: pd.DataFrame, y: pd.Series) -> Self: self.feature_na_mode.update({f: self.rank_na_mode for f in self.rank_features}) self.feature_na_mode.update(temp_feature_na_mode) - self.__classes = np.sort(y.unique()) - for feature, na_mode in self.feature_na_mode.items(): if na_mode == "min": na_filler = X[feature].min() @@ -630,6 +628,7 @@ def fit(self, X: pd.DataFrame, y: pd.Series) -> Self: ) builder.build(self._tree) + self.__classes = dataset.classes self._is_fitted = True return self