Indexing#
Indexing in PyTorch follows similar patterns to other libraries that handle arrays. Using the []
operator, you can select various subsets of the original tensor. This page highlights the options and features available for indexing with the []
operator in PyTorch tensors.
import torch
from math import prod
Slicing#
Slicing is the most popular way to select subsets of the tensor. We will explore these options using the following example:
experimental = torch.arange(15).reshape(3, 5)
experimental
tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
It might seem obvious, but elements defined by slices are independent of the specifications of other axes. When you use slices to define selections across multiple axes, it will include all elements that fall within all slices. This behavior differs from selecting elements using iterable objects, as we will explore further.
The following example shows that if we specified rows 1, 2 and columns 0, 2, 4 with slices we selected elements with all combinations of the indeces specified in slices: [1, 0], [1, 2], [1, 4], [2, 0], [2, 2], [2, 4].
experimental[1:3, 0:5:2]
tensor([[ 5, 7, 9],
[10, 12, 14]])
A specific case of slicing involves selecting elements along a particular axis, which can be done using the :
operator for that axis.
The expression functions similarly to a slice applied to all objects.
experimental[1:3, :]
tensor([[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
PyTorch does not directly support reverse slicing. However, you can achieve the same result using either the torch.flip
function or the torch.Tensor.flip
method.
The following syntax attempts to achieve taking the first two columns of a tensor in reverse order using slices:
experimental[:, 3:0:-1]
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[13], line 1
----> 1 experimental[:, 3:0:-1]
ValueError: step must be greater than zero
We encountered an error stating that the step value must be greater than zero.
However, we can easily achieve our goal using the flip
method.
experimental[:, :3:2].flip(1)
tensor([[ 2, 0],
[ 7, 5],
[12, 10]])
Iterable objects#
The option with iterable objects allows you to specify the exact combinations of elements along different axes that you want to select. We’ll explore different options using the example provided in the cell below:
experimental = torch.arange(15).reshape(3, 5)
experimental
tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
When specifying an iterable for selection along one axis, it’s quite straightforward—you simply get a subset along the specified axis.
The following cell selects rows 1 and 2 from the original tensor.
experimental[[1,2]]
tensor([[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
But when it comes to using iterables as selection along different axes - it’ll be interpreted as selection of the combinations of the axes.
In the following cell we specify selection [[0, 1], [2, 3]]
which istrutcs torch to take elements with idices [0,2]
and [1,3]
.
experimental[[0, 1], [2, 3]]
tensor([2, 8])
If you specify iterables with a number of elements that does not match the dimensions of the tensor, you’ll encounter an IndexError
in PyTorch.
In the following example, we attempt to pass iterables of different sizes to PyTorch for selection, resulting in an error.
experimental[[0, 1, 2], [0, 1]]
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[63], line 1
----> 1 experimental[[0, 1, 2], [0, 1]]
IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [3], [2]
Note: Selection of single elements behaves slightly differently; it essentially acts as if you’re selecting just that one element.
The following example demonstrates that even when selecting a single element for columns, which does not correspond to row selection, everything works as expected.
experimental[[0, 1, 2], [1]]
tensor([ 1, 6, 11])
Uspecified axes#
If you don’t specify the inner indices, PyTorch will include all elements along those axes by default - so it equivalent to the :
. The following cell demonstrates this by creating a 3-dimensional tensor. It shows that the expressions temp_tensor[i]
, temp_tensor[i, :]
, and temp_tensor[i, :, :]
all yield the same result.
temp_tensor = torch.arange(12).reshape(2, 2, 3)
print(temp_tensor[1])
print(temp_tensor[1, :])
print(temp_tensor[1,:,:])
tensor([[ 6, 7, 8],
[ 9, 10, 11]])
tensor([[ 6, 7, 8],
[ 9, 10, 11]])
tensor([[ 6, 7, 8],
[ 9, 10, 11]])
New dimention#
Using None
in indexing allows you to add new dimensions to the tensor. It’ll add dimetion on the axis None
displayed on.
As example, we’ll considering two dimention tensor created in the following cell.
none_example_tensor = torch.arange(12).reshape(4,3)
none_example_tensor
tensor([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
By putting None
at the top, we’ve got a new outer axis, so the whole matrix has been wrapped in the extra layer.
res = none_example_tensor[None]
print(res)
res.shape
tensor([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]]])
torch.Size([1, 4, 3])
By placing None
on the second axis, we’ll create a new axis at the second position - each row of the matrix will be wrapped as a separate one-row matrix.
res = none_example_tensor[:, None]
print(res)
print(res.shape)
tensor([[[ 0, 1, 2]],
[[ 3, 4, 5]],
[[ 6, 7, 8]],
[[ 9, 10, 11]]])
torch.Size([4, 1, 3])
By placing None
on the last axis, we create a new axis as the most nested axis - each number of the input tensor is now a small one-element vector.
res = none_example_tensor[:, :, None]
print(res)
print(res.shape)
tensor([[[ 0],
[ 1],
[ 2]],
[[ 3],
[ 4],
[ 5]],
[[ 6],
[ 7],
[ 8]],
[[ 9],
[10],
[11]]])
torch.Size([4, 3, 1])