GAN#
This notebook covers the process of building a Generative Adversarial Network (GAN) and it’s application to the MNIST
dataset.
import numpy as np
from pathlib import Path
from IPython.display import HTML
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.nn.functional import binary_cross_entropy
import torchvision.utils as vutils
import torchvision.transforms as T
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
import matplotlib.animation as animation
if torch.cuda.is_available():
DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
DEVICE = torch.device("mps")
else:
DEVICE = torch.device("cpu")
print("using device", DEVICE)
using device cpu
As example MNIST
dataset will be used.
TRAIN_DATASET = MNIST(
Path("mnist_files"),
download=True,
train=True,
transform=T.ToTensor()
)
DATA_LOADER = DataLoader(TRAIN_DATASET, batch_size=64)
Generator#
Following cell implements Generator
class.
class Generator(nn.Module):
'''
A class that generates a picture from a set of random noise.
Parameters
----------
feature_map_size: int
Feature map's size of the generator.
number_channels: int
Number of input channels.
noise_size: int
Size of the vector that is expected to be transformed to the picture by
the model.
'''
def __init__(
self,
feature_map_size: int,
number_channels: int,
noise_size: int
) -> None:
super(Generator, self).__init__()
self.main = nn.Sequential(
# (nz) x 1 x 1
nn.ConvTranspose2d(
in_channels=noise_size,
out_channels=feature_map_size * 2,
kernel_size=7,
stride=1,
padding=0,
bias=False
),
nn.BatchNorm2d(feature_map_size * 2),
nn.ReLU(True),
# (feature_map_size*2) x 7 x 7
nn.ConvTranspose2d(
in_channels=feature_map_size * 2,
out_channels=feature_map_size,
kernel_size=4,
stride=2,
padding=1,
bias=False
),
nn.BatchNorm2d(feature_map_size),
nn.ReLU(True),
# (feature_map_size) x 14 x 14
nn.ConvTranspose2d(
in_channels=feature_map_size,
out_channels=number_channels,
kernel_size=4,
stride=2,
padding=1,
bias=False
),
nn.Tanh(),
# (nc) x 28 x 28
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
'''
Apply model to given data.
Paramaters
----------
input: torch.Tensor
Tensor with size (n_samples, self.nz)
Returns
-------
out: torch.Tensor
Tensor that represents set of generated pictures.
'''
return self.main(input[:, :, None, None])
Now let’s check what options there are to use it: to generator you have to pass vectors with random values. For each random vertor it will return picture.
generator = Generator(
feature_map_size=64,
number_channels=1,
noise_size=100
)
generator(torch.randn(20, 100)).shape
torch.Size([20, 1, 28, 28])
Discriminator#
The discriminator is a model that tries to determine if a picture was created by the generator or not. The following cell defines the generator that we will use for this example.
class Discriminator(nn.Module):
'''
Realisation of the discriminator. Class that takes picture and generate scor
which expresses how much the model thinks the picture is generated.
Parameters
----------
number_channels: int
Number of channels in input.
feature_map_size: int
Feature map's size.
'''
def __init__(self, number_channels: int, feature_map_size: int):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
# (number_channels) x 28 x 28
nn.Conv2d(
in_channels=number_channels,
out_channels=feature_map_size,
kernel_size=4,
stride=2,
padding=1,
bias=False
),
nn.LeakyReLU(0.2, inplace=True),
# (feature_map_size) x 14 x 14
nn.Conv2d(
in_channels=feature_map_size,
out_channels=feature_map_size * 2,
kernel_size=4,
stride=2,
padding=1,
bias=False
),
nn.BatchNorm2d(feature_map_size * 2),
nn.LeakyReLU(0.2, inplace=True),
# (feature_map_size*2) x 7 x 7
nn.Conv2d(
in_channels=feature_map_size * 2,
out_channels=1,
kernel_size=7,
stride=1,
padding=0,
bias=False
),
nn.Sigmoid(),
nn.Flatten(start_dim=0, end_dim=-1)
)
def forward(self, input: torch.Tensor):
'''
Apply model to given data.
Paramaters
----------
input: torch.Tensor
Tensor with size (n_samples, self.nz)
Returns
-------
out: torch.Tensor
One number torch tensor that represents the score that represents
the score that the picture is real (not generated).
'''
return self.main(input)
Consider how the discriminator works by passing a sample picture from the training data to it:
discriminator = Discriminator(
number_channels=1,
feature_map_size=64
)
discriminator(TRAIN_DATASET[0][0][None, :, :, :])
tensor([0.4685], grad_fn=<ViewBackward0>)
We got a score that represents the model’s prediction of whether the picture we passed was generated.
Model fitting#
We need to train the discriminator to determine if its input was original or generated by the generator. The following cell implements the optimization step with two core elements:
Gradient accumulation to increase predicted scores for real images.
Gradient accumulation to decrease predicted scores for generated images.
def discriminator_step(
pictures: torch.Tensor,
generation: torch.Tensor,
discriminator: torch.nn.Module,
optimizer: torch.optim.Optimizer
) -> tuple[float, float, float]:
'''
Step of the discriminator. Maximize log(D(x)) + log(1 - D(G(z))) - tries to
to improve the prediction that the real images have 1. scores and the
generated images have have 0. scores.
Parameters
----------
pictures: torch.Tensor
Batch of real images that we're trying to imitate.
generation: torch.Tensor
Batch of generated images.
discriminator: torch.nn.Module
Model that we optimise.
optimizer: torch.optim.Optimizer
Optimizet that uses weights of the dicriminator.
Return
------
out: tuple[float, float, float]
- Mean prediction for real images.
- Mean predictoin for fake images.
- Total loss value on both real and fake images.
'''
batch_size = pictures.shape[0]
discriminator.zero_grad()
# Gradient accumulation on real images
# Model should predict scores close to 1
label = torch.full((batch_size,), 1., dtype=torch.float, device=DEVICE)
output = discriminator(pictures)
errD_real = binary_cross_entropy(output, label)
errD_real.backward()
D_x = output.mean().item()
# Gradient accumulation on fake images
# Model should predicst scores close to 0
label.fill_(0.)
# Note: generation here bypasses the discriminator without gradient
# accumulation because we don't need generator gradients to optimize the
# discriminator.
output = discriminator(generation.detach())
errD_fake = binary_cross_entropy(output, label)
errD_fake.backward()
D_G_z1 = output.mean().item()
errD = errD_real + errD_fake
# Step based on accumulated gradients
optimizer.step()
return D_x, D_G_z1, errD
In this generator step, by calculating the loss (the difference between the discriminator’s judgment and the target “real” label), the generator learns the extent of improvement needed. This loss is used to compute gradients, guiding adjustments to the generator’s weights to make future outputs more realistic.
def generator_step(
generation: torch.Tensor,
discriminator: torch.nn.Module,
optimizer: torch.optim.Optimizer,
generator: torch.nn.Module
) -> tuple[float, float]:
'''
Step of the generator.
Parameters
----------
genration: torch.Tensor
Set of objects generated by generator.
disciminator: torch.nn.Module
Model which decision will determine behavior of the the generator.
optimizer: torch.optim.Optimizer
The optimizer that changes weights of the generator.
generator: torch.nn.Module
The model whose weights we are adjusting in this step.
Returns
-------
out: tuple[float, float]
- Mean prediction of the discriminator under generated data.
- Loss value of the discriminator on generated picture.
'''
batch_size = generation.shape[0]
generator.zero_grad()
label = torch.full(
size=(batch_size,),
fill_value=1.,
dtype=torch.float,
device=DEVICE
)
output = discriminator(generation)
errG = binary_cross_entropy(output, label)
errG.backward()
D_G_z2 = output.mean().item()
optimizer.step()
return D_G_z2, errG
There are reasons to initialize model weights with specific random values. The following function implements normal distribution-based weight initialization for our models.
def weights_init(m: torch.nn.Module) -> None:
'''
Function to initialize the weights of the model.
Parameters
----------
m: torch.nn.Module
Model that requires weight initialization.
'''
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
There are reasons to initialize model weights with specific random values. The following function implements normal distribution-based weight initialization for our models.
feature_map_size = 64
noise_size = 100
netG = Generator(
feature_map_size=feature_map_size,
number_channels=1,
noise_size=noise_size
).to(DEVICE)
netG = netG.apply(weights_init)
netD = Discriminator(
feature_map_size=feature_map_size,
number_channels=1
).to(DEVICE)
netD = netD.apply(weights_init)
lr = 0.001
beta1 = 0.5
# We'll draw images from the same input to compare results.
fixed_noise = torch.randn(64, noise_size, device=DEVICE)
optimizerD = torch.optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = torch.optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
img_list = []
G_losses = []
D_losses = []
num_epochs = 2
for epoch in range(num_epochs):
for i, (pictures, _) in enumerate(DATA_LOADER, 0):
pictures = pictures.to(DEVICE)
batch_size = pictures.size(0)
# Getting generated ("fake") picture that tries to trick discriminator
noise = torch.randn(batch_size, noise_size, device=DEVICE)
generation = netG(noise)
D_x, D_G_z1, errD = discriminator_step(
pictures=pictures,
generation=generation,
discriminator=netD,
optimizer=optimizerD
)
# Maximizing for generator log(D(G(z)))
D_G_z2, errG = generator_step(
generation=generation,
discriminator=netD,
optimizer=optimizerG,
generator=netG
)
G_losses.append(errG.item())
D_losses.append(errD.item())
if i % 50 == 0:
print(
f"[{epoch}/{num_epochs}][{i}/{len(DATA_LOADER)}]\t"
f"Discriminator's loss: {errD.item():.4f}\t"
f"Generator's loss: {errG.item():.4f}\t"
f"D(x): {D_x:.4f}\t"
f"D(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}"
)
if (i % 500 == 0) or (
(epoch == num_epochs - 1) and (i == len(DATA_LOADER) - 1)
):
with torch.no_grad():
generation = netG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(generation, padding=2, normalize=True))
[0/2][0/938] Discriminator's loss: 1.7518 Generator's loss: 4.7653 D(x): 0.6382 D(G(z)): 0.6811 / 0.0109
[0/2][50/938] Discriminator's loss: 0.0725 Generator's loss: 5.2438 D(x): 0.9644 D(G(z)): 0.0336 / 0.0091
[0/2][100/938] Discriminator's loss: 0.1012 Generator's loss: 3.6272 D(x): 0.9586 D(G(z)): 0.0549 / 0.0356
[0/2][150/938] Discriminator's loss: 0.1291 Generator's loss: 3.5180 D(x): 0.9185 D(G(z)): 0.0374 / 0.0387
[0/2][200/938] Discriminator's loss: 1.3565 Generator's loss: 0.9188 D(x): 0.5612 D(G(z)): 0.4311 / 0.4408
[0/2][250/938] Discriminator's loss: 0.7970 Generator's loss: 1.1331 D(x): 0.5356 D(G(z)): 0.0933 / 0.3967
[0/2][300/938] Discriminator's loss: 0.1237 Generator's loss: 3.6578 D(x): 0.9075 D(G(z)): 0.0214 / 0.0426
[0/2][350/938] Discriminator's loss: 0.1801 Generator's loss: 3.6914 D(x): 0.8570 D(G(z)): 0.0195 / 0.0399
[0/2][400/938] Discriminator's loss: 0.9553 Generator's loss: 7.4316 D(x): 0.4471 D(G(z)): 0.0008 / 0.0041
[0/2][450/938] Discriminator's loss: 0.1762 Generator's loss: 3.9034 D(x): 0.9406 D(G(z)): 0.0991 / 0.0354
[0/2][500/938] Discriminator's loss: 0.4488 Generator's loss: 2.2763 D(x): 0.8926 D(G(z)): 0.2537 / 0.1311
[0/2][550/938] Discriminator's loss: 0.2853 Generator's loss: 3.4381 D(x): 0.8152 D(G(z)): 0.0321 / 0.0599
[0/2][600/938] Discriminator's loss: 0.3209 Generator's loss: 3.0028 D(x): 0.8489 D(G(z)): 0.1176 / 0.0741
[0/2][650/938] Discriminator's loss: 0.3860 Generator's loss: 3.8455 D(x): 0.9540 D(G(z)): 0.2577 / 0.0285
[0/2][700/938] Discriminator's loss: 0.1155 Generator's loss: 6.1508 D(x): 0.8971 D(G(z)): 0.0023 / 0.0038
[0/2][750/938] Discriminator's loss: 1.0524 Generator's loss: 2.0052 D(x): 0.5873 D(G(z)): 0.1514 / 0.2225
[0/2][800/938] Discriminator's loss: 0.2644 Generator's loss: 2.6836 D(x): 0.8687 D(G(z)): 0.1012 / 0.0975
[0/2][850/938] Discriminator's loss: 0.6571 Generator's loss: 4.0227 D(x): 0.9525 D(G(z)): 0.4043 / 0.0249
[0/2][900/938] Discriminator's loss: 0.3223 Generator's loss: 2.8650 D(x): 0.8586 D(G(z)): 0.1276 / 0.0781
[1/2][0/938] Discriminator's loss: 0.7768 Generator's loss: 1.0826 D(x): 0.6014 D(G(z)): 0.0529 / 0.4178
[1/2][50/938] Discriminator's loss: 0.4546 Generator's loss: 1.9901 D(x): 0.7953 D(G(z)): 0.1627 / 0.1818
[1/2][100/938] Discriminator's loss: 0.8563 Generator's loss: 1.9985 D(x): 0.5223 D(G(z)): 0.0308 / 0.2026
[1/2][150/938] Discriminator's loss: 0.3607 Generator's loss: 3.4995 D(x): 0.9286 D(G(z)): 0.2162 / 0.0461
[1/2][200/938] Discriminator's loss: 0.3745 Generator's loss: 2.5101 D(x): 0.7906 D(G(z)): 0.0939 / 0.1151
[1/2][250/938] Discriminator's loss: 1.0054 Generator's loss: 1.4995 D(x): 0.4711 D(G(z)): 0.0315 / 0.3851
[1/2][300/938] Discriminator's loss: 0.9267 Generator's loss: 0.3721 D(x): 0.4974 D(G(z)): 0.0642 / 0.7397
[1/2][350/938] Discriminator's loss: 0.5278 Generator's loss: 1.9388 D(x): 0.6836 D(G(z)): 0.0487 / 0.2187
[1/2][400/938] Discriminator's loss: 0.6281 Generator's loss: 1.5017 D(x): 0.7426 D(G(z)): 0.1975 / 0.3022
[1/2][450/938] Discriminator's loss: 0.2970 Generator's loss: 2.9298 D(x): 0.8307 D(G(z)): 0.0727 / 0.0935
[1/2][500/938] Discriminator's loss: 1.5452 Generator's loss: 4.1130 D(x): 0.9816 D(G(z)): 0.6332 / 0.0274
[1/2][550/938] Discriminator's loss: 0.2576 Generator's loss: 3.0462 D(x): 0.8743 D(G(z)): 0.0995 / 0.0731
[1/2][600/938] Discriminator's loss: 0.4843 Generator's loss: 2.8078 D(x): 0.8329 D(G(z)): 0.2077 / 0.0998
[1/2][650/938] Discriminator's loss: 0.3904 Generator's loss: 2.9556 D(x): 0.8128 D(G(z)): 0.1237 / 0.0928
[1/2][700/938] Discriminator's loss: 0.3428 Generator's loss: 3.0855 D(x): 0.8799 D(G(z)): 0.1516 / 0.0747
[1/2][750/938] Discriminator's loss: 0.9444 Generator's loss: 2.9951 D(x): 0.9222 D(G(z)): 0.4475 / 0.0877
[1/2][800/938] Discriminator's loss: 0.5628 Generator's loss: 2.4849 D(x): 0.8243 D(G(z)): 0.2346 / 0.1288
[1/2][850/938] Discriminator's loss: 0.4026 Generator's loss: 2.4009 D(x): 0.8424 D(G(z)): 0.1728 / 0.1268
[1/2][900/938] Discriminator's loss: 0.3921 Generator's loss: 2.4337 D(x): 0.8245 D(G(z)): 0.1380 / 0.1236
Below is showen how generator evaluates each 500 steps.
fig = plt.figure(figsize=(8, 8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(
fig=fig,
artists=ims,
interval=1000,
repeat_delay=1000,
blit=True
)
plt.close()
HTML(ani.to_jshtml())
Result pictures isn’t briliant but at least they look like something handwritten.