Add dimention#

Adding an extra dimension to a tensor can be beneficial, particularly when performing matrix multiplication.

All methods of creating a new dimension will result in a new axis being created at the specified position. Each element at the specified position will be transformed into a one-element array along this new axis.

import torch

Understanding#

Sometimes it’s difficult to visualize exactly what happens when you add a new axis to a tensor. In this section, we’ll consider a few examples to gain a better understanding.


Consider simple example of possible transformations for matrix.

show_tensor = torch.arange(20).reshape(5,4)
display(show_tensor)
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15],
        [16, 17, 18, 19]])

If we add a new dimension as the outermost dimension, we get a 3D tensor containing a single matrix, which is the original matrix.

display(show_tensor.unsqueeze(0))
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15],
         [16, 17, 18, 19]]])

When we add a new dimension as the rows dimension, each row is now wrapped as a one-row matrix, and these matrices act as layers for the third dimension.

display(show_tensor[:, None])
tensor([[[ 0,  1,  2,  3]],

        [[ 4,  5,  6,  7]],

        [[ 8,  9, 10, 11]],

        [[12, 13, 14, 15]],

        [[16, 17, 18, 19]]])

Finally, if you add a new dimension as the innermost dimension, each element is now nested as a separate row, and layers of the third dimension are one-column matrices.

display(show_tensor.unsqueeze(2))
tensor([[[ 0],
         [ 1],
         [ 2],
         [ 3]],

        [[ 4],
         [ 5],
         [ 6],
         [ 7]],

        [[ 8],
         [ 9],
         [10],
         [11]],

        [[12],
         [13],
         [14],
         [15]],

        [[16],
         [17],
         [18],
         [19]]])

None indexing#

You can use None in indexing to create a new axis. Simply place None in the position where you want to introduce the new axis.


The following example demonstrates adding new dimensions to a vector, first as rows and then as columns.

display(torch.arange(5)[None])
display(torch.arange(5)[:, None])
tensor([[0, 1, 2, 3, 4]])
tensor([[0],
        [1],
        [2],
        [3],
        [4]])

Unsqueeze#

The torch.unsqueeze function or the unsqueeze method of a tensor adds a new one-element dimension at the specified dimention index.


Here are examples of adding a dimension as the rows dimension and as the columns dimension for the one dimentional vector.

display(torch.arange(5).unsqueeze(0))
display(torch.arange(5).unsqueeze(1))
tensor([[0, 1, 2, 3, 4]])
tensor([[0],
        [1],
        [2],
        [3],
        [4]])

Negative indices#

Specifying negative numbers will result in adding dimensions counting from the innermost dimensions.


The following cell shows adding new dimensions for a matrix:

  • Columns dimension in the -1 case.

  • Rows dimension in the -2 case.

display(torch.arange(10).reshape(5,2).unsqueeze(-1))
display(torch.arange(10).reshape(5,2).unsqueeze(-2))
tensor([[[0],
         [1]],

        [[2],
         [3]],

        [[4],
         [5]],

        [[6],
         [7]],

        [[8],
         [9]]])
tensor([[[0, 1]],

        [[2, 3]],

        [[4, 5]],

        [[6, 7]],

        [[8, 9]]])

Axes number#

You can specify only one more axis than the tensor under consideration has. If you try to specify more, you’ll get an error as shown below.

torch.arange(5).unsqueeze(2)
---------------------------------------------------------------------------

IndexError                                Traceback (most recent call last)

Cell In[38], line 1

----> 1 torch.arange(5).unsqueeze(2)



IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
torch.arange(5).unsqueeze(-3)
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[5], line 1
----> 1 torch.arange(5).unsqueeze(-3)

IndexError: Dimension out of range (expected to be in range of [-2, 1], but got -3)