Skip to content

Error when using sample_weight paramater in fit function of LocalClassifierPerParentNode and SGDClassifier #164

@johanmazelanssi

Description

@johanmazelanssi

Describe the bug
Using a sample_weight parameter in fit function of LocalClassifierPerParentNode and SGDClassifier cause an error inside the scikit-learn code

To Reproduce

import numpy as np

from sklearn.utils.class_weight import compute_sample_weight

from sklearn.linear_model import SGDClassifier

from hiclass import LocalClassifierPerParentNode


random_state = 0

# Example from https://hiclass.readthedocs.io/en/latest/auto_examples/plot_empty_levels.html
# Define data
X_train = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
X_test = [[9, 10], [7, 8], [5, 6], [3, 4], [1, 2]]
Y_train = np.array(
    [
        ["Bird"],
        ["Reptile", "Snake"],
        ["Reptile", "Lizard"],
        ["Mammal", "Cat"],
        ["Mammal", "Wolf", "Dog"],
    ],
    dtype=object,
)

clf = SGDClassifier(loss="log_loss",
                    n_jobs=1,
                    shuffle=True,
                    random_state=random_state)

classifier = LocalClassifierPerParentNode(local_classifier=clf)

Y_train_mod=["_".join(l) for l in Y_train]
print(f"Y_train_mod: {Y_train_mod}")
sample_weight_l = compute_sample_weight(class_weight="balanced", y=Y_train_mod)

classifier.fit(X_train, Y_train, sample_weight=sample_weight_l)

predictions = classifier.predict(X_test)

print(f"predictions: {predictions}")

The observed output is ValueError: Provided ``coef_`` does not match dataset.

Expected behavior
The script should execute without an error.

Versions:

  • Python 3.12.1
  • HiClass 5.0.4

Additional information
The problem looks specific to SGDClassifier. If I use clf = RandomForestClassifier(max_depth=20, random_state=random_state), the script is ok.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions