Squeeze#

The torch.squeeze function and the torch.Tensor.squeeze method remove dimensions from a tensor that have a size of one. These dimensions are considered “ambiguous” as they contain only a single value.

import torch

Specify dimention#

By specifying an number of the axis in dim parameter, we can reduce a specific dimension.


Tensors with two dimensions containing only one object. These dimensions are effectively useless, as there are no sets or variations along them. Such dimensions can be safely collapsed, treating them as single elements within the outer dimension.

example_tensor = torch.arange(15).reshape(3, 1, 5, 1)
example_tensor
tensor([[[[ 0],
          [ 1],
          [ 2],
          [ 3],
          [ 4]]],


        [[[ 5],
          [ 6],
          [ 7],
          [ 8],
          [ 9]]],


        [[[10],
          [11],
          [12],
          [13],
          [14]]]])

We can unwrap the 1-element vectors into a single vector by reducing along the appropriate axis.

example_tensor.squeeze(-1)
tensor([[[ 0,  1,  2,  3,  4]],

        [[ 5,  6,  7,  8,  9]],

        [[10, 11, 12, 13, 14]]])

Note: Specifying an axis with more than one element will result in the input tensor being returned without any changes. The following example tries to reduce 4-th dimention but it results unchanges input tensor.

example_tensor.squeeze(1)
tensor([[[ 0],
         [ 1],
         [ 2],
         [ 3],
         [ 4]],

        [[ 5],
         [ 6],
         [ 7],
         [ 8],
         [ 9]],

        [[10],
         [11],
         [12],
         [13],
         [14]]])

Practical case#

In regression tasks, a fully connected layer with dimensionality \([D, 1]\) is commonly used as the final layer of a neural network, where \(D\) represents the number of inputs to this layer. This layer outputs a matrix with dimensionality \([n, 1]\), where \(n\) corresponds to the number of objects in the minibatch processed by the network. If your target is simply a vector of \(n\) elements, using this final layer could introduce errors in your network.


Consider layer initialised with ones - so it would be matrix multiplication of input data to matrix:

\[\begin{split}\left(\begin{array}{c} 1 \\ 1 \\ \cdots \\ 1 \end{array}\right) \end{split}\]

This is the sum of the rows in the input data.

samples = 3
layer_size = 3

fc = torch.nn.Linear(layer_size, 1, bias=False)
with torch.no_grad():
    fc.weight.copy_(torch.ones_like(fc.weight))

input_data = torch.arange(
    samples*layer_size, dtype=torch.float
).reshape(samples, layer_size)

ans = fc(input_data)
print(input_data)
ans
tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])
tensor([[ 3.],
        [12.],
        [21.]], grad_fn=<MmBackward0>)

So in case input data is matrix:

\[\begin{split}\left(\begin{array}{c} 1&2&3 \\ 4&5&6 \\ 7&8&9 \end{array}\right)\end{split}\]

Out of the layer we’ll got:

\[\begin{split}\left(\begin{array}{c} 3 \\ 12 \\ 21 \end{array}\right) \end{split}\]

Suppose real target values are: \((0,1,2)\).

y = torch.arange(0, samples)
y
tensor([0, 1, 2])

For computing mse loss, we expect a computation like:

\[\frac{(0-3)^2 + (1-12)^2 + (2-21)^2}{3} = \frac{9 + 121 + 361}{3} = 491/3=163.[6]\]

Applying MSE loss to this would result in:

torch.functional.F.mse_loss(ans, y)
/tmp/ipykernel_7944/3648618532.py:1: UserWarning: Using a target size (torch.Size([3])) that is different to the input size (torch.Size([3, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  torch.functional.F.mse_loss(ans, y)
tensor(175.6667, grad_fn=<MseLossBackward0>)

We received a warning, but a result was produced—though incorrect. This discrepancy stemmed from broadcasting. The target values were broadcasted to a matrix \(Y'\), and the network output to \(\hat{Y'}\):

\[\begin{split}Y' = \left(\begin{array}{c} 0&1&2 \\ 0&1&2 \\ 0&1&2 \end{array}\right), \hat{Y'} = \left(\begin{array}{c} 3&3&3 \\ 12&12&12 \\ 21&21&21 \end{array}\right) \end{split}\]

The element-wise MSE between these matrices would then be precisely 175.[6].

However, after squeezing the network output, we would obtain a result of 163.[6].

torch.functional.F.mse_loss(ans.squeeze(), y)
tensor(163.6667, grad_fn=<MseLossBackward0>)