Saving model#

Saving and loading PyTorch models is crucial because any model you build needs to be transferred and deployed in some way. Check the official tutorial. Here, we’ll experiment with the options from the tutorial.

import torch
from torch import nn

import itertools
from pathlib import Path

State dict#

The classical method to save and load a model’s state dictionary follows these steps:

  • Retrieve the model’s state dictionary with nn.Module.state_dict().

  • Save the state dictionary with torch.save.

  • Load the state dictionary with torch.load.

  • Load the weights into the model using nn.Module.load_state_dict().


In the following cell, we created a simple model and initialized it with a constant value.

model = nn.Sequential(
    nn.Linear(3, 3),
    nn.Linear(3, 3)
)
for p in model.parameters():
    nn.init.constant_(p, 3)

model.state_dict()
OrderedDict([('0.weight',
              tensor([[3., 3., 3.],
                      [3., 3., 3.],
                      [3., 3., 3.]])),
             ('0.bias', tensor([3., 3., 3.])),
             ('1.weight',
              tensor([[3., 3., 3.],
                      [3., 3., 3.],
                      [3., 3., 3.]])),
             ('1.bias', tensor([3., 3., 3.]))])

Now, using torch.save, we save the state dictionary and discard the original model.

torch.save(obj=model.state_dict(), f=Path("/tmp")/"my_model")
del model

Now, with torch.load, we load the state dictionary—since all values were constant during saving, they remain as 3.

state_dict = torch.load(Path("/tmp")/"my_model", weights_only=False)
state_dict
OrderedDict([('0.weight',
              tensor([[3., 3., 3.],
                      [3., 3., 3.],
                      [3., 3., 3.]])),
             ('0.bias', tensor([3., 3., 3.])),
             ('1.weight',
              tensor([[3., 3., 3.],
                      [3., 3., 3.],
                      [3., 3., 3.]])),
             ('1.bias', tensor([3., 3., 3.]))])

By recreating the exact same model and loading the previously saved state dictionary into it, you can fully recreate the model.

model = nn.Sequential(
    nn.Linear(3, 3),
    nn.Linear(3, 3)
)

model.load_state_dict(state_dict)
<All keys matched successfully>

Successfully executed nn.load_state_dict returns a special string.

Save entire model#

By passing the entire model to the torch.save function, you’ll save a serialized Torch model. Then, with just one line of code, you can restore the model using the torch.load function.


The following cell demonstrates creating a model, initializing its weights with a constant, and saving this model to disk by passing it as obj to the torch.save function.

model = nn.Sequential(
    nn.Linear(3, 3),
    nn.Linear(3, 3)
)
for p in model.parameters():
    nn.init.constant_(p, 3)

torch.save(model, Path("/tmp")/"model.pht")

With torch.load, the model can be restored. The following cell shows that the loaded model has weights identical to those initialized before saving.

torch.load(Path("/tmp")/"model.pht", weights_only=False).state_dict()
OrderedDict([('0.weight',
              tensor([[3., 3., 3.],
                      [3., 3., 3.],
                      [3., 3., 3.]])),
             ('0.bias', tensor([3., 3., 3.])),
             ('1.weight',
              tensor([[3., 3., 3.],
                      [3., 3., 3.],
                      [3., 3., 3.]])),
             ('1.bias', tensor([3., 3., 3.]))])

Saving group#

A typical case where this approach requires two torch objects, such as a generator and discriminator in a GAN, is straightforward. It’s not a problem to save a batch of models grouped in any Python collection.


As an example, consider two linear layers that we want to save as a single file.

val1 = nn.Linear(3, 3)
val2 = nn.Linear(3, 3)

for p in itertools.chain(val1.parameters(), val2.parameters()):
    nn.init.constant_(p, 3)

We’ll save models as dictionary.

torch.save(
    obj={"val1": val1, "val2": val2},
    f=Path("/tmp")/"model.pth"
)

So there are no problems to load that dictionary:

ans = torch.load(Path("/tmp")/"model.pth", weights_only=False)
ans
{'val1': Linear(in_features=3, out_features=3, bias=True),
 'val2': Linear(in_features=3, out_features=3, bias=True)}

To be sure, let’s check that the extracted model has weights just as we initialized them before saving.

ans["val1"].state_dict()
OrderedDict([('weight',
              tensor([[3., 3., 3.],
                      [3., 3., 3.],
                      [3., 3., 3.]])),
             ('bias', tensor([3., 3., 3.]))])

The same trick will work with tuples as well.

torch.save(obj=(val1, val2), f=Path("/tmp")/"model.pth")
torch.load(Path("/tmp")/"model.pth", weights_only=False)
(Linear(in_features=3, out_features=3, bias=True),
 Linear(in_features=3, out_features=3, bias=True))

Different devices#

Saved model remembers device it was saved in. So it’s typical problem to load model that used GPU as divice during saving to cpu only machine. There are special section “saving loading model across devices” in pytorch tutorial.

map_location parameter in all model loading utilities allows to specify at which device model have to be loaded.


Here is the code that stores a simple model that runs on the GPU, using serialisation and state dict approaches.

model = nn.Sequential(
    nn.Linear(3, 3),
    nn.Linear(3, 3)
)
model.to(device=torch.device('cuda'))

files_path = Path("saving_model_fiels")
files_path.mkdir(exist_ok=True)

torch.save(model, f=files_path/"entire_model.pth")
torch.save(model.state_dict(), f=files_path/"state_dict.pth")

If you try to load something stored on the GPU on the single-CPU machine, you’ll get a corresponding error.

files_path = Path("saving_model_files")
try:
    torch.load(files_path/"entire_model.pth")
except Exception as e:
    print(e)
Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
/tmp/ipykernel_243924/253386758.py:3: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  torch.load(files_path/"entire_model.pth")

By setting the device in the map_location argument, everything works correctly.

torch.load(
    files_path/"entire_model.pth", 
    map_location=torch.device('cpu'), 
    weights_only=False
)
Sequential(
  (0): Linear(in_features=3, out_features=3, bias=True)
  (1): Linear(in_features=3, out_features=3, bias=True)
)
torch.load(
    files_path/"state_dict.pth", 
    map_location=torch.device('cpu'), 
    weights_only=False
)
OrderedDict([('0.weight',
              tensor([[-0.5407, -0.0230, -0.4316],
                      [-0.0598,  0.4279, -0.1093],
                      [-0.1346,  0.4117, -0.0023]])),
             ('0.bias', tensor([-0.4625,  0.4542, -0.1207])),
             ('1.weight',
              tensor([[-0.4032, -0.2615, -0.4629],
                      [-0.3916,  0.3262, -0.0923],
                      [-0.4275, -0.5173,  0.2389]])),
             ('1.bias', tensor([ 0.2098, -0.2383, -0.2560]))])