Conditional GAN

Conditional GAN#

Conditional GAN (CGAN) is a type of generative model with extra parameters that specify the type of generation to be performed. This notebook uses the MNIST dataset as an example. The model will take random noise and a specified digit that should be generated from that noise.

import numpy as np
from pathlib import Path
from IPython.display import HTML

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.nn.functional import binary_cross_entropy

import torchvision.utils as vutils
import torchvision.transforms as T
from torchvision.datasets import MNIST

import matplotlib.pyplot as plt
import matplotlib.animation as animation

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")
print("using device", DEVICE)

TRAIN_DATASET = MNIST(
    Path("mnist_files"),
    download=True,
    train=True,
    transform=T.ToTensor()
)
DATA_LOADER = DataLoader(TRAIN_DATASET, batch_size=64)
using device cpu

Generator#

The main feature of the generator for CGAN is that it should somehow push information about wich numeber is required for generation to the model.


The following generator class uses torch.nn.Embedding to transform labels into vectors, which are then concatenated with the noise in the forward pass - that’s how information about label inters network.

class Generator(nn.Module):
    '''
    A class that generates a picture from a set of random noise.

    Parameters
    ----------
    feature_map_size: int
        Feature map's size of the generator.
    number_channels: int
        Number of input channels.
    noise_size: int
        Size of the vector that is expected to be transformed to the picture by 
        the model.
    '''
    def __init__(
        self, 
        feature_map_size: int, 
        number_channels: int, 
        noise_size: int,
        classes_number: int
    ) -> None:
        super(Generator, self).__init__()

        self.embedding = nn.Embedding(
            num_embeddings=classes_number,
            embedding_dim=classes_number
        )

        self.main = nn.Sequential(
            # (nz) x 1 x 1
            nn.ConvTranspose2d(
                in_channels=noise_size + classes_number,
                out_channels=feature_map_size * 2, 
                kernel_size=7,
                stride=1, 
                padding=0, 
                bias=False
            ),
            nn.BatchNorm2d(feature_map_size * 2),
            nn.ReLU(True),

            # (feature_map_size*2) x 7 x 7
            nn.ConvTranspose2d(
                in_channels=feature_map_size * 2, 
                out_channels=feature_map_size, 
                kernel_size=4, 
                stride=2, 
                padding=1, 
                bias=False
            ),
            nn.BatchNorm2d(feature_map_size),
            nn.ReLU(True),

            # (feature_map_size) x 14 x 14
            nn.ConvTranspose2d(
                in_channels=feature_map_size, 
                out_channels=number_channels, 
                kernel_size=4, 
                stride=2, 
                padding=1, 
                bias=False
            ),
            nn.Tanh(),
            # (nc) x 28 x 28
        )

    def forward(self, input: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        '''
        Apply model to given data.

        Paramaters
        ----------
        input: torch.Tensor
            Tensor with size (n_samples, self.nz).
        labels: torch.Tensor
            Set of labels that have to be generated.

        Returns
        -------
        out: torch.Tensor
            Tensor that represents set of generated pictures.
        '''
        labels = self.embedding(labels)
        input = torch.cat([input, labels], dim=1)
        return self.main(input[:, :, None, None])

Consider how that class is supposed to be used - with random noise passed labels of the digits that have to be generated.

generator = Generator(
    feature_map_size=64,
    number_channels=1,
    noise_size=100,
    classes_number=10
)
generator(
    torch.randn(20, 100), 
    torch.randint(low=0, high=10, size=(20,))
).shape
torch.Size([20, 1, 28, 28])

As the result a set of pictures.

Discriminator#

The discriminator, like the generator, should receive information about which number is in the image. The procedure is almost the same as for the generator, except that we need to concatenate the “label” channel, which should have the same dimensionality as the image channels.


The following cell implements a discriminator that additionally expects the labels of the images as input.

class Discriminator(nn.Module):
    '''
    Realisation of the discriminator. Class that takes picture and generate scor
    which expresses how much the model thinks the picture is generated.

    Parameters
    ----------
    number_channels: int
        Number of channels in input.
    feature_map_size: int
        Feature map's size.
    '''

    def __init__(
        self, 
        number_channels: int, 
        feature_map_size: int,
        classes_number: int
    ):
        super(Discriminator, self).__init__()

        self.embedding = torch.nn.Sequential(
            torch.nn.Embedding(
                num_embeddings=classes_number, embedding_dim=28**2
            ),
            # Transforms an embedding vector to a set of pictures with
            # given amount of channels
            torch.nn.Unflatten(
                dim=1, unflattened_size=(number_channels, 28, 28)
            )
        )

        # Increasing number of channels for embedding information
        number_channels += 1

        self.main = nn.Sequential(
            # (number_channels) x 28 x 28
            nn.Conv2d(
                in_channels=number_channels, 
                out_channels=feature_map_size, 
                kernel_size=4, 
                stride=2, 
                padding=1, 
                bias=False
            ),
            nn.LeakyReLU(0.2, inplace=True),


            # (feature_map_size) x 14 x 14
            nn.Conv2d(
                in_channels=feature_map_size, 
                out_channels=feature_map_size * 2, 
                kernel_size=4, 
                stride=2, 
                padding=1, 
                bias=False
            ),
            nn.BatchNorm2d(feature_map_size * 2),
            nn.LeakyReLU(0.2, inplace=True),


            # (feature_map_size*2) x 7 x 7
            nn.Conv2d(
                in_channels=feature_map_size * 2, 
                out_channels=1, 
                kernel_size=7, 
                stride=1, 
                padding=0, 
                bias=False
            ),
            
            nn.Sigmoid(),
            nn.Flatten(start_dim=0, end_dim=-1)
        )

    def forward(self, input: torch.Tensor, labels: torch.Tensor):
        '''
        Apply model to given data.

        Paramaters
        ----------
        input: torch.Tensor
            Tensor with size (n_samples, self.nz).
        labels: torch.Tensor
            Set of labels that have to be generated.

        Returns
        -------
        out: torch.Tensor
            One number torch tensor that represents the score that represents 
            the score that the picture is real (not generated).
        '''
        labels = self.embedding(labels)
        input = torch.cat([input, labels], dim=1)
        return self.main(input)

The next example shows a usage of discriminator.

discriminator = Discriminator(
    number_channels=1, 
    feature_map_size=64,
    classes_number=10
)

images, labels = next(iter(DATA_LOADER))
discriminator(images, labels)
tensor([0.4961, 0.3902, 0.6212, 0.6132, 0.4327, 0.3448, 0.6016, 0.5595, 0.5648,
        0.6109, 0.6144, 0.5024, 0.5680, 0.6251, 0.5695, 0.4638, 0.3622, 0.5321,
        0.6464, 0.4238, 0.6291, 0.4133, 0.3882, 0.6054, 0.5787, 0.3850, 0.5927,
        0.5878, 0.3770, 0.4629, 0.5566, 0.5205, 0.6753, 0.4036, 0.4051, 0.5235,
        0.6389, 0.4187, 0.4417, 0.6263, 0.5856, 0.5550, 0.4788, 0.4285, 0.5058,
        0.4392, 0.5430, 0.5248, 0.4525, 0.6055, 0.5848, 0.3872, 0.4926, 0.6212,
        0.4331, 0.4974, 0.4209, 0.4056, 0.6371, 0.6044, 0.5749, 0.6219, 0.6230,
        0.3880], grad_fn=<ViewBackward0>)

Model fitting#

Typical GAN training loop is enough ot train CGAN as well.

def discriminator_step(
    pictures: torch.Tensor,
    generation: torch.Tensor,
    discriminator: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    labels: torch.Tensor
) -> tuple[float, float, float]:
    '''
    Step of the discriminator. Maximize log(D(x)) + log(1 - D(G(z))) - tries to 
    to improve the prediction that the real images have 1. scores and the 
    generated images have have 0. scores.

    Parameters
    ----------
    pictures: torch.Tensor
        Batch of real images that we're trying to imitate.
    generation: torch.Tensor
        Batch of generated images.
    discriminator: torch.nn.Module
        Model that we optimise.
    optimizer: torch.optim.Optimizer
        Optimizet that uses weights of the dicriminator.
    labels: torch.Tensor
        Labels that describes what is displayed on the picture.

    Return
    ------
    out: tuple[float, float, float]
        - Mean prediction for real images.
        - Mean predictoin for fake images.
        - Total loss value on both real and fake images.
    '''

    batch_size = pictures.shape[0]
    discriminator.zero_grad()

    # Gradient accumulation on real images
    # Model should predict scores close to 1
    label = torch.full((batch_size,), 1., dtype=torch.float, device=DEVICE)
    output = discriminator(pictures, labels)
    errD_real = binary_cross_entropy(output, label)
    errD_real.backward()
    D_x = output.mean().item()

    # Gradient accumulation on fake images
    # Model should predicst scores close to 0
    label.fill_(0.)
    # Note: generation here bypasses the discriminator without gradient 
    # accumulation because we don't need generator gradients to optimize the 
    # discriminator.
    output = discriminator(generation.detach(), labels)
    errD_fake = binary_cross_entropy(output, label)
    errD_fake.backward()
    D_G_z1 = output.mean().item()


    errD = errD_real + errD_fake

    # Step based on accumulated gradients
    optimizer.step()

    return D_x, D_G_z1, errD

def generator_step(
    generation: torch.Tensor,
    discriminator: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    generator: torch.nn.Module,
    labels: torch.Tensor
) -> tuple[float, float]:
    '''
    Step of the generator.

    Parameters
    ----------
    genration: torch.Tensor
        Set of objects generated by generator.
    disciminator: torch.nn.Module
        Model which decision will determine behavior of the the generator.
    optimizer: torch.optim.Optimizer
        The optimizer that changes weights of the generator.
    generator: torch.nn.Module
        The model whose weights we are adjusting in this step.
    labels: torch.Tensor
        Labels that describes what is displayed on the picture.

    Returns
    -------
    out: tuple[float, float]
        - Mean prediction of the discriminator under generated data.
        - Loss value of the discriminator on generated picture.
    '''

    batch_size = generation.shape[0]
    generator.zero_grad()
    
    label = torch.full(
        size=(batch_size,), 
        fill_value=1.,
        dtype=torch.float,
        device=DEVICE
    )
    output = discriminator(generation, labels)
    errG = binary_cross_entropy(output, label)
    errG.backward()


    D_G_z2 = output.mean().item()
    optimizer.step()

    return D_G_z2, errG

feature_map_size = 64
noise_size = 100

def weights_init(m: torch.nn.Module) -> None:
    '''
    Function to initialize the weights of the model.

    Parameters
    ----------
    m: torch.nn.Module
        Model that requires weight initialization.
    '''
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

netG = Generator(
    feature_map_size=feature_map_size, 
    number_channels=1, 
    noise_size=noise_size,
    classes_number=10
).to(DEVICE)
netG = netG.apply(weights_init)

netD = Discriminator(
    feature_map_size=feature_map_size, 
    number_channels=1,
    classes_number=10
).to(DEVICE)
netD = netD.apply(weights_init)

lr = 0.001
beta1 = 0.5

# We'll draw images from the same input to compare results.
fixed_noise = torch.randn(10, noise_size, device=DEVICE)

optimizerD = torch.optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = torch.optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

img_list = []
G_losses = []
D_losses = []

num_epochs = 2

for epoch in range(num_epochs):
    for i, (pictures, labels) in enumerate(DATA_LOADER, 0):
        pictures = pictures.to(DEVICE)
        batch_size = pictures.size(0)
        
        # Getting generated ("fake") picture that tries to trick discriminator
        noise = torch.randn(batch_size, noise_size, device=DEVICE)
        generation = netG(noise, labels)

        D_x, D_G_z1, errD = discriminator_step(
            pictures=pictures,
            generation=generation,
            discriminator=netD,
            optimizer=optimizerD,
            labels=labels
        )

        # Maximizing for generator log(D(G(z)))
        D_G_z2, errG = generator_step(
            generation=generation,
            discriminator=netD,
            optimizer=optimizerG,
            generator=netG,
            labels=labels
        )

        G_losses.append(errG.item())
        D_losses.append(errD.item())

        if i % 50 == 0:
            print(
                f"[{epoch}/{num_epochs}][{i}/{len(DATA_LOADER)}]\t"
                f"Discriminator's loss: {errD.item():.4f}\t"
                f"Generator's loss: {errG.item():.4f}\t"
                f"D(x): {D_x:.4f}\t"
                f"D(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}"
            )

        if (i % 500 == 0) or (
            (epoch == num_epochs - 1) and (i == len(DATA_LOADER) - 1)
        ):
            with torch.no_grad():
                generation = netG(
                    fixed_noise, torch.arange(0, 10)
                ).detach().cpu()
            img_list.append(vutils.make_grid(generation, nrow=10, padding=2, normalize=True))
[0/2][0/938]	Discriminator's loss: 1.5434	Generator's loss: 1.5112	D(x): 0.4506	D(G(z)): 0.4346 / 0.5283
[0/2][50/938]	Discriminator's loss: 0.9908	Generator's loss: 1.8837	D(x): 0.6253	D(G(z)): 0.3311 / 0.2062
[0/2][100/938]	Discriminator's loss: 1.0235	Generator's loss: 2.1464	D(x): 0.6667	D(G(z)): 0.3978 / 0.1631
[0/2][150/938]	Discriminator's loss: 1.0180	Generator's loss: 1.9072	D(x): 0.7596	D(G(z)): 0.4630 / 0.1902
[0/2][200/938]	Discriminator's loss: 1.1699	Generator's loss: 1.3564	D(x): 0.5074	D(G(z)): 0.2551 / 0.3003
[0/2][250/938]	Discriminator's loss: 1.1039	Generator's loss: 2.7239	D(x): 0.8327	D(G(z)): 0.5602 / 0.0901
[0/2][300/938]	Discriminator's loss: 0.9014	Generator's loss: 0.6127	D(x): 0.5212	D(G(z)): 0.1416 / 0.5653
[0/2][350/938]	Discriminator's loss: 0.8082	Generator's loss: 2.1880	D(x): 0.8379	D(G(z)): 0.3838 / 0.1728
[0/2][400/938]	Discriminator's loss: 0.8043	Generator's loss: 2.8220	D(x): 0.9046	D(G(z)): 0.4386 / 0.0972
[0/2][450/938]	Discriminator's loss: 0.4749	Generator's loss: 2.4705	D(x): 0.8067	D(G(z)): 0.2074 / 0.1026
[0/2][500/938]	Discriminator's loss: 1.8155	Generator's loss: 3.9962	D(x): 0.9607	D(G(z)): 0.7792 / 0.0299
[0/2][550/938]	Discriminator's loss: 0.6510	Generator's loss: 1.8618	D(x): 0.7318	D(G(z)): 0.2293 / 0.1955
[0/2][600/938]	Discriminator's loss: 0.6526	Generator's loss: 2.8902	D(x): 0.8570	D(G(z)): 0.3522 / 0.0809
[0/2][650/938]	Discriminator's loss: 0.6477	Generator's loss: 2.6249	D(x): 0.8625	D(G(z)): 0.3372 / 0.1012
[0/2][700/938]	Discriminator's loss: 0.5802	Generator's loss: 1.5953	D(x): 0.7020	D(G(z)): 0.1330 / 0.2409
[0/2][750/938]	Discriminator's loss: 1.2677	Generator's loss: 2.2023	D(x): 0.8543	D(G(z)): 0.5788 / 0.1417
[0/2][800/938]	Discriminator's loss: 0.7851	Generator's loss: 1.6041	D(x): 0.6738	D(G(z)): 0.2623 / 0.2452
[0/2][850/938]	Discriminator's loss: 1.0156	Generator's loss: 2.2827	D(x): 0.8613	D(G(z)): 0.5117 / 0.1314
[0/2][900/938]	Discriminator's loss: 0.8801	Generator's loss: 1.6902	D(x): 0.6595	D(G(z)): 0.2740 / 0.2287
[1/2][0/938]	Discriminator's loss: 1.2976	Generator's loss: 0.5750	D(x): 0.5425	D(G(z)): 0.2673 / 0.6127
[1/2][50/938]	Discriminator's loss: 0.8084	Generator's loss: 1.5425	D(x): 0.7090	D(G(z)): 0.3087 / 0.2631
[1/2][100/938]	Discriminator's loss: 0.9786	Generator's loss: 1.4213	D(x): 0.7709	D(G(z)): 0.4189 / 0.2823
[1/2][150/938]	Discriminator's loss: 0.8726	Generator's loss: 3.2140	D(x): 0.8713	D(G(z)): 0.4469 / 0.0657
[1/2][200/938]	Discriminator's loss: 1.2559	Generator's loss: 2.7035	D(x): 0.8443	D(G(z)): 0.5841 / 0.0957
[1/2][250/938]	Discriminator's loss: 1.4120	Generator's loss: 0.5689	D(x): 0.3295	D(G(z)): 0.0884 / 0.6100
[1/2][300/938]	Discriminator's loss: 1.4050	Generator's loss: 0.4096	D(x): 0.3465	D(G(z)): 0.1380 / 0.7043
[1/2][350/938]	Discriminator's loss: 1.0377	Generator's loss: 1.1427	D(x): 0.5735	D(G(z)): 0.3087 / 0.3678
[1/2][400/938]	Discriminator's loss: 1.5422	Generator's loss: 3.3755	D(x): 0.8966	D(G(z)): 0.6883 / 0.0565
[1/2][450/938]	Discriminator's loss: 1.0001	Generator's loss: 2.4868	D(x): 0.8613	D(G(z)): 0.5230 / 0.1062
[1/2][500/938]	Discriminator's loss: 2.1109	Generator's loss: 2.8956	D(x): 0.8344	D(G(z)): 0.7905 / 0.0792
[1/2][550/938]	Discriminator's loss: 0.8318	Generator's loss: 1.4294	D(x): 0.6806	D(G(z)): 0.2979 / 0.2711
[1/2][600/938]	Discriminator's loss: 1.4873	Generator's loss: 0.3662	D(x): 0.3109	D(G(z)): 0.1496 / 0.7095
[1/2][650/938]	Discriminator's loss: 0.9725	Generator's loss: 1.4026	D(x): 0.6244	D(G(z)): 0.3180 / 0.2869
[1/2][700/938]	Discriminator's loss: 0.8597	Generator's loss: 1.3669	D(x): 0.6091	D(G(z)): 0.2226 / 0.3114
[1/2][750/938]	Discriminator's loss: 1.2769	Generator's loss: 1.7263	D(x): 0.7795	D(G(z)): 0.5819 / 0.2088
[1/2][800/938]	Discriminator's loss: 1.1612	Generator's loss: 1.1500	D(x): 0.5756	D(G(z)): 0.3915 / 0.3607
[1/2][850/938]	Discriminator's loss: 1.2540	Generator's loss: 1.6004	D(x): 0.7783	D(G(z)): 0.5696 / 0.2422
[1/2][900/938]	Discriminator's loss: 0.8771	Generator's loss: 1.0318	D(x): 0.5868	D(G(z)): 0.2210 / 0.3907

The next cell presents an interactive graph that allows you to track the predictions at different stages of model training. Here, all digits are represented from left to right.

fig = plt.figure(figsize=(8, 8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(
    fig=fig, 
    artists=ims, 
    interval=1000, 
    repeat_delay=1000, 
    blit=True
)
plt.close()

HTML(ani.to_jshtml())