Broadcasting#
Broadcasting is a mechanism that allows tensors of different shapes to be used together in element-wise operations.
import torch
Broadcasting rules#
To understand broadcasing consider its rules:
Rule 1 (dimensioning): If tensors have a different number of dimensions, the shape of the smaller tensor is padded with ones on the left side until both tensors have the same number of dimensions.
Rule 2 (compatability): Two dimensions are compatible when: They are equal, or one of them is 1
Rule 3 (broadcasting): If the dimensions are compatible, the tensors are broadcasted to the shape of the larger tensor. If not compatible, an error is raised.
Look at each rule separately with examples.
Dimensioning#
To apply element-wise operations, tensors must have the same number of dimensions. By default, if two tensors have different dimensionalities, the tensor with fewer dimensions will be expanded to match the higher-dimensional tensor. More formally, given two tensors \(A\) and \(B\) with dimensionalities \((d_1, d_2, ..., d_n)\) and \((d'_1, d'_2, ... , d'_m)\) respectively, where \(m < n\), PyTorch will transform the dimensionality of \(B\) to \((d''_1, d''_2, ..., d''_{n-m}, d''_{n-m+1}, ..., d''_n)\) where:
Consider an example where we have tensors \(A\) and \(B\) with dimensionalities: \((1, 2, 3, 2, 3)\) and \((3, 2, 3)\) respectively.
By default, PyTorch will transform \(B\) to have a dimensionality of \((1, 1, 3, 2, 3)\). It adds ones to the beginning of the dimensionality until it reaches the same length as \(A\) (5 dimensions), and then uses the original dimensions of \(B\).
When a tensor is transformed to match another tensor’s dimensionality, it’s effectively being wrapped in a series of additional dimensions, each of size 1. This creates a “shell” around the original tensor, expanding it to have the same number of dimensions as the target tensor.
Let’s consider a code example that demonstrates this dimensionality transformation. We’ll add an element to a tensor and observe how PyTorch handles the broadcasting.
a = torch.arange(4).reshape(2,2)
a
tensor([[0, 1],
[2, 3]])
As a second tensor, we’ll use a one-dimensional tensor b
:
b = torch.tensor([5, 6])
b
tensor([5, 6])
There could be two ways to add two elements to the tensor a
. By rows and by columns the following cell shows both options.
print("Adding")
print(b[:, None])
a + b[:, None]
Adding
tensor([[1],
[2]])
tensor([[1, 2],
[4, 5]])
print("Adding", b[None])
a + b[None]
Adding tensor([[1, 2]])
tensor([[1, 3],
[3, 5]])
Before we specify the exact method for adding the missing dimensionality, let’s examine how PyTorch chooses the default approach. This will be demonstrated in the following cell.
a + b
tensor([[1, 3],
[3, 5]])
By default, PyTorch added the extra dimension by rows. This behavior is because wrapping a one-dimensional tensor in an outer dimension results in a one-row matrix. The two operands then share the column dimension, enabling element-wise operations.
Compatability#
After padding tensors with smaller dimensions (if necessary), the resulting tensors must have compatible dimensions. This means that corresponding dimensions should be equal or one of them should be 1.
More formally, if we have two tensors \(A\) and \(B\) with dimensions \((d'_1, d'_2, ..., d'_n)\) and \((d''_1, d''_2, ..., d''_n)\) after padding, then for element-wise operations in PyTorch, the following condition must hold:
\(d'_i = d''_i\) or \(d'_i = 1\) or \(d''_i = 1\) for \(i = 1, 2, ..., n\).
The tensors with dimensions \((1, 3, 2)\) and \((2, 3, 1)\) are indeed compatible for element-wise operations because they satisfy the conditions:
\(d'_1 = 1\) The first dimension of the first tensor is 1.
\(d'_2 = d''_2\) The second dimensions of both tensors are equal (3).
\(d''_3 = 1\) The third dimension of the second tensor is 1.
A = torch.arange(6).reshape(1,3,2)
A
tensor([[[0, 1],
[2, 3],
[4, 5]]])
B = torch.arange(6).reshape(2, 3, 1)
B
tensor([[[0],
[1],
[2]],
[[3],
[4],
[5]]])
After applying element-wise operations, everything works as expected.
A + B
tensor([[[ 0, 1],
[ 3, 4],
[ 6, 7]],
[[ 3, 4],
[ 6, 7],
[ 9, 10]]])
However, any deviation from this rule results in a PyTorch error.
B = torch.arange(6).reshape(2, 2, 1)
A + B
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[10], line 1
----> 1 B = torch.arange(6).reshape(2, 2, 1)
2 A+B
RuntimeError: shape '[2, 2, 1]' is invalid for input of size 6
Broadcasting#
Essentially, broadcasting aims to bring the operands to the same dimensionality before performing the operation. After previous rules if the number of elements in corresponding dimensions differs, it’s guaranteed that one dimension has a size of one. In this case, the dimension with size one will be expanded to match the corresponding dimension of the other operand by repeating its single element the required number of times.
Let’s reinforce the previous rule with a few examples.
Let’s break down the broadcasting process for these operands:
[[1, 2]]
: This tensor has dimensions (1, 2). To match the dimensions of[[1], [2]]
, which are (2, 1), it’s broadcasted by repeating its single row twice, resulting in[[1, 2], [1, 2]]
.[[1], [2]]
: This tensor has dimensions (2, 1). To match the dimensions of[[1, 2]]
, it’s broadcasted by repeating its single element in each row twice, resulting in[[1, 1], [2, 2]]
.
After that result is obvious.
torch.tensor([[1,2]]) + torch.tensor([[1], [2]])
tensor([[2, 3],
[3, 4]])
Consider the following more complex 3-dimensional example with tensors:
Tensor A:
[
[
[1],
[2]
],
[
[3],
[4]
]
]
has dimensions (2, 2, 1)
.
Tensor B:
[
[
[1, 2],
[2, 3]
]
]
has dimensions (1, 2, 2)
.
To match the dimensions of Tensor B, we need to add extra elements to the rows of Tensor A, resulting in:
[
[
[1, 1],
[2, 2]
],
[
[3, 3],
[4, 4]
]
]
This gives Tensor A dimensions (2, 2, 2)
.
For Tensor B, we need to add extra layers to the third dimension by repeating its existing layer. The modified Tensor B becomes:
[
[
[1, 2],
[2, 3]
],
[
[1, 2],
[2, 3]
]
]
This gives Tensor B dimensions (2, 2, 2)
.
Finally, performing element-wise operations:
A = torch.tensor([
[
[1],
[2]
],
[
[3],
[4]
]
])
B = torch.tensor([
[
[1, 2],
[2, 3]
]
])
A + B
tensor([[[2, 3],
[4, 5]],
[[4, 5],
[6, 7]]])
Practical example#
Normalizing data can be achieved efficiently and easily using broadcasting techniques. Imagine we have four objects, each represented by a 3x3 matrix. This gives us a three-dimensional tensor containing these matrices. However, each object has a different base value. The following code snippet generates and prints such a tensor:
value = torch.stack([
torch.normal(mean=mean, std=3., size=(3, 3))
for mean in [10, 100, 500, 1000]
])
value
tensor([[[ 12.8460, 11.6038, 8.2267],
[ 13.3707, 13.0595, 9.3532],
[ 10.6257, 10.8800, 5.5372]],
[[ 102.3054, 98.2987, 89.9491],
[ 103.9013, 98.5896, 98.1971],
[ 106.4493, 98.1544, 96.3743]],
[[ 501.0322, 499.2918, 504.5105],
[ 496.8454, 498.7290, 499.3281],
[ 499.4231, 504.0831, 499.4671]],
[[ 997.9377, 1001.2782, 1005.9980],
[ 999.5461, 1001.3517, 997.5663],
[1003.1572, 999.6476, 1001.6178]]])
Let’s say we want to standardize all of these objects, meaning we want to bring them to zero mean and unit variance.
First, we need to calculate the mean and standard deviation for each object we’re considering.
mean = value.mean(axis=[1,2], keepdim=True)
print(mean)
std = value.std(axis=[1,2], keepdim=True)
print(std)
tensor([[[ 10.6114]],
[[ 99.1355]],
[[ 500.3011]],
[[1000.9001]]])
tensor([[[2.5660]],
[[4.7580]],
[[2.5100]],
[[2.6277]]])
The result is two vectors, each containing a value for every object in the input array. These vectors have the same dimensionality as the objects themselves - the keepdim=True
argument ensures this.
Now, by subtracting the means from the original arrays and dividing the result by the standard deviations, we’ve effectively scaled all the objects to the same scale.
(value - mean)/std
tensor([[[ 0.8708, 0.3867, -0.9293],
[ 1.0753, 0.9540, -0.4904],
[ 0.0056, 0.1047, -1.9775]],
[[ 0.6662, -0.1759, -1.9307],
[ 1.0016, -0.1147, -0.1972],
[ 1.5372, -0.2062, -0.5803]],
[[ 0.2912, -0.4021, 1.6770],
[-1.3768, -0.6263, -0.3877],
[-0.3498, 1.5068, -0.3323]],
[[-1.1274, 0.1439, 1.9401],
[-0.5153, 0.1719, -1.2687],
[ 0.8590, -0.4767, 0.2731]]])