diff --git a/smarttree/_check.py b/smarttree/_check.py index d9c1ad1..0469d50 100644 --- a/smarttree/_check.py +++ b/smarttree/_check.py @@ -316,7 +316,7 @@ def _check__rank_na_mode(rank_na_mode): if rank_na_mode not in ("include_all", "include_best"): raise ValueError( "`rank_na_mode` must be Literal['include_all', 'include_best']." - f" The current value of `na_mode` is {rank_na_mode!r}." + f" The current value of `rank_na_mode` is {rank_na_mode!r}." ) diff --git a/tests/decision_tree/base/test__check_params.py b/tests/decision_tree/base/test__check_params.py index 2bcd540..822e58f 100644 --- a/tests/decision_tree/base/test__check_params.py +++ b/tests/decision_tree/base/test__check_params.py @@ -7,6 +7,7 @@ from smarttree._types import ( CatNaModeType, ClassificationCriterionType, + CommonNaModeType, NaModeType, NumNaModeType, ) @@ -559,6 +560,31 @@ def test__check_params__hierarchy(hierarchy, expected_context): SmartDecisionTreeClassifier(hierarchy=hierarchy) +@pytest.mark.parametrize( + ("na_mode", "expected_context"), + [ + ("include_all", does_not_raise()), + ("include_best", does_not_raise()), + ( + "smth", + pytest.raises( + ValueError, + match=re.escape( + "`na_mode` must be Literal['include_all', 'include_best']." + " The current value of `na_mode` is 'smth'." + ), + ), + ), + ], + ids=["include_all", "include_best", "invalid"], +) +def test__check_params__na_mode(na_mode, expected_context): + with expected_context: + na_mode: CommonNaModeType + SmartDecisionTreeClassifier(na_mode=na_mode) + + + @pytest.mark.parametrize( ("num_na_mode", "expected_context"), [ @@ -632,6 +658,30 @@ def test__check_param__cat_na_filler(cat_na_filler, expected_context): SmartDecisionTreeClassifier(cat_na_filler=cat_na_filler) +@pytest.mark.parametrize( + ("rank_na_mode", "expected_context"), + [ + ("include_all", does_not_raise()), + ("include_best", does_not_raise()), + ( + "smth", + pytest.raises( + ValueError, + match=re.escape( + "`rank_na_mode` must be Literal['include_all', 'include_best']." + " The current value of `rank_na_mode` is 'smth'." + ), + ), + ), + ], + ids=["include_all", "include_best", "invalid"], +) +def test__check_params__rank_na_mode(rank_na_mode, expected_context): + with expected_context: + rank_na_mode: CommonNaModeType + SmartDecisionTreeClassifier(rank_na_mode=rank_na_mode) + + @pytest.mark.parametrize( ("min_samples_split", "min_samples_leaf", "expected_context"), [ @@ -700,5 +750,5 @@ def test__check_params__min_samples_split__min_samples_leaf( ) def test__check_params__feature_na_mode(feature_na_mode, expected_context): with expected_context: - feature_na_mode: dict[str, NaModeType | None] + feature_na_mode: dict[str, NaModeType] SmartDecisionTreeClassifier(feature_na_mode=feature_na_mode)