Basic#

Here are the basics of working with images in Torch.

import numpy as np

from PIL import Image

import torchvision
from torchvision import transforms

from IPython import display

import matplotlib.pyplot as plt

Load image#

The PIL library can be used to load images. For example, we will use a really simple small picture to see patterns in the matrix representation of the picture. So in the following cell I load and display the picture that will be used:

simple_image_path = "basic_files/test_picture.png"
image = Image.open(simple_image_path)
image
../../_images/1d24198f7d88c4402d1172e71f81a673f0deaab115417fa940c314dd06f053df.png

The picture is very small, so it can be difficult to understand what exactly is on the picture. In the following cell I use matplotlib to dispalay interpolation of the picture:

plt.imshow(
    plt.imread(simple_image_path, format = "jpg"), 
    interpolation='nearest'
)
plt.xticks([])
plt.yticks([])
plt.show()
../../_images/89f456335d3be2fb38e1e02a1d319eb8b4df42fbc9ea724aa30022540f49552d.png

And now let’s tranfrom picture to the torch tensor. I deliberately kept the picture using only 3 colour channels. So now we can see the tensor of dimension (3, <picture width>, <picture height>), that literaly follows the pattern of the picture.

transforms.ToTensor()(image)
tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980]],

        [[0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980]],

        [[0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
         [0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980]]])

Image sample normalisation#

cifar10 = torchvision.datasets.CIFAR10(
    "cifar10",
    download = True,
    transform = transforms.ToTensor()
)
Files already downloaded and verified
means = (cifar10.data / 255).mean(axis = (0,1,2))
print("means", means)
stds = (cifar10.data / 255).std(axis = (0,1,2))
print("stds", stds)
means [0.49139968 0.48215841 0.44653091]
stds [0.24703223 0.24348513 0.26158784]