GAN#

This notebook covers the process of building a Generative Adversarial Network (GAN) and it’s application to the MNIST dataset.

import numpy as np
from pathlib import Path
from IPython.display import HTML

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.nn.functional import binary_cross_entropy

import torchvision.utils as vutils
import torchvision.transforms as T
from torchvision.datasets import MNIST

import matplotlib.pyplot as plt
import matplotlib.animation as animation


if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")
print("using device", DEVICE)
using device cpu

As example MNIST dataset will be used.

TRAIN_DATASET = MNIST(
    Path("mnist_files"),
    download=True,
    train=True,
    transform=T.ToTensor()
)
DATA_LOADER = DataLoader(TRAIN_DATASET, batch_size=64)

Generator#

Following cell implements Generator class.

class Generator(nn.Module):
    '''
    A class that generates a picture from a set of random noise.

    Parameters
    ----------
    feature_map_size: int
        Feature map's size of the generator.
    number_channels: int
        Number of input channels.
    noise_size: int
        Size of the vector that is expected to be transformed to the picture by 
        the model.
    '''
    def __init__(
        self, 
        feature_map_size: int, 
        number_channels: int, 
        noise_size: int
    ) -> None:
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # (nz) x 1 x 1
            nn.ConvTranspose2d(
                in_channels=noise_size, 
                out_channels=feature_map_size * 2, 
                kernel_size=7,
                stride=1, 
                padding=0, 
                bias=False
            ),
            nn.BatchNorm2d(feature_map_size * 2),
            nn.ReLU(True),

            # (feature_map_size*2) x 7 x 7
            nn.ConvTranspose2d(
                in_channels=feature_map_size * 2, 
                out_channels=feature_map_size, 
                kernel_size=4, 
                stride=2, 
                padding=1, 
                bias=False
            ),
            nn.BatchNorm2d(feature_map_size),
            nn.ReLU(True),

            # (feature_map_size) x 14 x 14
            nn.ConvTranspose2d(
                in_channels=feature_map_size, 
                out_channels=number_channels, 
                kernel_size=4, 
                stride=2, 
                padding=1, 
                bias=False
            ),
            nn.Tanh(),
            # (nc) x 28 x 28
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        '''
        Apply model to given data.

        Paramaters
        ----------
        input: torch.Tensor
            Tensor with size (n_samples, self.nz)

        Returns
        -------
        out: torch.Tensor
            Tensor that represents set of generated pictures.
        '''
        return self.main(input[:, :, None, None])

Now let’s check what options there are to use it: to generator you have to pass vectors with random values. For each random vertor it will return picture.

generator = Generator(
    feature_map_size=64, 
    number_channels=1, 
    noise_size=100
)
generator(torch.randn(20, 100)).shape
torch.Size([20, 1, 28, 28])

Discriminator#

The discriminator is a model that tries to determine if a picture was created by the generator or not. The following cell defines the generator that we will use for this example.

class Discriminator(nn.Module):
    '''
    Realisation of the discriminator. Class that takes picture and generate scor
    which expresses how much the model thinks the picture is generated.

    Parameters
    ----------
    number_channels: int
        Number of channels in input.
    feature_map_size: int
        Feature map's size.
    '''

    def __init__(self, number_channels: int, feature_map_size: int):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # (number_channels) x 28 x 28
            nn.Conv2d(
                in_channels=number_channels, 
                out_channels=feature_map_size, 
                kernel_size=4, 
                stride=2, 
                padding=1, 
                bias=False
            ),
            nn.LeakyReLU(0.2, inplace=True),


            # (feature_map_size) x 14 x 14
            nn.Conv2d(
                in_channels=feature_map_size, 
                out_channels=feature_map_size * 2, 
                kernel_size=4, 
                stride=2, 
                padding=1, 
                bias=False
            ),
            nn.BatchNorm2d(feature_map_size * 2),
            nn.LeakyReLU(0.2, inplace=True),


            # (feature_map_size*2) x 7 x 7
            nn.Conv2d(
                in_channels=feature_map_size * 2, 
                out_channels=1, 
                kernel_size=7, 
                stride=1, 
                padding=0, 
                bias=False
            ),
            
            nn.Sigmoid(),
            nn.Flatten(start_dim=0, end_dim=-1)
        )

    def forward(self, input: torch.Tensor):
        '''
        Apply model to given data.

        Paramaters
        ----------
        input: torch.Tensor
            Tensor with size (n_samples, self.nz)

        Returns
        -------
        out: torch.Tensor
            One number torch tensor that represents the score that represents 
            the score that the picture is real (not generated).
        '''
        return self.main(input)

Consider how the discriminator works by passing a sample picture from the training data to it:

discriminator = Discriminator(
    number_channels=1, 
    feature_map_size=64
)
discriminator(TRAIN_DATASET[0][0][None, :, :, :])
tensor([0.4685], grad_fn=<ViewBackward0>)

We got a score that represents the model’s prediction of whether the picture we passed was generated.

Model fitting#

We need to train the discriminator to determine if its input was original or generated by the generator. The following cell implements the optimization step with two core elements:

  • Gradient accumulation to increase predicted scores for real images.

  • Gradient accumulation to decrease predicted scores for generated images.

def discriminator_step(
    pictures: torch.Tensor,
    generation: torch.Tensor,
    discriminator: torch.nn.Module,
    optimizer: torch.optim.Optimizer
) -> tuple[float, float, float]:
    '''
    Step of the discriminator. Maximize log(D(x)) + log(1 - D(G(z))) - tries to 
    to improve the prediction that the real images have 1. scores and the 
    generated images have have 0. scores.

    Parameters
    ----------
    pictures: torch.Tensor
        Batch of real images that we're trying to imitate.
    generation: torch.Tensor
        Batch of generated images.
    discriminator: torch.nn.Module
        Model that we optimise.
    optimizer: torch.optim.Optimizer
        Optimizet that uses weights of the dicriminator.

    Return
    ------
    out: tuple[float, float, float]
        - Mean prediction for real images.
        - Mean predictoin for fake images.
        - Total loss value on both real and fake images.
    '''

    batch_size = pictures.shape[0]
    discriminator.zero_grad()

    # Gradient accumulation on real images
    # Model should predict scores close to 1
    label = torch.full((batch_size,), 1., dtype=torch.float, device=DEVICE)
    output = discriminator(pictures)
    errD_real = binary_cross_entropy(output, label)
    errD_real.backward()
    D_x = output.mean().item()

    # Gradient accumulation on fake images
    # Model should predicst scores close to 0
    label.fill_(0.)
    # Note: generation here bypasses the discriminator without gradient 
    # accumulation because we don't need generator gradients to optimize the 
    # discriminator.
    output = discriminator(generation.detach())
    errD_fake = binary_cross_entropy(output, label)
    errD_fake.backward()
    D_G_z1 = output.mean().item()


    errD = errD_real + errD_fake

    # Step based on accumulated gradients
    optimizer.step()

    return D_x, D_G_z1, errD

In this generator step, by calculating the loss (the difference between the discriminator’s judgment and the target “real” label), the generator learns the extent of improvement needed. This loss is used to compute gradients, guiding adjustments to the generator’s weights to make future outputs more realistic.

def generator_step(
    generation: torch.Tensor,
    discriminator: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    generator: torch.nn.Module
) -> tuple[float, float]:
    '''
    Step of the generator.

    Parameters
    ----------
    genration: torch.Tensor
        Set of objects generated by generator.
    disciminator: torch.nn.Module
        Model which decision will determine behavior of the the generator.
    optimizer: torch.optim.Optimizer
        The optimizer that changes weights of the generator.
    generator: torch.nn.Module
        The model whose weights we are adjusting in this step.

    Returns
    -------
    out: tuple[float, float]
        - Mean prediction of the discriminator under generated data.
        - Loss value of the discriminator on generated picture.
    '''

    batch_size = generation.shape[0]
    generator.zero_grad()
    
    label = torch.full(
        size=(batch_size,), 
        fill_value=1.,
        dtype=torch.float,
        device=DEVICE
    )
    output = discriminator(generation)
    errG = binary_cross_entropy(output, label)
    errG.backward()


    D_G_z2 = output.mean().item()
    optimizer.step()

    return D_G_z2, errG

There are reasons to initialize model weights with specific random values. The following function implements normal distribution-based weight initialization for our models.

def weights_init(m: torch.nn.Module) -> None:
    '''
    Function to initialize the weights of the model.

    Parameters
    ----------
    m: torch.nn.Module
        Model that requires weight initialization.
    '''
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

There are reasons to initialize model weights with specific random values. The following function implements normal distribution-based weight initialization for our models.

feature_map_size = 64
noise_size = 100

netG = Generator(
    feature_map_size=feature_map_size, 
    number_channels=1, 
    noise_size=noise_size
).to(DEVICE)
netG = netG.apply(weights_init)

netD = Discriminator(
    feature_map_size=feature_map_size, 
    number_channels=1
).to(DEVICE)
netD = netD.apply(weights_init)

lr = 0.001
beta1 = 0.5

# We'll draw images from the same input to compare results.
fixed_noise = torch.randn(64, noise_size, device=DEVICE)

optimizerD = torch.optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = torch.optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

img_list = []
G_losses = []
D_losses = []

num_epochs = 2

for epoch in range(num_epochs):
    for i, (pictures, _) in enumerate(DATA_LOADER, 0):
        pictures = pictures.to(DEVICE)
        batch_size = pictures.size(0)
        
        # Getting generated ("fake") picture that tries to trick discriminator
        noise = torch.randn(batch_size, noise_size, device=DEVICE)
        generation = netG(noise)

        D_x, D_G_z1, errD = discriminator_step(
            pictures=pictures,
            generation=generation,
            discriminator=netD,
            optimizer=optimizerD
        )

        # Maximizing for generator log(D(G(z)))
        D_G_z2, errG = generator_step(
            generation=generation,
            discriminator=netD,
            optimizer=optimizerG,
            generator=netG
        )

        G_losses.append(errG.item())
        D_losses.append(errD.item())

        if i % 50 == 0:
            print(
                f"[{epoch}/{num_epochs}][{i}/{len(DATA_LOADER)}]\t"
                f"Discriminator's loss: {errD.item():.4f}\t"
                f"Generator's loss: {errG.item():.4f}\t"
                f"D(x): {D_x:.4f}\t"
                f"D(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}"
            )

        if (i % 500 == 0) or (
            (epoch == num_epochs - 1) and (i == len(DATA_LOADER) - 1)
        ):
            with torch.no_grad():
                generation = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(generation, padding=2, normalize=True))
[0/2][0/938]	Discriminator's loss: 1.7518	Generator's loss: 4.7653	D(x): 0.6382	D(G(z)): 0.6811 / 0.0109
[0/2][50/938]	Discriminator's loss: 0.0725	Generator's loss: 5.2438	D(x): 0.9644	D(G(z)): 0.0336 / 0.0091
[0/2][100/938]	Discriminator's loss: 0.1012	Generator's loss: 3.6272	D(x): 0.9586	D(G(z)): 0.0549 / 0.0356
[0/2][150/938]	Discriminator's loss: 0.1291	Generator's loss: 3.5180	D(x): 0.9185	D(G(z)): 0.0374 / 0.0387
[0/2][200/938]	Discriminator's loss: 1.3565	Generator's loss: 0.9188	D(x): 0.5612	D(G(z)): 0.4311 / 0.4408
[0/2][250/938]	Discriminator's loss: 0.7970	Generator's loss: 1.1331	D(x): 0.5356	D(G(z)): 0.0933 / 0.3967
[0/2][300/938]	Discriminator's loss: 0.1237	Generator's loss: 3.6578	D(x): 0.9075	D(G(z)): 0.0214 / 0.0426
[0/2][350/938]	Discriminator's loss: 0.1801	Generator's loss: 3.6914	D(x): 0.8570	D(G(z)): 0.0195 / 0.0399
[0/2][400/938]	Discriminator's loss: 0.9553	Generator's loss: 7.4316	D(x): 0.4471	D(G(z)): 0.0008 / 0.0041
[0/2][450/938]	Discriminator's loss: 0.1762	Generator's loss: 3.9034	D(x): 0.9406	D(G(z)): 0.0991 / 0.0354
[0/2][500/938]	Discriminator's loss: 0.4488	Generator's loss: 2.2763	D(x): 0.8926	D(G(z)): 0.2537 / 0.1311
[0/2][550/938]	Discriminator's loss: 0.2853	Generator's loss: 3.4381	D(x): 0.8152	D(G(z)): 0.0321 / 0.0599
[0/2][600/938]	Discriminator's loss: 0.3209	Generator's loss: 3.0028	D(x): 0.8489	D(G(z)): 0.1176 / 0.0741
[0/2][650/938]	Discriminator's loss: 0.3860	Generator's loss: 3.8455	D(x): 0.9540	D(G(z)): 0.2577 / 0.0285
[0/2][700/938]	Discriminator's loss: 0.1155	Generator's loss: 6.1508	D(x): 0.8971	D(G(z)): 0.0023 / 0.0038
[0/2][750/938]	Discriminator's loss: 1.0524	Generator's loss: 2.0052	D(x): 0.5873	D(G(z)): 0.1514 / 0.2225
[0/2][800/938]	Discriminator's loss: 0.2644	Generator's loss: 2.6836	D(x): 0.8687	D(G(z)): 0.1012 / 0.0975
[0/2][850/938]	Discriminator's loss: 0.6571	Generator's loss: 4.0227	D(x): 0.9525	D(G(z)): 0.4043 / 0.0249
[0/2][900/938]	Discriminator's loss: 0.3223	Generator's loss: 2.8650	D(x): 0.8586	D(G(z)): 0.1276 / 0.0781
[1/2][0/938]	Discriminator's loss: 0.7768	Generator's loss: 1.0826	D(x): 0.6014	D(G(z)): 0.0529 / 0.4178
[1/2][50/938]	Discriminator's loss: 0.4546	Generator's loss: 1.9901	D(x): 0.7953	D(G(z)): 0.1627 / 0.1818
[1/2][100/938]	Discriminator's loss: 0.8563	Generator's loss: 1.9985	D(x): 0.5223	D(G(z)): 0.0308 / 0.2026
[1/2][150/938]	Discriminator's loss: 0.3607	Generator's loss: 3.4995	D(x): 0.9286	D(G(z)): 0.2162 / 0.0461
[1/2][200/938]	Discriminator's loss: 0.3745	Generator's loss: 2.5101	D(x): 0.7906	D(G(z)): 0.0939 / 0.1151
[1/2][250/938]	Discriminator's loss: 1.0054	Generator's loss: 1.4995	D(x): 0.4711	D(G(z)): 0.0315 / 0.3851
[1/2][300/938]	Discriminator's loss: 0.9267	Generator's loss: 0.3721	D(x): 0.4974	D(G(z)): 0.0642 / 0.7397
[1/2][350/938]	Discriminator's loss: 0.5278	Generator's loss: 1.9388	D(x): 0.6836	D(G(z)): 0.0487 / 0.2187
[1/2][400/938]	Discriminator's loss: 0.6281	Generator's loss: 1.5017	D(x): 0.7426	D(G(z)): 0.1975 / 0.3022
[1/2][450/938]	Discriminator's loss: 0.2970	Generator's loss: 2.9298	D(x): 0.8307	D(G(z)): 0.0727 / 0.0935
[1/2][500/938]	Discriminator's loss: 1.5452	Generator's loss: 4.1130	D(x): 0.9816	D(G(z)): 0.6332 / 0.0274
[1/2][550/938]	Discriminator's loss: 0.2576	Generator's loss: 3.0462	D(x): 0.8743	D(G(z)): 0.0995 / 0.0731
[1/2][600/938]	Discriminator's loss: 0.4843	Generator's loss: 2.8078	D(x): 0.8329	D(G(z)): 0.2077 / 0.0998
[1/2][650/938]	Discriminator's loss: 0.3904	Generator's loss: 2.9556	D(x): 0.8128	D(G(z)): 0.1237 / 0.0928
[1/2][700/938]	Discriminator's loss: 0.3428	Generator's loss: 3.0855	D(x): 0.8799	D(G(z)): 0.1516 / 0.0747
[1/2][750/938]	Discriminator's loss: 0.9444	Generator's loss: 2.9951	D(x): 0.9222	D(G(z)): 0.4475 / 0.0877
[1/2][800/938]	Discriminator's loss: 0.5628	Generator's loss: 2.4849	D(x): 0.8243	D(G(z)): 0.2346 / 0.1288
[1/2][850/938]	Discriminator's loss: 0.4026	Generator's loss: 2.4009	D(x): 0.8424	D(G(z)): 0.1728 / 0.1268
[1/2][900/938]	Discriminator's loss: 0.3921	Generator's loss: 2.4337	D(x): 0.8245	D(G(z)): 0.1380 / 0.1236

Below is showen how generator evaluates each 500 steps.

fig = plt.figure(figsize=(8, 8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(
    fig=fig, 
    artists=ims, 
    interval=1000, 
    repeat_delay=1000, 
    blit=True
)
plt.close()

HTML(ani.to_jshtml())

Result pictures isn’t briliant but at least they look like something handwritten.