Reshape like#
This section covers the following functions and methods:
torch.reshape
functiontorch.Tensor.reshape
methodtorch.Tensor.view
method
All of them serve the same purpose but differ in their underlying mechanics.
import torch
The easiest way to change the dimensionality of a tensor is to use the torch.reshape
function or torch.torch.reshape
method of the tensor you’re reshaping. All you have to do is specify a new dimensionality, quite obviously the number of elements in the new dimensionality must be equal to the number of elements in the tensor to be reshaped.
Consider an example with 12 arange elements.
original_tensor = torch.arange(12)
print("original", original_tensor)
original tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
The following cell reshapes it to the (6, 2)
dimensionality.
original_tensor.reshape(6,2)
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11]])
The elements are sequentially obtained in their order, and first fill the innermost dimension, then the next one, and so on.
Same result with more complex result dimentionality.
original_tensor.reshape(3,2,2)
tensor([[[ 0, 1],
[ 2, 3]],
[[ 4, 5],
[ 6, 7]],
[[ 8, 9],
[10, 11]]])
Similarly, when reshaping complex dimensions, it will extract elements from the most nested dimensions. The following example breaks the original sequence of elements with transposition, and then reshapes the resulting array to the 1D dimensionality.
transposed_tensor = original_tensor.reshape(2, 6).T
print("Transposed \n", transposed_tensor)
transposed_tensor.reshape(12)
Transposed
tensor([[ 0, 6],
[ 1, 7],
[ 2, 8],
[ 3, 9],
[ 4, 10],
[ 5, 11]])
tensor([ 0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11])
It takes elements from inndermost dimentionalty gradually moving to more and more external dimensions.
Complete dimentionality#
When using -1
as an argument in the reshape
function, you instruct PyTorch to automatically determine the appropriate dimension size that will match the total number of elements in the tensor.
Applying the reshape(-1, 6)
instruction to the original_tensor
from the previous examples, which contains 12 elements, tells PyTorch to automatically determine the appropriate size for the first dimension. In this case, it calculates that the size of the first dimension should be 2
.
original_tensor.reshape(-1, 6)
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11]])
For a more complex scenario, suppose we want to create 2 matrices with 3 columns each. To accommodate the 12 elements, each matrix must have 2 rows. The following example demonstrates how to achieve this using the reshape
function:
original_tensor.reshape(2, -1, 3)
tensor([[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]]])
If it’s not possible to match the dimensions, PyTorch will display an error message indicating the issue. It is not possible to arrange 12 elements into 3 matrices with 3 rows each, so this will result in an error.
original_tensor.reshape(3,3,-1)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[21], line 1
----> 1 original_tensor.reshape(3,3,-1)
RuntimeError: shape '[3, 3, -1]' is invalid for input of size 12
View method#
This method can work only with contiguous tensors (tensor are stored sequentially in memory without any gaps). It generally faster than reshape
.
The following cell creates the tensor we’ll use as an example.
X = torch.randn(4, 3)
X
tensor([[-0.0944, -0.0557, 0.5670],
[ 0.3236, 2.5825, 0.0342],
[ 1.0671, -1.3104, -2.0697],
[ 1.3963, 0.2144, 1.2199]])
Simply applying the view
method works as expected.
X.view(3, 4)
tensor([[-0.0944, -0.0557, 0.5670, 0.3236],
[ 2.5825, 0.0342, 1.0671, -1.3104],
[-2.0697, 1.3963, 0.2144, 1.2199]])
But transposing breaks the continuity of the tensor, the following cell applies `view’ to the transposed tensor that caused the error message.
X.T.view(4, 3)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[20], line 1
----> 1 X.T.view(4, 3)
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.