Unet

Contents

Unet#

This page describes the implementation of the unet network architecture in Torch.

import torch
from torch import nn

Architecture#

Here, the methods and features for implementing a U-Net in Torch are described.

For conveniece here is really popular representation of the U-Net achitecture:

unet achitecture

The following elements can be conventionally distinguished in the architecture:

  • Two convolutional transformations.

  • Downscaling, followed by two convolutional transformations.

  • Upscaling, followed by two convolutional transformations. Before the convolutional transformations, the values from the previous layer should be concatenated with the results of the corresponding downscaling block.

  • Output layer that ensures the network produces the required output shape.

So consider possible implementations for each of these components.

class DoubleConv(nn.Module):
    '''
    Implementation of double conv layer.

    Parameters
    ----------
    in_channels: int
        Expected number of elements in the third dimension of the model's input.
    out_channels: int
        Number of elements in the third dimension of the model's output.
    mid_channels: int
        Number of channels used for communication between the first and second
        convolutions.
    '''
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        mid_channels: int|None=None
    ) -> None:
        super().__init__()

        if not mid_channels:
            mid_channels = out_channels

        self.double_conv = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=mid_channels,
                kernel_size=3,
                padding=1
            ),
            nn.BatchNorm2d(num_features=mid_channels),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(
                in_channels=mid_channels,
                out_channels=out_channels,
                kernel_size=3,
                padding=1
            ),
            nn.BatchNorm2d(num_features=out_channels),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.double_conv(x)

class Down(nn.Module):
    '''
    Downscaling block. Applies MaxPooling2D and double convolution. As a result,
    the output feature maps will have a dimensionality of n/2 - 2, where n is
    the size of the input feature map.

    Parameters
    ----------
    in_channels: int
        Channels number of the input data.
    out_channesl: int
        Channels number of the output data.
    '''

    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels=in_channels, out_channels=out_channels)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.maxpool_conv(x)

class Up(nn.Module):
    '''
    Upscaling block. Takes "down" and "left" inputs, concatenates them, and
    applies double convolutional transformations. If the "down" dimension does
    not have enough elements to be concatenated with the "left" input, it will
    be padded to have the corresponding shape.

    Parameters
    ----------
    in_channels: int
        Number of input channels.
    out_channels: int
        Number of output channels.
    '''

    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()

        self.up = nn.Upsample(
            scale_factor=2,
            mode="bilinear",
            align_corners=True
        )
        self.conv = DoubleConv(
            in_channels=in_channels,
            out_channels=out_channels,
            #mid_channels=(in_channels // 2)
        )

    def forward(self, x: torch.Tensor, x_left: torch.Tensor) -> torch.Tensor:
        x = self.up(x)

        diffY = x_left.shape[2] - x.shape[2]
        diffX = x_left.shape[3] - x.shape[3]

        # Pad upsampled "x" to diffX//2 from left - other from right.
        # Similarly with Y.
        pad = [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]
        x = torch.nn.functional.pad(input=x, pad=pad)

        x = torch.cat([x_left, x], dim=1)

        return self.conv(x)

And definition of the model.

class UNet(nn.Module):
    """
    Implementation of the final network.
    """

    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()

        self.n_channels = n_channels
        self.n_classes = n_classes

        self.inc = DoubleConv(n_channels, out_channels=32)
        self.down1 = Down(in_channels=32, out_channels=64)
        self.down2 = Down(in_channels=64, out_channels=128)
        self.down3 = Down(in_channels=128, out_channels=256)

        self.bottleneck = Down(in_channels=256, out_channels=256)

        # Input Up layer is concatenation by channels of the previous layer
        # and corresponding down layer
        self.up1 = Up(in_channels=512, out_channels=128)
        self.up2 = Up(in_channels=256, out_channels=64)
        self.up3 = Up(in_channels=128, out_channels=32)
        self.up4 = Up(in_channels=64, out_channels=32)

        # The last layer applies a convolution that preserves the dimensionality
        # of the feature maps and returns as many channels as the number of
        # predicted classes.
        self.outc = torch.nn.Conv2d(
            in_channels=32,
            out_channels=n_classes,
            kernel_size=1
        )

    def forward(self, x):
        x1 = self.inc(x=x)
        x2 = self.down1(x=x1)
        x3 = self.down2(x=x2)
        x4 = self.down3(x=x3)
        x5 = self.bottleneck(x=x4)
        x = self.up1(x=x5, x_left=x4)
        x = self.up2(x=x, x_left=x3)
        x = self.up3(x=x, x_left=x2)
        x = self.up4(x=x, x_left=x1)
        logits = self.outc(x)
        return logits