Skip to content

This repository contains code for pretraining Wide Residual Network (WRN) using ImageNet

License

Notifications You must be signed in to change notification settings

attaullah/Pretraining-WideResNet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Pretraining Wide Residual Network

This repository contains code for pretraining Wide Residual Network (WRN) [1] on downsampled [2] ImageNet 32x32, ImageNet 64x64, and ImageNet 224x224 using cross-entropy and triplet loss [3].

Environment setup

For creating conda environment, a yml file tf2.yml is provided for replicating setup.

conda env create -f tf2.yml
conda activate tf2

Data preparation

ImageNet full dataset can be downloaded from link. After downloading, set the path of base_dir in data.py.

ImageNet 32x32 and ImageNet 64x64 datasets can be generated either using scripts provided by Downsampled ImageNet or TensorFlow datasets package. The tensorflow_datasets package can be installed using pip:

pip install tensorflow_datasets

The current version of tensorflow_datasets=4.4.0 package has a broken link for downloading ImageNet 32x32 and ImageNet 64x64. The workaround is available at GitHub.

Pretraining

For pretraining from scratch using different setups, pretrain.py can be used. Details of self-explanatory commandline arguments can be seen by passing --help to it.

 python pretrain.py --help
 
       USAGE: pretrain.py [flags]
flags:

pretrain.py:
  --bs: batch_size
    (default: '128')
    (an integer)
  --d: <imagenet_resized/32x32|imagenet_resized/64x64|imagenet-full>: dataset
    (default: 'imagenet_resized/32x32')
  --e: number of epochs
    (default: '50')
    (an integer)
  --g: gpu id
    (default: '0')
  --lbl: <lda|knn>: Specify labelling method either LDA or KNN.
    (default: 'lda')
  --lr: learning_rate
    (default: '0.001')
    (a number)
  --lt: <cross-entropy|triplet>: loss_type  either cross-entropy  or triplet.
    (default: 'cross-entropy')
  --margin: margin for triplet loss
    (default: '1.0')
    (a number)
  --n: network
    (default: 'wrn-28-2')
  --[no]sw: save weights
    (default: 'false')

Try --helpfull to get a list of all flags.

Pretrained weights will be saved into weights/ directory. We also provide pretrained weights. They can be downloaded from releases and saved into weights/ directory. Path of downloaded weights can be set in wrn.py.

Example usage

For using pretrained weights, an example notebook is provided . For more details, see cifar_example.ipynb.

Citation

If you use the provided weights, kindly cite our paper.

@inproceedings{sahito2022better,
  title={Better self-training for image classification through self-supervision},
  author={Sahito, Attaullah and Frank, Eibe and Pfahringer, Bernhard},
  booktitle={Australasian Joint Conference on Artificial Intelligence},
  pages={645--657},
  year={2022},
  organization={Springer}
}

References

  1. Wide Residual Networks. Sergey Zagoruyko and Nikos Komodakis. In British Machine Vision Conference 2016. British Machine Vision Association, 2016.
  2. A downsampled variant of ImageNet as an alternative to the CIFAR datasets. Patryk Chrabaszcz, Ilya Loshchilov, and Frank Hutter. arXiv preprint arXiv:1707.08819, 2017 .
  3. Distance metric learning for large margin nearest neighbour classification. Kilian Q Weinberger and Lawrence K Saul. Journal of Machine Learning Research, 10(2), 2009.

License

MIT