Skip to content
This repository was archived by the owner on Aug 6, 2025. It is now read-only.
This repository was archived by the owner on Aug 6, 2025. It is now read-only.

Changing number of classes #282

@vikashg

Description

@vikashg

Hi,
I am trying to retrain the linear classifier by changing the number of classes. This is what I did.
First I trained the features using the following command

out_dir=./main_dino_output
python -m torch.distributed.launch --arch vit_small --data_path </path/to/my/datadir> --output_dir $out_dir --epochs 1000

Now that I have trained my features. I will run a linear classifier as follows

python   eval_linear.py --pretrained_weights $out_dir/checkpoint.pth --num_labels 5 --data_path $data_dir --epochs 500 --arch vit_small

This part executes properly. Now I would like to evaluate the trained algorithm. The trained model from this step is saved as ./checkpoint.pth.tar. So, I execute the above command with the --evaluate flag turned on

python   eval_linear.py  --evaluate --pretrained_weights ./checkpoint.pth.tar --num_labels 5 --data_path $data_dir 

However, in this case I get the following error:
image

The model throws an error as it is expecting 1000 classes and not 5.

size mismatch for module.linear.weight: copying a param with shape torch.Size([1000, 1536]) from checkpoint, the shape in current model is torch.Size([5, 1536]).
size mismatch for module.linear.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([5]).

When I look at the code

dino/eval_linear.py

Lines 79 to 83 in 7c446df

if args.evaluate:
utils.load_pretrained_linear_weights(linear_classifier, args.arch, args.patch_size)
test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
return

It is downloading new weights and trying to run the evaluation on the new weights.
I think this is a bug and if I am providing the weights, it should not download the weights as it is doing in the utils.load_pretrained_linear_weights.
When I comment out Line 80 in the eval_linear.py file, the code works fine.

Is this the right thing to do. Please let me know.

P.S.: I know that positing screenshots is generally not the norm. But I wanted to show that it is downloading new weights.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions