Flatten/unflatten#
torch.nn.Flatten
and torch.nn.Unflatten
layers allow you to manipulate the dimensionalities of the output directly within the forward pass of a neural network.
import torch
Flatten#
It simply concatenates elements along the “flattened” dimensions. Check official manual about Flatten.
If the input has dimensions \(\left(S_0, ..., S_{\text{start}}, ..., S_{\text{end}}, ..., S_{n}\right)\), the resulting dimensions will be \(\left(S_0, ..., \prod_{i=\text{start}}^{\text{end}} S_i, ..., S_{n}\right)\). Elements of all dimensions from start
to end
will be concatenated sequentially.
For example, consider a 3-dimensional tensor. So it dimentionality would be \(\left(3, 3, 3\right)\).
input = torch.arange(27).reshape([3, 3, 3])
input
tensor([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]],
[[18, 19, 20],
[21, 22, 23],
[24, 25, 26]]])
By default start_dim=1, end_dim=-1
. So in our case result of default Flatten
should be \(\left( S_0, S_1S_2 \right)=\left(3, 9\right)\):
torch.nn.Flatten()(input)
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8],
[ 9, 10, 11, 12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23, 24, 25, 26]])
Elements of the last two axes were combined into flat vectors.
Now let’s try using a non-default Flatten
. Suppose we want to concatenate outer dimensions. This way, we concatenate layers of the cube and end up with a long matrix.
torch.nn.Flatten(start_dim=0, end_dim=1)(input)
tensor([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17],
[18, 19, 20],
[21, 22, 23],
[24, 25, 26]])
Or simply transform an input of arbitrary dimensionality into a one-dimensional vector.
torch.nn.Flatten(start_dim=0)(input)
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26])
Unflatten#
It splits the specified dimension into pieces and arranges these pieces in the specified order.
As an example, consider a matrix. The outer dimension of the matrix refers to its rows.
input = torch.arange(81).reshape([9, 9])
input
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8],
[ 9, 10, 11, 12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23, 24, 25, 26],
[27, 28, 29, 30, 31, 32, 33, 34, 35],
[36, 37, 38, 39, 40, 41, 42, 43, 44],
[45, 46, 47, 48, 49, 50, 51, 52, 53],
[54, 55, 56, 57, 58, 59, 60, 61, 62],
[63, 64, 65, 66, 67, 68, 69, 70, 71],
[72, 73, 74, 75, 76, 77, 78, 79, 80]])
The following cell applies Unflatten
to the outer dimension of the input matrix—rows are grouped into separate matrices and arranged as layers of a 3D array.
torch.nn.Unflatten(dim=0, unflattened_size=(3,3))(input)
tensor([[[ 0, 1, 2, 3, 4, 5, 6, 7, 8],
[ 9, 10, 11, 12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23, 24, 25, 26]],
[[27, 28, 29, 30, 31, 32, 33, 34, 35],
[36, 37, 38, 39, 40, 41, 42, 43, 44],
[45, 46, 47, 48, 49, 50, 51, 52, 53]],
[[54, 55, 56, 57, 58, 59, 60, 61, 62],
[63, 64, 65, 66, 67, 68, 69, 70, 71],
[72, 73, 74, 75, 76, 77, 78, 79, 80]]])
Applying Unflatten
to the inner dimension of the input splits rows into subarrays and arranges them as new rows inside the matrices.
torch.nn.Unflatten(dim=1, unflattened_size=(3,3))(input)
tensor([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]],
[[18, 19, 20],
[21, 22, 23],
[24, 25, 26]],
[[27, 28, 29],
[30, 31, 32],
[33, 34, 35]],
[[36, 37, 38],
[39, 40, 41],
[42, 43, 44]],
[[45, 46, 47],
[48, 49, 50],
[51, 52, 53]],
[[54, 55, 56],
[57, 58, 59],
[60, 61, 62]],
[[63, 64, 65],
[66, 67, 68],
[69, 70, 71]],
[[72, 73, 74],
[75, 76, 77],
[78, 79, 80]]])