diff --git a/smarttree/_check.py b/smarttree/_check.py index d9ebeb9..cdb9e40 100644 --- a/smarttree/_check.py +++ b/smarttree/_check.py @@ -47,9 +47,13 @@ def check__params( if num_features is not None: _check__num_features(num_features) + param_name = "num_features" + _check__features_contain_duplicates(param_name, num_features) if cat_features is not None: _check__cat_features(cat_features) + param_name = "cat_features" + _check__features_contain_duplicates(param_name, cat_features) if rank_features is not None: _check__rank_features(rank_features) @@ -212,6 +216,11 @@ def _check__rank_features(rank_features): ) +def _check__features_contain_duplicates(param_name, features): + if isinstance(features, list) and len(features) != len(set(features)): + raise ValueError(f"`{param_name}` contains duplicates.") + + def _check__hierarchy(hierarchy): common_message = ( "`hierarchy` must be a dictionary" diff --git a/tests/decision_tree/base/test__check_params.py b/tests/decision_tree/base/test__check_params.py index b87c7b0..9043388 100644 --- a/tests/decision_tree/base/test__check_params.py +++ b/tests/decision_tree/base/test__check_params.py @@ -422,6 +422,26 @@ def test__check_params__rank_features(rank_features, expected_context): SmartDecisionTreeClassifier(rank_features=rank_features) +@pytest.mark.parametrize( + ("params_to_set", "expected_context"), + [ + ({"num_features": ["f1", "f2"]}, does_not_raise()), + ( + {"num_features": ["f1", "f1"]}, + pytest.raises(ValueError, match="`num_features` contains duplicates.") + ), + ({"cat_features": ["f1", "f2"]}, does_not_raise()), + ( + {"cat_features": ["f1", "f1"]}, + pytest.raises(ValueError, match="`cat_features` contains duplicates.") + ), + ], +) +def test__check_params__features_contain_duplicates(params_to_set, expected_context): + with expected_context: + SmartDecisionTreeClassifier(**params_to_set) + + @pytest.mark.parametrize( ("hierarchy", "expected_context"), [