Basic#
Here are the basics of working with images in Torch.
import numpy as np
from PIL import Image
import torchvision
from torchvision import transforms
from IPython import display
import matplotlib.pyplot as plt
Load image#
The PIL
library can be used to load images. For example, we will use a really simple small picture to see patterns in the matrix representation of the picture. So in the following cell I load and display the picture that will be used:
simple_image_path = "basic_files/test_picture.png"
image = Image.open(simple_image_path)
image

The picture is very small, so it can be difficult to understand what exactly is on the picture. In the following cell I use matplotlib to dispalay interpolation of the picture:
plt.imshow(
plt.imread(simple_image_path, format = "jpg"),
interpolation='nearest'
)
plt.xticks([])
plt.yticks([])
plt.show()

And now let’s tranfrom picture to the torch tensor. I deliberately kept the picture using only 3 colour channels. So now we can see the tensor of dimension (3, <picture width>, <picture height>)
, that literaly follows the pattern of the picture.
transforms.ToTensor()(image)
tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980]],
[[0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980]],
[[0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980],
[0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0980, 0.2000, 0.2980]]])
Image sample normalisation#
cifar10 = torchvision.datasets.CIFAR10(
"cifar10",
download = True,
transform = transforms.ToTensor()
)
Files already downloaded and verified
means = (cifar10.data / 255).mean(axis = (0,1,2))
print("means", means)
stds = (cifar10.data / 255).std(axis = (0,1,2))
print("stds", stds)
means [0.49139968 0.48215841 0.44653091]
stds [0.24703223 0.24348513 0.26158784]