Conditional GAN#
Conditional GAN (CGAN) is a type of generative model with extra parameters that specify the type of generation to be performed. This notebook uses the MNIST dataset as an example. The model will take random noise and a specified digit that should be generated from that noise.
import numpy as np
from pathlib import Path
from IPython.display import HTML
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.nn.functional import binary_cross_entropy
import torchvision.utils as vutils
import torchvision.transforms as T
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
import matplotlib.animation as animation
if torch.cuda.is_available():
DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
DEVICE = torch.device("mps")
else:
DEVICE = torch.device("cpu")
print("using device", DEVICE)
TRAIN_DATASET = MNIST(
Path("mnist_files"),
download=True,
train=True,
transform=T.ToTensor()
)
DATA_LOADER = DataLoader(TRAIN_DATASET, batch_size=64)
using device cpu
Generator#
The main feature of the generator for CGAN is that it should somehow push information about wich numeber is required for generation to the model.
The following generator class uses torch.nn.Embedding
to transform labels into vectors, which are then concatenated with the noise in the forward pass - that’s how information about label inters network.
class Generator(nn.Module):
'''
A class that generates a picture from a set of random noise.
Parameters
----------
feature_map_size: int
Feature map's size of the generator.
number_channels: int
Number of input channels.
noise_size: int
Size of the vector that is expected to be transformed to the picture by
the model.
'''
def __init__(
self,
feature_map_size: int,
number_channels: int,
noise_size: int,
classes_number: int
) -> None:
super(Generator, self).__init__()
self.embedding = nn.Embedding(
num_embeddings=classes_number,
embedding_dim=classes_number
)
self.main = nn.Sequential(
# (nz) x 1 x 1
nn.ConvTranspose2d(
in_channels=noise_size + classes_number,
out_channels=feature_map_size * 2,
kernel_size=7,
stride=1,
padding=0,
bias=False
),
nn.BatchNorm2d(feature_map_size * 2),
nn.ReLU(True),
# (feature_map_size*2) x 7 x 7
nn.ConvTranspose2d(
in_channels=feature_map_size * 2,
out_channels=feature_map_size,
kernel_size=4,
stride=2,
padding=1,
bias=False
),
nn.BatchNorm2d(feature_map_size),
nn.ReLU(True),
# (feature_map_size) x 14 x 14
nn.ConvTranspose2d(
in_channels=feature_map_size,
out_channels=number_channels,
kernel_size=4,
stride=2,
padding=1,
bias=False
),
nn.Tanh(),
# (nc) x 28 x 28
)
def forward(self, input: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
'''
Apply model to given data.
Paramaters
----------
input: torch.Tensor
Tensor with size (n_samples, self.nz).
labels: torch.Tensor
Set of labels that have to be generated.
Returns
-------
out: torch.Tensor
Tensor that represents set of generated pictures.
'''
labels = self.embedding(labels)
input = torch.cat([input, labels], dim=1)
return self.main(input[:, :, None, None])
Consider how that class is supposed to be used - with random noise passed labels of the digits that have to be generated.
generator = Generator(
feature_map_size=64,
number_channels=1,
noise_size=100,
classes_number=10
)
generator(
torch.randn(20, 100),
torch.randint(low=0, high=10, size=(20,))
).shape
torch.Size([20, 1, 28, 28])
As the result a set of pictures.
Discriminator#
The discriminator, like the generator, should receive information about which number is in the image. The procedure is almost the same as for the generator, except that we need to concatenate the “label” channel, which should have the same dimensionality as the image channels.
The following cell implements a discriminator that additionally expects the labels of the images as input.
class Discriminator(nn.Module):
'''
Realisation of the discriminator. Class that takes picture and generate scor
which expresses how much the model thinks the picture is generated.
Parameters
----------
number_channels: int
Number of channels in input.
feature_map_size: int
Feature map's size.
'''
def __init__(
self,
number_channels: int,
feature_map_size: int,
classes_number: int
):
super(Discriminator, self).__init__()
self.embedding = torch.nn.Sequential(
torch.nn.Embedding(
num_embeddings=classes_number, embedding_dim=28**2
),
# Transforms an embedding vector to a set of pictures with
# given amount of channels
torch.nn.Unflatten(
dim=1, unflattened_size=(number_channels, 28, 28)
)
)
# Increasing number of channels for embedding information
number_channels += 1
self.main = nn.Sequential(
# (number_channels) x 28 x 28
nn.Conv2d(
in_channels=number_channels,
out_channels=feature_map_size,
kernel_size=4,
stride=2,
padding=1,
bias=False
),
nn.LeakyReLU(0.2, inplace=True),
# (feature_map_size) x 14 x 14
nn.Conv2d(
in_channels=feature_map_size,
out_channels=feature_map_size * 2,
kernel_size=4,
stride=2,
padding=1,
bias=False
),
nn.BatchNorm2d(feature_map_size * 2),
nn.LeakyReLU(0.2, inplace=True),
# (feature_map_size*2) x 7 x 7
nn.Conv2d(
in_channels=feature_map_size * 2,
out_channels=1,
kernel_size=7,
stride=1,
padding=0,
bias=False
),
nn.Sigmoid(),
nn.Flatten(start_dim=0, end_dim=-1)
)
def forward(self, input: torch.Tensor, labels: torch.Tensor):
'''
Apply model to given data.
Paramaters
----------
input: torch.Tensor
Tensor with size (n_samples, self.nz).
labels: torch.Tensor
Set of labels that have to be generated.
Returns
-------
out: torch.Tensor
One number torch tensor that represents the score that represents
the score that the picture is real (not generated).
'''
labels = self.embedding(labels)
input = torch.cat([input, labels], dim=1)
return self.main(input)
The next example shows a usage of discriminator.
discriminator = Discriminator(
number_channels=1,
feature_map_size=64,
classes_number=10
)
images, labels = next(iter(DATA_LOADER))
discriminator(images, labels)
tensor([0.4961, 0.3902, 0.6212, 0.6132, 0.4327, 0.3448, 0.6016, 0.5595, 0.5648,
0.6109, 0.6144, 0.5024, 0.5680, 0.6251, 0.5695, 0.4638, 0.3622, 0.5321,
0.6464, 0.4238, 0.6291, 0.4133, 0.3882, 0.6054, 0.5787, 0.3850, 0.5927,
0.5878, 0.3770, 0.4629, 0.5566, 0.5205, 0.6753, 0.4036, 0.4051, 0.5235,
0.6389, 0.4187, 0.4417, 0.6263, 0.5856, 0.5550, 0.4788, 0.4285, 0.5058,
0.4392, 0.5430, 0.5248, 0.4525, 0.6055, 0.5848, 0.3872, 0.4926, 0.6212,
0.4331, 0.4974, 0.4209, 0.4056, 0.6371, 0.6044, 0.5749, 0.6219, 0.6230,
0.3880], grad_fn=<ViewBackward0>)
Model fitting#
Typical GAN training loop is enough ot train CGAN as well.
def discriminator_step(
pictures: torch.Tensor,
generation: torch.Tensor,
discriminator: torch.nn.Module,
optimizer: torch.optim.Optimizer,
labels: torch.Tensor
) -> tuple[float, float, float]:
'''
Step of the discriminator. Maximize log(D(x)) + log(1 - D(G(z))) - tries to
to improve the prediction that the real images have 1. scores and the
generated images have have 0. scores.
Parameters
----------
pictures: torch.Tensor
Batch of real images that we're trying to imitate.
generation: torch.Tensor
Batch of generated images.
discriminator: torch.nn.Module
Model that we optimise.
optimizer: torch.optim.Optimizer
Optimizet that uses weights of the dicriminator.
labels: torch.Tensor
Labels that describes what is displayed on the picture.
Return
------
out: tuple[float, float, float]
- Mean prediction for real images.
- Mean predictoin for fake images.
- Total loss value on both real and fake images.
'''
batch_size = pictures.shape[0]
discriminator.zero_grad()
# Gradient accumulation on real images
# Model should predict scores close to 1
label = torch.full((batch_size,), 1., dtype=torch.float, device=DEVICE)
output = discriminator(pictures, labels)
errD_real = binary_cross_entropy(output, label)
errD_real.backward()
D_x = output.mean().item()
# Gradient accumulation on fake images
# Model should predicst scores close to 0
label.fill_(0.)
# Note: generation here bypasses the discriminator without gradient
# accumulation because we don't need generator gradients to optimize the
# discriminator.
output = discriminator(generation.detach(), labels)
errD_fake = binary_cross_entropy(output, label)
errD_fake.backward()
D_G_z1 = output.mean().item()
errD = errD_real + errD_fake
# Step based on accumulated gradients
optimizer.step()
return D_x, D_G_z1, errD
def generator_step(
generation: torch.Tensor,
discriminator: torch.nn.Module,
optimizer: torch.optim.Optimizer,
generator: torch.nn.Module,
labels: torch.Tensor
) -> tuple[float, float]:
'''
Step of the generator.
Parameters
----------
genration: torch.Tensor
Set of objects generated by generator.
disciminator: torch.nn.Module
Model which decision will determine behavior of the the generator.
optimizer: torch.optim.Optimizer
The optimizer that changes weights of the generator.
generator: torch.nn.Module
The model whose weights we are adjusting in this step.
labels: torch.Tensor
Labels that describes what is displayed on the picture.
Returns
-------
out: tuple[float, float]
- Mean prediction of the discriminator under generated data.
- Loss value of the discriminator on generated picture.
'''
batch_size = generation.shape[0]
generator.zero_grad()
label = torch.full(
size=(batch_size,),
fill_value=1.,
dtype=torch.float,
device=DEVICE
)
output = discriminator(generation, labels)
errG = binary_cross_entropy(output, label)
errG.backward()
D_G_z2 = output.mean().item()
optimizer.step()
return D_G_z2, errG
feature_map_size = 64
noise_size = 100
def weights_init(m: torch.nn.Module) -> None:
'''
Function to initialize the weights of the model.
Parameters
----------
m: torch.nn.Module
Model that requires weight initialization.
'''
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
netG = Generator(
feature_map_size=feature_map_size,
number_channels=1,
noise_size=noise_size,
classes_number=10
).to(DEVICE)
netG = netG.apply(weights_init)
netD = Discriminator(
feature_map_size=feature_map_size,
number_channels=1,
classes_number=10
).to(DEVICE)
netD = netD.apply(weights_init)
lr = 0.001
beta1 = 0.5
# We'll draw images from the same input to compare results.
fixed_noise = torch.randn(10, noise_size, device=DEVICE)
optimizerD = torch.optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = torch.optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
img_list = []
G_losses = []
D_losses = []
num_epochs = 2
for epoch in range(num_epochs):
for i, (pictures, labels) in enumerate(DATA_LOADER, 0):
pictures = pictures.to(DEVICE)
batch_size = pictures.size(0)
# Getting generated ("fake") picture that tries to trick discriminator
noise = torch.randn(batch_size, noise_size, device=DEVICE)
generation = netG(noise, labels)
D_x, D_G_z1, errD = discriminator_step(
pictures=pictures,
generation=generation,
discriminator=netD,
optimizer=optimizerD,
labels=labels
)
# Maximizing for generator log(D(G(z)))
D_G_z2, errG = generator_step(
generation=generation,
discriminator=netD,
optimizer=optimizerG,
generator=netG,
labels=labels
)
G_losses.append(errG.item())
D_losses.append(errD.item())
if i % 50 == 0:
print(
f"[{epoch}/{num_epochs}][{i}/{len(DATA_LOADER)}]\t"
f"Discriminator's loss: {errD.item():.4f}\t"
f"Generator's loss: {errG.item():.4f}\t"
f"D(x): {D_x:.4f}\t"
f"D(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}"
)
if (i % 500 == 0) or (
(epoch == num_epochs - 1) and (i == len(DATA_LOADER) - 1)
):
with torch.no_grad():
generation = netG(
fixed_noise, torch.arange(0, 10)
).detach().cpu()
img_list.append(vutils.make_grid(generation, nrow=10, padding=2, normalize=True))
[0/2][0/938] Discriminator's loss: 1.5434 Generator's loss: 1.5112 D(x): 0.4506 D(G(z)): 0.4346 / 0.5283
[0/2][50/938] Discriminator's loss: 0.9908 Generator's loss: 1.8837 D(x): 0.6253 D(G(z)): 0.3311 / 0.2062
[0/2][100/938] Discriminator's loss: 1.0235 Generator's loss: 2.1464 D(x): 0.6667 D(G(z)): 0.3978 / 0.1631
[0/2][150/938] Discriminator's loss: 1.0180 Generator's loss: 1.9072 D(x): 0.7596 D(G(z)): 0.4630 / 0.1902
[0/2][200/938] Discriminator's loss: 1.1699 Generator's loss: 1.3564 D(x): 0.5074 D(G(z)): 0.2551 / 0.3003
[0/2][250/938] Discriminator's loss: 1.1039 Generator's loss: 2.7239 D(x): 0.8327 D(G(z)): 0.5602 / 0.0901
[0/2][300/938] Discriminator's loss: 0.9014 Generator's loss: 0.6127 D(x): 0.5212 D(G(z)): 0.1416 / 0.5653
[0/2][350/938] Discriminator's loss: 0.8082 Generator's loss: 2.1880 D(x): 0.8379 D(G(z)): 0.3838 / 0.1728
[0/2][400/938] Discriminator's loss: 0.8043 Generator's loss: 2.8220 D(x): 0.9046 D(G(z)): 0.4386 / 0.0972
[0/2][450/938] Discriminator's loss: 0.4749 Generator's loss: 2.4705 D(x): 0.8067 D(G(z)): 0.2074 / 0.1026
[0/2][500/938] Discriminator's loss: 1.8155 Generator's loss: 3.9962 D(x): 0.9607 D(G(z)): 0.7792 / 0.0299
[0/2][550/938] Discriminator's loss: 0.6510 Generator's loss: 1.8618 D(x): 0.7318 D(G(z)): 0.2293 / 0.1955
[0/2][600/938] Discriminator's loss: 0.6526 Generator's loss: 2.8902 D(x): 0.8570 D(G(z)): 0.3522 / 0.0809
[0/2][650/938] Discriminator's loss: 0.6477 Generator's loss: 2.6249 D(x): 0.8625 D(G(z)): 0.3372 / 0.1012
[0/2][700/938] Discriminator's loss: 0.5802 Generator's loss: 1.5953 D(x): 0.7020 D(G(z)): 0.1330 / 0.2409
[0/2][750/938] Discriminator's loss: 1.2677 Generator's loss: 2.2023 D(x): 0.8543 D(G(z)): 0.5788 / 0.1417
[0/2][800/938] Discriminator's loss: 0.7851 Generator's loss: 1.6041 D(x): 0.6738 D(G(z)): 0.2623 / 0.2452
[0/2][850/938] Discriminator's loss: 1.0156 Generator's loss: 2.2827 D(x): 0.8613 D(G(z)): 0.5117 / 0.1314
[0/2][900/938] Discriminator's loss: 0.8801 Generator's loss: 1.6902 D(x): 0.6595 D(G(z)): 0.2740 / 0.2287
[1/2][0/938] Discriminator's loss: 1.2976 Generator's loss: 0.5750 D(x): 0.5425 D(G(z)): 0.2673 / 0.6127
[1/2][50/938] Discriminator's loss: 0.8084 Generator's loss: 1.5425 D(x): 0.7090 D(G(z)): 0.3087 / 0.2631
[1/2][100/938] Discriminator's loss: 0.9786 Generator's loss: 1.4213 D(x): 0.7709 D(G(z)): 0.4189 / 0.2823
[1/2][150/938] Discriminator's loss: 0.8726 Generator's loss: 3.2140 D(x): 0.8713 D(G(z)): 0.4469 / 0.0657
[1/2][200/938] Discriminator's loss: 1.2559 Generator's loss: 2.7035 D(x): 0.8443 D(G(z)): 0.5841 / 0.0957
[1/2][250/938] Discriminator's loss: 1.4120 Generator's loss: 0.5689 D(x): 0.3295 D(G(z)): 0.0884 / 0.6100
[1/2][300/938] Discriminator's loss: 1.4050 Generator's loss: 0.4096 D(x): 0.3465 D(G(z)): 0.1380 / 0.7043
[1/2][350/938] Discriminator's loss: 1.0377 Generator's loss: 1.1427 D(x): 0.5735 D(G(z)): 0.3087 / 0.3678
[1/2][400/938] Discriminator's loss: 1.5422 Generator's loss: 3.3755 D(x): 0.8966 D(G(z)): 0.6883 / 0.0565
[1/2][450/938] Discriminator's loss: 1.0001 Generator's loss: 2.4868 D(x): 0.8613 D(G(z)): 0.5230 / 0.1062
[1/2][500/938] Discriminator's loss: 2.1109 Generator's loss: 2.8956 D(x): 0.8344 D(G(z)): 0.7905 / 0.0792
[1/2][550/938] Discriminator's loss: 0.8318 Generator's loss: 1.4294 D(x): 0.6806 D(G(z)): 0.2979 / 0.2711
[1/2][600/938] Discriminator's loss: 1.4873 Generator's loss: 0.3662 D(x): 0.3109 D(G(z)): 0.1496 / 0.7095
[1/2][650/938] Discriminator's loss: 0.9725 Generator's loss: 1.4026 D(x): 0.6244 D(G(z)): 0.3180 / 0.2869
[1/2][700/938] Discriminator's loss: 0.8597 Generator's loss: 1.3669 D(x): 0.6091 D(G(z)): 0.2226 / 0.3114
[1/2][750/938] Discriminator's loss: 1.2769 Generator's loss: 1.7263 D(x): 0.7795 D(G(z)): 0.5819 / 0.2088
[1/2][800/938] Discriminator's loss: 1.1612 Generator's loss: 1.1500 D(x): 0.5756 D(G(z)): 0.3915 / 0.3607
[1/2][850/938] Discriminator's loss: 1.2540 Generator's loss: 1.6004 D(x): 0.7783 D(G(z)): 0.5696 / 0.2422
[1/2][900/938] Discriminator's loss: 0.8771 Generator's loss: 1.0318 D(x): 0.5868 D(G(z)): 0.2210 / 0.3907
The next cell presents an interactive graph that allows you to track the predictions at different stages of model training. Here, all digits are represented from left to right.
fig = plt.figure(figsize=(8, 8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(
fig=fig,
artists=ims,
interval=1000,
repeat_delay=1000,
blit=True
)
plt.close()
HTML(ani.to_jshtml())