Flatten/unflatten

Contents

Flatten/unflatten#

torch.nn.Flatten and torch.nn.Unflatten layers allow you to manipulate the dimensionalities of the output directly within the forward pass of a neural network.

import torch

Flatten#

It simply concatenates elements along the “flattened” dimensions. Check official manual about Flatten.

If the input has dimensions \(\left(S_0, ..., S_{\text{start}}, ..., S_{\text{end}}, ..., S_{n}\right)\), the resulting dimensions will be \(\left(S_0, ..., \prod_{i=\text{start}}^{\text{end}} S_i, ..., S_{n}\right)\). Elements of all dimensions from start to end will be concatenated sequentially.


For example, consider a 3-dimensional tensor. So it dimentionality would be \(\left(3, 3, 3\right)\).

input = torch.arange(27).reshape([3, 3, 3])
input
tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8]],

        [[ 9, 10, 11],
         [12, 13, 14],
         [15, 16, 17]],

        [[18, 19, 20],
         [21, 22, 23],
         [24, 25, 26]]])

By default start_dim=1, end_dim=-1. So in our case result of default Flatten should be \(\left( S_0, S_1S_2 \right)=\left(3, 9\right)\):

torch.nn.Flatten()(input)
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8],
        [ 9, 10, 11, 12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23, 24, 25, 26]])

Elements of the last two axes were combined into flat vectors.

Now let’s try using a non-default Flatten. Suppose we want to concatenate outer dimensions. This way, we concatenate layers of the cube and end up with a long matrix.

torch.nn.Flatten(start_dim=0, end_dim=1)(input)
tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17],
        [18, 19, 20],
        [21, 22, 23],
        [24, 25, 26]])

Or simply transform an input of arbitrary dimensionality into a one-dimensional vector.

torch.nn.Flatten(start_dim=0)(input)
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26])

Unflatten#

It splits the specified dimension into pieces and arranges these pieces in the specified order.


As an example, consider a matrix. The outer dimension of the matrix refers to its rows.

input = torch.arange(81).reshape([9, 9])
input
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8],
        [ 9, 10, 11, 12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23, 24, 25, 26],
        [27, 28, 29, 30, 31, 32, 33, 34, 35],
        [36, 37, 38, 39, 40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49, 50, 51, 52, 53],
        [54, 55, 56, 57, 58, 59, 60, 61, 62],
        [63, 64, 65, 66, 67, 68, 69, 70, 71],
        [72, 73, 74, 75, 76, 77, 78, 79, 80]])

The following cell applies Unflatten to the outer dimension of the input matrix—rows are grouped into separate matrices and arranged as layers of a 3D array.

torch.nn.Unflatten(dim=0, unflattened_size=(3,3))(input)
tensor([[[ 0,  1,  2,  3,  4,  5,  6,  7,  8],
         [ 9, 10, 11, 12, 13, 14, 15, 16, 17],
         [18, 19, 20, 21, 22, 23, 24, 25, 26]],

        [[27, 28, 29, 30, 31, 32, 33, 34, 35],
         [36, 37, 38, 39, 40, 41, 42, 43, 44],
         [45, 46, 47, 48, 49, 50, 51, 52, 53]],

        [[54, 55, 56, 57, 58, 59, 60, 61, 62],
         [63, 64, 65, 66, 67, 68, 69, 70, 71],
         [72, 73, 74, 75, 76, 77, 78, 79, 80]]])

Applying Unflatten to the inner dimension of the input splits rows into subarrays and arranges them as new rows inside the matrices.

torch.nn.Unflatten(dim=1, unflattened_size=(3,3))(input)
tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8]],

        [[ 9, 10, 11],
         [12, 13, 14],
         [15, 16, 17]],

        [[18, 19, 20],
         [21, 22, 23],
         [24, 25, 26]],

        [[27, 28, 29],
         [30, 31, 32],
         [33, 34, 35]],

        [[36, 37, 38],
         [39, 40, 41],
         [42, 43, 44]],

        [[45, 46, 47],
         [48, 49, 50],
         [51, 52, 53]],

        [[54, 55, 56],
         [57, 58, 59],
         [60, 61, 62]],

        [[63, 64, 65],
         [66, 67, 68],
         [69, 70, 71]],

        [[72, 73, 74],
         [75, 76, 77],
         [78, 79, 80]]])