The MNIST dataset is a dataset of handwritten digits, each of which is a 28x28 grayscale image. We'll train a diffusion model to generate new images of handwritten digits.
We'll start by using the same model architecture as in the previous chapter, and train for 10 epochs.
class ConvBlock(torch.nn.Module):
def __init__(self, in_channels, out_channels, embed_dim):
super(ConvBlock, self).__init__()
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.proj = torch.nn.Linear(embed_dim, out_channels)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, 3, padding=1)
def forward(self, x, embedding):
x = self.conv1(x)
emb_proj = self.proj(embedding).view(-1, x.size(1), 1, 1)
x = torch.nn.functional.relu(x + emb_proj)
x = self.conv2(x)
x = torch.nn.functional.relu(x)
return x
class Model(torch.nn.Module):
def __init__(self, num_steps=1000, embed_dim=16):
super(Model, self).__init__()
self.embed = torch.nn.Embedding(num_steps, embed_dim)
self.enc1 = ConvBlock(1, 16, embed_dim)
self.enc2 = ConvBlock(16, 32, embed_dim)
self.bottleneck = ConvBlock(32, 64, embed_dim)
self.upconv2 = torch.nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
self.dec2 = ConvBlock(64, 32, embed_dim)
self.upconv1 = torch.nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
self.dec1 = ConvBlock(32, 16, embed_dim)
self.final = torch.nn.Conv2d(16, 1, kernel_size=1)
def forward(self, x, t):
emb = self.embed(t)
enc1 = self.enc1(x, emb)
enc2 = self.enc2(torch.nn.functional.max_pool2d(enc1, 2), emb)
bottleneck = self.bottleneck(torch.nn.functional.max_pool2d(enc2, 2), emb)
dec2 = self.dec2(torch.cat([enc2, self.upconv2(bottleneck)], 1), emb)
dec1 = self.dec1(torch.cat([enc1, self.upconv1(dec2)], 1), emb)
out = self.final(dec1)
return outWe can run this with:
python part_a_mnist.py trainYou can also change additional hyperparameters, such as the learning rate, batch size, and number of epochs:
python part_a_mnist.py train --batch-size 128 --lr 1e-3 --epochs 10After training, we can sample a grid of images from the model with:
python part_a_mnist.py testOur resulting output looks like this:
As you can see, we have some room for improvement.
Let's make some changes to our model:
In our previous models, we used an embedding layer to learn a unique embedding for each timestep. However, another common technique is to use a positional encoding, which is a fixed function of the timestep. This is the same technique used with transformers to encode the position of each token in the sequence.
We can define a simple positional encoding function as follows:
class PositionalEncoding(torch.nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
# Create a matrix to hold the positional encodings
pe = torch.zeros(max_len, d_model)
# Compute the positional encodings
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# Register pe as a buffer to avoid updating it during backpropagation
self.register_buffer('pe', pe)
def forward(self, x):
# Retrieve the positional encodings
return self.pe[x]We can think of positional encoding as a fixed function that maps each timestep to a unique vector. If you're familiar with Fourier series, you might recognize that each component of the positional encoding is a sine or cosine function with a different frequency.
Let's train our new model:
python part_c_mnist.py trainYou can also change additional hyperparameters, such as the learning rate, batch size, and number of epochs:
python part_c_mnist.py train --batch-size 128 --lr 1e-3 --epochs 10After training, we can sample a grid of images from the model with:
python part_c_mnist.py testOur resulting output looks like this:
We can see that our model has improved significantly, and we've even reduced the number of parameters by switching to positional encoding.
We'll continue with two more changes to our model:
We'll modify our ConvBlock modules to include GroupNorm layers for normalization:
class ConvBlock(torch.nn.Module):
def __init__(self, in_channels, out_channels, embed_dim):
super(ConvBlock, self).__init__()
self.norm1 = torch.nn.GroupNorm(16, in_channels)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.proj = torch.nn.Linear(embed_dim, out_channels)
self.norm2 = torch.nn.GroupNorm(16, out_channels)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x, embedding):
x = self.norm1(x)
x = torch.nn.functional.relu(x)
x = self.conv1(x)
emb_proj = self.proj(embedding).view(-1, x.size(1), 1, 1)
x = x + emb_proj
x = self.norm2(x)
x = torch.nn.functional.relu(x)
x = self.conv2(x)
return xWe'll also update our embedding to follow our positional encoding with two fully-connected layers:
self.embed = torch.nn.Sequential(
PositionalEncoding(embed_dim, num_steps),
torch.nn.Linear(embed_dim, embed_dim),
torch.nn.ReLU(),
torch.nn.Linear(embed_dim, embed_dim),
torch.nn.ReLU(),
)Our new model looks like this:
class Model(torch.nn.Module):
def __init__(self, num_steps=1000, embed_dim=64):
super(Model, self).__init__()
self.embed = torch.nn.Sequential(
PositionalEncoding(embed_dim, num_steps),
torch.nn.Linear(embed_dim, embed_dim),
torch.nn.ReLU(),
torch.nn.Linear(embed_dim, embed_dim),
torch.nn.ReLU(),
)
self.conv_in = torch.nn.Conv2d(1, 16, kernel_size=3, padding=1)
self.enc1 = ConvBlock(16, 16, embed_dim)
self.enc2 = ConvBlock(16, 32, embed_dim)
self.bottleneck = ConvBlock(32, 64, embed_dim)
self.upconv2 = torch.nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
self.dec2 = ConvBlock(64, 32, embed_dim)
self.upconv1 = torch.nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
self.dec1 = ConvBlock(32, 16, embed_dim)
self.norm_out = torch.nn.GroupNorm(16, 16)
self.conv_out = torch.nn.Conv2d(16, 1, kernel_size=3, padding=1)
def forward(self, x, t):
emb = self.embed(t)
x = self.conv_in(x)
enc1 = self.enc1(x, emb)
enc2 = self.enc2(torch.nn.functional.max_pool2d(enc1, 2), emb)
bottleneck = self.bottleneck(torch.nn.functional.max_pool2d(enc2, 2), emb)
dec2 = self.dec2(torch.cat([enc2, self.upconv2(bottleneck)], 1), emb)
dec1 = self.dec1(torch.cat([enc1, self.upconv1(dec2)], 1), emb)
out = self.norm_out(dec1)
out = torch.nn.functional.relu(out)
out = self.conv_out(out)
return outLet's train this model now:
python part_d_mnist.py trainYou can also change additional hyperparameters, such as the learning rate, batch size, and number of epochs:
python part_d_mnist.py train --batch-size 128 --lr 1e-3 --epochs 20After training, we can sample a grid of images from the model with:
python part_d_mnist.py testOur resulting output looks like this:
This is already looking quite good, but we can still make further improvements.
We can further improve our model with the following improvements:
We'll replace our ConvBlock modules with a more advanced residual block:
class ResnetBlock(torch.nn.Module):
def __init__(self, in_channels, out_channels, embed_dim):
super(ResnetBlock, self).__init__()
self.norm1 = torch.nn.GroupNorm(16, in_channels)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.proj = torch.nn.Linear(embed_dim, out_channels)
self.norm2 = torch.nn.GroupNorm(16, out_channels)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else None
def forward(self, x, embedding):
_input = x
x = self.norm1(x)
x = torch.nn.functional.relu(x)
x = self.conv1(x)
emb_proj = self.proj(embedding).view(-1, x.size(1), 1, 1)
x = x + emb_proj
x = self.norm2(x)
x = torch.nn.functional.relu(x)
x = self.conv2(x)
if self.shortcut is not None:
_input = self.shortcut(_input)
return x + _inputThis allows us to stack multiple residual blocks together, and still maintain a good flow of gradients.
Currently, we use max pooling for downsampling and transposed convolutions for upsampling.
Let's replace our max pooling layers with strided convolutions:
class Downsample(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super(Downsample, self).__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
def forward(self, x):
return self.conv(x)Instead of transposed convolutions, we'll use nearest-neighbor upsampling followed by a convolution:
class Upsample(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super(Upsample, self).__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv(x)
return xOur new model looks like this:
class Model(torch.nn.Module):
def __init__(self, num_steps=1000, embed_dim=64):
super(Model, self).__init__()
self.embed = torch.nn.Sequential(
PositionalEncoding(embed_dim, num_steps),
torch.nn.Linear(embed_dim, embed_dim),
torch.nn.ReLU(),
torch.nn.Linear(embed_dim, embed_dim),
torch.nn.ReLU(),
)
self.conv_in = torch.nn.Conv2d(1, 16, kernel_size=3, padding=1)
self.enc1_1 = ResnetBlock(16, 16, embed_dim)
self.enc1_2 = ResnetBlock(16, 32, embed_dim)
self.downconv1 = Downsample(32, 32)
self.enc2_1 = ResnetBlock(32, 32, embed_dim)
self.enc2_2 = ResnetBlock(32, 64, embed_dim)
self.downconv2 = Downsample(64, 64)
self.bottleneck_1 = ResnetBlock(64, 64, embed_dim)
self.bottleneck_2 = ResnetBlock(64, 64, embed_dim)
self.upconv2 = Upsample(64, 64)
self.dec2_1 = ResnetBlock(128, 64, embed_dim)
self.dec2_2 = ResnetBlock(64, 32, embed_dim)
self.upconv1 = Upsample(32, 32)
self.dec1_1 = ResnetBlock(64, 32, embed_dim)
self.dec1_2 = ResnetBlock(32, 16, embed_dim)
self.norm_out = torch.nn.GroupNorm(16, 16)
self.conv_out = torch.nn.Conv2d(16, 1, kernel_size=3, padding=1)
def forward(self, x, t):
emb = self.embed(t)
x = self.conv_in(x)
x = self.enc1_1(x, emb)
enc1 = self.enc1_2(x, emb)
x = self.downconv1(enc1)
x = self.enc2_1(x, emb)
enc2 = self.enc2_2(x, emb)
x = self.downconv2(enc2)
x = self.bottleneck_1(x, emb)
x = self.bottleneck_2(x, emb)
x = self.upconv2(x)
x = torch.cat([x, enc2], 1)
x = self.dec2_1(x, emb)
x = self.dec2_2(x, emb)
x = self.upconv1(x)
x = torch.cat([x, enc1], 1)
x = self.dec1_1(x, emb)
x = self.dec1_2(x, emb)
x = self.norm_out(x)
x = torch.nn.functional.relu(x)
x = self.conv_out(x)
return xLet's train this model now:
python main.py trainWe can also change additional hyperparameters, such as the learning rate, batch size, and number of epochs:
python main.py train --batch-size 128 --lr 1e-3 --epochs 80After training, we can sample a grid of images from the model with:
python main.py testOur resulting output looks like this:
This output is significantly better than our previous models!





