Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions smarttree/_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,16 +317,16 @@ def _check__feature_na_mode(feature_na_mode):


def check__data(
X,
*,
X=None,
y=None,
num_features=None,
cat_features=None,
rank_features=None,
all_features=None,
):
if X is not None:
_check__X(X)
_check__X(X)
X = X.copy()

if y is not None:
_check__y(y)
Expand All @@ -346,6 +346,11 @@ def check__data(
if all_features is not None:
_check__all_features_in(X, all_features)

if y is not None:
return X, y
else:
return X


def _check__X(X):
if not isinstance(X, pd.DataFrame):
Expand All @@ -369,9 +374,17 @@ def _check__num_features_in(X, num_features):
for num_feature in num_features:
if num_feature not in X.columns:
raise ValueError(
f"`num_features` contain feature {num_feature!r},"
f"`num_features` contain feature '{num_feature}',"
" which isnt present in the training data."
)
if not pd.api.types.is_numeric_dtype(X[num_feature]):
try:
X[num_feature] = pd.to_numeric(X[num_feature])
except ValueError:
raise ValueError(
f"`num_features` contain feature '{num_feature}',"
" which isnt numeric or convertable to numeric."
)


def _check__cat_features_in(X, cat_features):
Expand Down
6 changes: 3 additions & 3 deletions smarttree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def fit(self, X: pd.DataFrame, y: pd.Series) -> Self:
y: pd.Series
The target values.
"""
check__data(
X, y = check__data(
X=X,
y=y,
num_features=self.num_features,
Expand Down Expand Up @@ -627,7 +627,7 @@ def predict_proba(self, X: pd.DataFrame) -> NDArray[np.floating]:
ndarray: The class probabilities of the input samples. The order of
the classes corresponds to that in the attribute :term:`class_names`.
"""
check__data(X=X, all_features=self.all_features)
X = check__data(X=X, all_features=self.all_features)

X = self.__preprocess(X)

Expand Down Expand Up @@ -699,7 +699,7 @@ def score(
sample_weight: pd.Series | None = None,
) -> float | np.floating:
"""Returns the accuracy metric."""
check__data(X=X, y=y, all_features=self.all_features)
X, y = check__data(X=X, y=y, all_features=self.all_features)
return accuracy_score(y, self.predict(X), sample_weight=sample_weight)

@lru_cache
Expand Down
16 changes: 15 additions & 1 deletion tests/decision_tree/base/test__check_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,21 @@ def decision_tree():
),
),
),
(
"not_num_dtype",
"valid",
pytest.raises(
ValueError,
match=(
f"`num_features` contain feature '{NUM_FEATURE}',"
" which isnt numeric or convertable to numeric."
),
),
),
],
ids=[
"valid", "not_series", "short", "not_df", "contain_na",
"missing_num", "missing_cat", "missing_rank",
"missing_num", "missing_cat", "missing_rank", "not_num_dtype",
],
)
def test__check_data__fit(X, y, X_scenario, y_scenario, decision_tree, expected_context):
Expand All @@ -85,6 +96,9 @@ def test__check_data__fit(X, y, X_scenario, y_scenario, decision_tree, expected_
"missing_num": X[SELECTED].drop(columns=NUM_FEATURE),
"missing_cat": X[SELECTED].drop(columns=CAT_FEATURE),
"missing_rank": X[SELECTED].drop(columns=RANK_FEATURE),
"not_num_dtype": X[SELECTED].rename(
columns={NUM_FEATURE: CAT_FEATURE, CAT_FEATURE: NUM_FEATURE}
)
}
y_map = {
"valid": y,
Expand Down
Loading