Let's now train some conditional models.
Instead of our model just taking
We'll use an additional embedding layer to encode the class index
self.class_emb = torch.nn.Embedding(num_classes, embed_dim)In our forward pass, we'll add the class embedding to the time embedding:
emb = self.embed(t) # time embedding
emb_class = self.class_emb(class_idx) # class embedding
emb = emb + emb_classIn order to simplify the training options, we'll create some default configurations for our datasets and models. These will act as defaults for our CLI args.
For example, this is our CIFAR10 conditional model configuration:
"cifar10-cond": [
'train',
'--dataset', 'cifar10',
'--conditional',
'--batch-size', '128',
'--grad-clip', '1',
'--lr', '2e-4',
'--warmup', '5000',
'--steps', '800_000',
'--val-interval', '2000',
'--model-channels', '128',
'--channel-mult', '1', '2', '2', '2',
'--num-res-blocks', '2',
'--attention-resolutions', '2',
'--dropout', '0.1',
'--hflip',
'--save-checkpoints',
'--log-interval', '5',
'--progress',
]Now we can train a CIFAR10 conditional model with the following command:
python main.py train --config cifar10-cond --gpu 0 --output-dir results/cifar10-condOur model will periodically generate samples from each of the classes and save the grid of images. Here are the samples from the fully trained CIFAR10 conditional model:
Each of the rows corresponds to a different class:
- airplane
- automobile
- bird
- cat
- deer
- dog
- frog
- horse
- ship
- truck
We can also train a conditional model on MNIST.
python main.py train --config mnist-cond --gpu 0 --output-dir results/mnist-cond
