Transformers#

This page discusses the layers of the Torch that implements transformers architecture.

import torch

Encoder layer#

The encoder layer is implemented using the torch.nn.TransformerEncoderLayer.

The crucial parameters:

  • d_model: The dimentionality of the processed sequence unit.

  • nhead: Number of heads that process the sequence.

The ouput of the layer has the same dimentionality as the input.


The following cell defines the regular encoder layer.

encoder = torch.nn.TransformerEncoderLayer(
    d_model=5, nhead=1, batch_first=True
)

Here is defined the kind of input it supposes to process:

inp = torch.rand(32, 10, 5)

A batch of 32 data units, each of which is a sequence of 10 elements from \(\mathbb{R}^5\).

encoder(inp).shape
torch.Size([32, 10, 5])

Transformer encoder#

A transformer encoder is an object that stacks a specified number of the transformer layers together. The output of the each layer just becomes the input of the following layer.


The next code shows how to initialize the transformer encoder.

transformer_encoder = torch.nn.TransformerEncoder(
    encoder_layer=torch.nn.TransformerEncoderLayer(
        d_model=5, nhead=1, batch_first=True
    ),
    num_layers=2,
    enable_nested_tensor=False
)

Here is an example data shape that can be processed by the TransformerEncoder defined earlier.

inp = torch.rand(32, 10, 5)

Here is an example of the data that came through the layer.

transformer_encoder(inp).shape
torch.Size([32, 10, 5])

Nested tensor#

The TransformerEncoder has a parameter enable_nested_tensor. True value forses torch to use special datastucture - nested tensor which is optimised to work with seqences.

Note: Nested tensors have requirements for the data they process. If the dimentionality of the data is not even, torch automatically sets the value of enable_nested_tensor to False.


The following cell attempts to define a TransformerEncoder that uses a nested tensor with TransformerEncoderLayer that have an odd dimentionality for the input tensor.

transformer_encoder = torch.nn.TransformerEncoder(
    encoder_layer=torch.nn.TransformerEncoderLayer(
        d_model=5, nhead=1, batch_first=True
    ),
    num_layers=2,
    enable_nested_tensor=True
)
/home/fedor/.virtualenvs/knowledge/lib/python3.12/site-packages/torch/nn/modules/transformer.py:385: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.num_heads is odd
  warnings.warn(