From af8a3412b54dcdf532c0a12b63c06f972789b798 Mon Sep 17 00:00:00 2001 From: Mikhail Martin Date: Thu, 18 Sep 2025 00:19:39 +0300 Subject: [PATCH] added check numeric dtype --- smarttree/_check.py | 21 ++++++++++++++++---- smarttree/_classes.py | 6 +++--- tests/decision_tree/base/test__check_data.py | 16 ++++++++++++++- 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/smarttree/_check.py b/smarttree/_check.py index 883a965..c1cd400 100644 --- a/smarttree/_check.py +++ b/smarttree/_check.py @@ -317,16 +317,16 @@ def _check__feature_na_mode(feature_na_mode): def check__data( + X, *, - X=None, y=None, num_features=None, cat_features=None, rank_features=None, all_features=None, ): - if X is not None: - _check__X(X) + _check__X(X) + X = X.copy() if y is not None: _check__y(y) @@ -346,6 +346,11 @@ def check__data( if all_features is not None: _check__all_features_in(X, all_features) + if y is not None: + return X, y + else: + return X + def _check__X(X): if not isinstance(X, pd.DataFrame): @@ -369,9 +374,17 @@ def _check__num_features_in(X, num_features): for num_feature in num_features: if num_feature not in X.columns: raise ValueError( - f"`num_features` contain feature {num_feature!r}," + f"`num_features` contain feature '{num_feature}'," " which isnt present in the training data." ) + if not pd.api.types.is_numeric_dtype(X[num_feature]): + try: + X[num_feature] = pd.to_numeric(X[num_feature]) + except ValueError: + raise ValueError( + f"`num_features` contain feature '{num_feature}'," + " which isnt numeric or convertable to numeric." + ) def _check__cat_features_in(X, cat_features): diff --git a/smarttree/_classes.py b/smarttree/_classes.py index 2bbcde6..73951e4 100644 --- a/smarttree/_classes.py +++ b/smarttree/_classes.py @@ -495,7 +495,7 @@ def fit(self, X: pd.DataFrame, y: pd.Series) -> Self: y: pd.Series The target values. """ - check__data( + X, y = check__data( X=X, y=y, num_features=self.num_features, @@ -627,7 +627,7 @@ def predict_proba(self, X: pd.DataFrame) -> NDArray[np.floating]: ndarray: The class probabilities of the input samples. The order of the classes corresponds to that in the attribute :term:`class_names`. """ - check__data(X=X, all_features=self.all_features) + X = check__data(X=X, all_features=self.all_features) X = self.__preprocess(X) @@ -699,7 +699,7 @@ def score( sample_weight: pd.Series | None = None, ) -> float | np.floating: """Returns the accuracy metric.""" - check__data(X=X, y=y, all_features=self.all_features) + X, y = check__data(X=X, y=y, all_features=self.all_features) return accuracy_score(y, self.predict(X), sample_weight=sample_weight) @lru_cache diff --git a/tests/decision_tree/base/test__check_data.py b/tests/decision_tree/base/test__check_data.py index 309b6c0..90245b2 100644 --- a/tests/decision_tree/base/test__check_data.py +++ b/tests/decision_tree/base/test__check_data.py @@ -71,10 +71,21 @@ def decision_tree(): ), ), ), + ( + "not_num_dtype", + "valid", + pytest.raises( + ValueError, + match=( + f"`num_features` contain feature '{NUM_FEATURE}'," + " which isnt numeric or convertable to numeric." + ), + ), + ), ], ids=[ "valid", "not_series", "short", "not_df", "contain_na", - "missing_num", "missing_cat", "missing_rank", + "missing_num", "missing_cat", "missing_rank", "not_num_dtype", ], ) def test__check_data__fit(X, y, X_scenario, y_scenario, decision_tree, expected_context): @@ -85,6 +96,9 @@ def test__check_data__fit(X, y, X_scenario, y_scenario, decision_tree, expected_ "missing_num": X[SELECTED].drop(columns=NUM_FEATURE), "missing_cat": X[SELECTED].drop(columns=CAT_FEATURE), "missing_rank": X[SELECTED].drop(columns=RANK_FEATURE), + "not_num_dtype": X[SELECTED].rename( + columns={NUM_FEATURE: CAT_FEATURE, CAT_FEATURE: NUM_FEATURE} + ) } y_map = { "valid": y,