Reshape like

Reshape like#

This section covers the following functions and methods:

  • torch.reshape function

  • torch.Tensor.reshape method

  • torch.Tensor.view method

All of them serve the same purpose but differ in their underlying mechanics.

import torch

The easiest way to change the dimensionality of a tensor is to use the torch.reshape function or torch.torch.reshape method of the tensor you’re reshaping. All you have to do is specify a new dimensionality, quite obviously the number of elements in the new dimensionality must be equal to the number of elements in the tensor to be reshaped.


Consider an example with 12 arange elements.

original_tensor = torch.arange(12)
print("original", original_tensor)
original tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])

The following cell reshapes it to the (6, 2) dimensionality.

original_tensor.reshape(6,2)
tensor([[ 0,  1],
        [ 2,  3],
        [ 4,  5],
        [ 6,  7],
        [ 8,  9],
        [10, 11]])

The elements are sequentially obtained in their order, and first fill the innermost dimension, then the next one, and so on.

Same result with more complex result dimentionality.

original_tensor.reshape(3,2,2)
tensor([[[ 0,  1],
         [ 2,  3]],

        [[ 4,  5],
         [ 6,  7]],

        [[ 8,  9],
         [10, 11]]])

Similarly, when reshaping complex dimensions, it will extract elements from the most nested dimensions. The following example breaks the original sequence of elements with transposition, and then reshapes the resulting array to the 1D dimensionality.

transposed_tensor = original_tensor.reshape(2, 6).T
print("Transposed \n", transposed_tensor)

transposed_tensor.reshape(12)
Transposed 
 tensor([[ 0,  6],
        [ 1,  7],
        [ 2,  8],
        [ 3,  9],
        [ 4, 10],
        [ 5, 11]])
tensor([ 0,  6,  1,  7,  2,  8,  3,  9,  4, 10,  5, 11])

It takes elements from inndermost dimentionalty gradually moving to more and more external dimensions.

Complete dimentionality#

When using -1 as an argument in the reshape function, you instruct PyTorch to automatically determine the appropriate dimension size that will match the total number of elements in the tensor.


Applying the reshape(-1, 6) instruction to the original_tensor from the previous examples, which contains 12 elements, tells PyTorch to automatically determine the appropriate size for the first dimension. In this case, it calculates that the size of the first dimension should be 2.

original_tensor.reshape(-1, 6)
tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11]])

For a more complex scenario, suppose we want to create 2 matrices with 3 columns each. To accommodate the 12 elements, each matrix must have 2 rows. The following example demonstrates how to achieve this using the reshape function:

original_tensor.reshape(2, -1, 3)
tensor([[[ 0,  1,  2],
         [ 3,  4,  5]],

        [[ 6,  7,  8],
         [ 9, 10, 11]]])

If it’s not possible to match the dimensions, PyTorch will display an error message indicating the issue. It is not possible to arrange 12 elements into 3 matrices with 3 rows each, so this will result in an error.

original_tensor.reshape(3,3,-1)
---------------------------------------------------------------------------



RuntimeError                              Traceback (most recent call last)



Cell In[21], line 1



----> 1 original_tensor.reshape(3,3,-1)







RuntimeError: shape '[3, 3, -1]' is invalid for input of size 12

View method#

This method can work only with contiguous tensors (tensor are stored sequentially in memory without any gaps). It generally faster than reshape.


The following cell creates the tensor we’ll use as an example.

X = torch.randn(4, 3)
X
tensor([[-0.0944, -0.0557,  0.5670],
        [ 0.3236,  2.5825,  0.0342],
        [ 1.0671, -1.3104, -2.0697],
        [ 1.3963,  0.2144,  1.2199]])

Simply applying the view method works as expected.

X.view(3, 4)
tensor([[-0.0944, -0.0557,  0.5670,  0.3236],
        [ 2.5825,  0.0342,  1.0671, -1.3104],
        [-2.0697,  1.3963,  0.2144,  1.2199]])

But transposing breaks the continuity of the tensor, the following cell applies `view’ to the transposed tensor that caused the error message.

X.T.view(4, 3)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[20], line 1
----> 1 X.T.view(4, 3)

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.