Skip to content

Bug: BaggingPuClassifier incompatible with scikit-learn ≥ 1.6 — partial_dependence and is_classifier fail #176

@MaxEtherington

Description

@MaxEtherington

Summary

BaggingPuClassifier is not recognised as a classifier by scikit-learn ≥ 1.6's
tag-based estimator introspection system. This causes sklearn.inspection.partial_dependence
(and any other sklearn utility that calls is_classifier()) to raise:

ValueError: 'estimator' must be a fitted regressor or classifier.

Environment

  • pulearn: 0.1.1
  • scikit-learn: 1.6+ (confirmed on 1.7.1)
  • Python: 3.13

Steps to reproduce

import joblib
import numpy as np
from sklearn.base import is_classifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import RobustScaler
from sklearn.svm import SVC
from pulearn import BaggingPuClassifier

svc = SVC(kernel="rbf", probability=True, random_state=0)
clf = BaggingPuClassifier(estimator=svc, n_estimators=10, random_state=0)
pipeline = Pipeline([("scaler", RobustScaler()), ("classifier", clf)])

X = np.random.default_rng(0).standard_normal((100, 4))
y = np.array([1] * 20 + [0] * 80)
pipeline.fit(X, y)

print(is_classifier(pipeline))   # False — expected True
print(is_classifier(clf))        # False — expected True

from sklearn.inspection import partial_dependence
# Raises: ValueError: 'estimator' must be a fitted regressor or classifier.
partial_dependence(pipeline, X=X, features=0)

Root cause

In scikit-learn 1.6, is_classifier() was changed from:

# sklearn < 1.6
return getattr(estimator, "_estimator_type", None) == "classifier"

to:

# sklearn >= 1.6
return get_tags(estimator).estimator_type == "classifier"

get_tags() calls estimator.__sklearn_tags__(). The __sklearn_tags__ method
is defined on BaseEstimator and on each Mixin (ClassifierMixin,
RegressorMixin, etc.), and they are designed to chain via super().

BaggingPuClassifier's MRO is:

BaggingPuClassifier → BaseBaggingPU → BaseEnsemble → ... → BaseEstimator → ClassifierMixin

Because BaseEstimator appears before ClassifierMixin in the MRO,
BaseEstimator.__sklearn_tags__() does not call super().__sklearn_tags__(),
so ClassifierMixin.__sklearn_tags__() — which sets estimator_type = "classifier"
is never reached. The result is Tags(estimator_type=None, ...), so is_classifier
returns False.

The legacy _estimator_type = "classifier" class attribute set by ClassifierMixin
is still present, but sklearn ≥ 1.6 no longer uses it (it is marked
# TODO(1.8): Remove this attribute in the sklearn source).

The same issue propagates through sklearn.pipeline.Pipeline, which delegates its
own tags to its final step. So wrapping BaggingPuClassifier in a Pipeline is
also affected.

Fix

ClassifierMixin must appear to the left of BaseEstimator in the class
definition, or BaggingPuClassifier (and BaseBaggingPU) must explicitly
implement __sklearn_tags__. The minimal fix is:

# In pulearn/bagging.py

from sklearn.utils._tags import ClassifierTags

class BaggingPuClassifier(BaseBaggingPU, ClassifierMixin):

    def __sklearn_tags__(self):
        tags = super().__sklearn_tags__()
        tags.estimator_type = "classifier"
        tags.classifier_tags = ClassifierTags()
        tags.target_tags.required = True
        return tags

Alternatively, reordering the base classes of BaseBaggingPU so that
ClassifierMixin appears before BaseEstimator in the MRO would also resolve
this, but the explicit __sklearn_tags__ override is safer and more explicit.

Workaround (for users)

Until a fix is released, monkey-patching before calling any sklearn inspection
utility works:

from sklearn.utils._tags import ClassifierTags
from pulearn.bagging import BaggingPuClassifier

def _pu_sklearn_tags(self):
    from sklearn.base import BaseEstimator
    tags = BaseEstimator.__sklearn_tags__(self)
    tags.estimator_type = "classifier"
    tags.classifier_tags = ClassifierTags()
    tags.target_tags.required = True
    return tags

BaggingPuClassifier.__sklearn_tags__ = _pu_sklearn_tags

References

Metadata

Metadata

Labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions