diff --git a/smarttree/_classes.py b/smarttree/_classes.py index 57f61be..2bbcde6 100644 --- a/smarttree/_classes.py +++ b/smarttree/_classes.py @@ -193,10 +193,16 @@ def tree_(self) -> Tree: assert self._tree is not None return self._tree + def get_n_leaves(self) -> int: + return self.tree_.leaf_counter + + def get_depth(self) -> int: + return self.tree_.max_depth + @property def feature_importances_(self) -> dict[str, float]: self._check_is_fitted() - return self.tree_.feature_importances + return self.tree_.compute_feature_importances() @abstractmethod def fit(self, X: pd.DataFrame, y: pd.Series) -> Self: diff --git a/smarttree/_tree.py b/smarttree/_tree.py index d2bd847..00631e3 100644 --- a/smarttree/_tree.py +++ b/smarttree/_tree.py @@ -103,3 +103,16 @@ def create_node( self.root = node return node + + def compute_feature_importances(self) -> dict[str, float]: + + amount = 0.0 + for importance in self.feature_importances.values(): + amount += importance + + normalized_feature_importances = dict() + for feature, importance in self.feature_importances.items(): + normalized_feature_importances[feature] = importance / amount + + return normalized_feature_importances + diff --git a/tests/decision_tree/classifier/test__classifier_not_fitted.py b/tests/decision_tree/classifier/test__classifier_not_fitted.py index 494fcbb..9e09fbf 100644 --- a/tests/decision_tree/classifier/test__classifier_not_fitted.py +++ b/tests/decision_tree/classifier/test__classifier_not_fitted.py @@ -32,5 +32,4 @@ def test__not_fitted__method(not_fitted_tree, X, y, method_call): ) def test__not_fitted__property(not_fitted_tree, property_name): with pytest.raises(NotFittedError): - property_ = getattr(not_fitted_tree, property_name) - _ = property_ + getattr(not_fitted_tree, property_name)