Data primitives#

Torch implements its approach to organizing data management. It assumes that you have two objects: Dataset and DataLoader. The Dataset holds the data and allows access to individual data points, while the DataLoader organizes the data into mini-batches and provides an iterable interface for iterating over them.

For the fullest description visit torch.utils.data tutorial.

import torch
import torch.utils.data as td
from torch.utils.data import DataLoader

Data set#

A data set in Torch is a special type of object that prepares data and returns individual data units with indexing syntax. It have to implement such methods:

  • __len__: returns the number of elements in the dataset.

  • __getitem__: implement the [] operator for objects of dataset.

Check more in the corresponding page.


The following cell shows a simple dataset that wraps python list as a torch primitive.

class Example(td.Dataset):
    def __init__(self, data: list[int]) -> None:
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i: int) -> torch.Tensor:
        return torch.tensor(self.data[i])

For each index, you will get the corresponding element, but it will be transformed into a torch tensor.

data_set = Example([3, 2, 7, 3])
data_set[2]
tensor(7)

Data loader#

A DataLoader in PyTorch is an object that simplifies the process of splitting data into batches.

Find out more in the torch.utils.data.DataLoader section of the official documentation.

Drop incomplete batch#

The drop_last argument in torch.DataLoader controls whether the final batch is dropped if it doesn’t contain enough elements to complete a full batch. If drop_last=True, any remaining samples that don’t fit into a complete batch will be skipped.


The following cell defines a TensorDataset tensor that used as base for dataset is showen.

samples = 14
dimentinarity = 5

input_tensor = (
    torch.arange(samples*dimentinarity)
    .reshape(samples, dimentinarity)
)
print(input_tensor)

dataset = torch.utils.data.TensorDataset(input_tensor)
tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39],
        [40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49],
        [50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59],
        [60, 61, 62, 63, 64],
        [65, 66, 67, 68, 69]])

Suppose we decided to use batch_size=4. Since our 14 samples can’t be evenly split into 4-size batches, the following cell defines such a DataLoader and prints all its batches.

data_loader = DataLoader(
    dataset, 
    batch_size=4,
    drop_last=True
)

for d in data_loader:
    print(d)
[tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]])]
[tensor([[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]])]
[tensor([[40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49],
        [50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]])]

The numbers from the last two samples (from 60 to 69) haven’t been printed because they didn’t form a complete batch, and thus were not included.

Collate function#

You specify how entities from the dataset should be joined into batches by setting the collate_fn argument of the DataLoader.

The collate_fn is a function that processes a list of tuples, where each tuple represents the outputs from the dataset—typically in the form (X, y). collate_fn should return torch.Tensor, but in some cases output can be different.


Consider example where we need to build dataset over tensor dataloader cerated in the following cell.

samples = 4
dimentinarity = 5

input_tensor = (
    torch.arange(samples*dimentinarity)
    .reshape(samples, dimentinarity)
)
print(input_tensor)

dataset = torch.utils.data.TensorDataset(input_tensor)
tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]])

Here is fucntion, which we will try to pass as a collate_fn argument. It prints the input passed to the function to check that we got in function exactly what we expected to get. Returns a stack of input tensors.

def check_function(batch: list[tuple[torch.Tensor]]) -> torch.Tensor:
    print("I got:", batch)
    return torch.stack(list(zip(*batch))[0])

Here is an example of its usage; everything works just as expected.

data_loader = DataLoader(
    dataset=dataset,
    collate_fn=check_function,
    batch_size=2
)

for batch in data_loader:
    print(batch)
I got: [(tensor([0, 1, 2, 3, 4]),), (tensor([5, 6, 7, 8, 9]),)]
tensor([[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]])
I got: [(tensor([10, 11, 12, 13, 14]),), (tensor([15, 16, 17, 18, 19]),)]
tensor([[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]])