6.6. File I/O
Open the notebook in Colab
Open the notebook in Colab
Open the notebook in Colab
Open the notebook in Colab
Open the notebook in SageMaker Studio Lab

So far we have discussed how to process data and how to build, train, and test deep learning models. However, at some point we will hopefully be happy enough with the learned models that we will want to save the results for later use in various contexts (perhaps even to make predictions in deployment). Additionally, when running a long training process, the best practice is to periodically save intermediate results (checkpointing) to ensure that we do not lose several days’ worth of computation if we trip over the power cord of our server. Thus it is time to learn how to load and store both individual weight vectors and entire models. This section addresses both issues.

import torch
from torch import nn
from torch.nn import functional as F
from mxnet import np, npx
from mxnet.gluon import nn

npx.set_np()
import flax
import jax
from flax import linen as nn
from flax.training import checkpoints
from jax import numpy as jnp
from d2l import jax as d2l
WARNING:absl:GlobalAsyncCheckpointManager is not imported correctly. Checkpointing of GlobalDeviceArrays will not be available.To use the feature, install tensorstore.
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import numpy as np
import tensorflow as tf

6.6.1. Loading and Saving Tensors

For individual tensors, we can directly invoke the load and save functions to read and write them respectively. Both functions require that we supply a name, and save requires as input the variable to be saved.

x = torch.arange(4)
torch.save(x, 'x-file')
x = np.arange(4)
npx.save('x-file', x)
[21:49:50] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
x = jnp.arange(4)
jnp.save('x-file.npy', x)
x = tf.range(4)
np.save('x-file.npy', x)

We can now read the data from the stored file back into memory.

x2 = torch.load('x-file')
x2
tensor([0, 1, 2, 3])
x2 = npx.load('x-file')
x2
[array([0., 1., 2., 3.])]
x2 = jnp.load('x-file.npy', allow_pickle=True)
x2
Array([0, 1, 2, 3], dtype=int32)
x2 = np.load('x-file.npy', allow_pickle=True)
x2
array([0, 1, 2, 3], dtype=int32)

We can store a list of tensors and read them back into memory.

y = torch.zeros(4)
torch.save([x, y],'x-files')
x2, y2 = torch.load('x-files')
(x2, y2)
(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))
y = np.zeros(4)
npx.save('x-files', [x, y])
x2, y2 = npx.load('x-files')
(x2, y2)
(array([0., 1., 2., 3.]), array([0., 0., 0., 0.]))
y = jnp.zeros(4)
jnp.save('xy-files.npy', [x, y])
x2, y2 = jnp.load('xy-files.npy', allow_pickle=True)
(x2, y2)
(Array([0., 1., 2., 3.], dtype=float32),
 Array([0., 0., 0., 0.], dtype=float32))
y = tf.zeros(4)
np.save('xy-files.npy', [x, y])
x2, y2 = np.load('xy-files.npy', allow_pickle=True)
(x2, y2)
(array([0., 1., 2., 3.]), array([0., 0., 0., 0.]))

We can even write and read a dictionary that maps from strings to tensors. This is convenient when we want to read or write all the weights in a model.

mydict = {'x': x, 'y': y}
torch.save(mydict, 'mydict')
mydict2 = torch.load('mydict')
mydict2
{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}
mydict = {'x': x, 'y': y}
npx.save('mydict', mydict)
mydict2 = npx.load('mydict')
mydict2
{'x': array([0., 1., 2., 3.]), 'y': array([0., 0., 0., 0.])}
mydict = {'x': x, 'y': y}
jnp.save('mydict.npy', mydict)
mydict2 = jnp.load('mydict.npy', allow_pickle=True)
mydict2
array({'x': Array([0, 1, 2, 3], dtype=int32), 'y': Array([0., 0., 0., 0.], dtype=float32)},
      dtype=object)
mydict = {'x': x, 'y': y}
np.save('mydict.npy', mydict)
mydict2 = np.load('mydict.npy', allow_pickle=True)
mydict2
array({'x': <tf.Tensor: shape=(4,), dtype=int32, numpy=array([0, 1, 2, 3], dtype=int32)>, 'y': <tf.Tensor: shape=(4,), dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>},
      dtype=object)

6.6.2. Loading and Saving Model Parameters

Saving individual weight vectors (or other tensors) is useful, but it gets very tedious if we want to save (and later load) an entire model. After all, we might have hundreds of parameter groups sprinkled throughout. For this reason the deep learning framework provides built-in functionalities to load and save entire networks. An important detail to note is that this saves model parameters and not the entire model. For example, if we have a 3-layer MLP, we need to specify the architecture separately. The reason for this is that the models themselves can contain arbitrary code, hence they cannot be serialized as naturally. Thus, in order to reinstate a model, we need to generate the architecture in code and then load the parameters from disk. Let’s start with our familiar MLP.

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.LazyLinear(256)
        self.output = nn.LazyLinear(10)

    def forward(self, x):
        return self.output(F.relu(self.hidden(x)))

net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)
class MLP(nn.Block):
    def __init__(self, **kwargs):
        super(MLP, self).__init__(**kwargs)
        self.hidden = nn.Dense(256, activation='relu')
        self.output = nn.Dense(10)

    def forward(self, x):
        return self.output(self.hidden(x))

net = MLP()
net.initialize()
X = np.random.uniform(size=(2, 20))
Y = net(X)
class MLP(nn.Module):
    def setup(self):
        self.hidden = nn.Dense(256)
        self.output = nn.Dense(10)

    def __call__(self, x):
        return self.output(nn.relu(self.hidden(x)))

net = MLP()
X = jax.random.normal(jax.random.PRNGKey(d2l.get_seed()), (2, 20))
Y, params = net.init_with_output(jax.random.PRNGKey(d2l.get_seed()), X)
class MLP(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.flatten = tf.keras.layers.Flatten()
        self.hidden = tf.keras.layers.Dense(units=256, activation=tf.nn.relu)
        self.out = tf.keras.layers.Dense(units=10)

    def call(self, inputs):
        x = self.flatten(inputs)
        x = self.hidden(x)
        return self.out(x)

net = MLP()
X = tf.random.uniform((2, 20))
Y = net(X)

Next, we store the parameters of the model as a file with the name “mlp.params”.

torch.save(net.state_dict(), 'mlp.params')
net.save_parameters('mlp.params')
checkpoints.save_checkpoint('ckpt_dir', params, step=1, overwrite=True)
'ckpt_dir/checkpoint_1'
net.save_weights('mlp.params')

To recover the model, we instantiate a clone of the original MLP model. Instead of randomly initializing the model parameters, we read the parameters stored in the file directly.

clone = MLP()
clone.load_state_dict(torch.load('mlp.params'))
clone.eval()
MLP(
  (hidden): LazyLinear(in_features=0, out_features=256, bias=True)
  (output): LazyLinear(in_features=0, out_features=10, bias=True)
)
clone = MLP()
clone.load_parameters('mlp.params')
clone = MLP()
cloned_params = flax.core.freeze(checkpoints.restore_checkpoint('ckpt_dir',
                                                                target=None))
clone = MLP()
clone.load_weights('mlp.params')
<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7fe86a880820>

Since both instances have the same model parameters, the computational result of the same input X should be the same. Let’s verify this.

Y_clone = clone(X)
Y_clone == Y
tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])
Y_clone = clone(X)
Y_clone == Y
array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True]])
Y_clone = clone.apply(cloned_params, X)
Y_clone == Y
Array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True]], dtype=bool)
Y_clone = clone(X)
Y_clone == Y
<tf.Tensor: shape=(2, 10), dtype=bool, numpy=
array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True]])>

6.6.3. Summary

The save and load functions can be used to perform file I/O for tensor objects. We can save and load the entire sets of parameters for a network via a parameter dictionary. Saving the architecture has to be done in code rather than in parameters.

6.6.4. Exercises

  1. Even if there is no need to deploy trained models to a different device, what are the practical benefits of storing model parameters?

  2. Assume that we want to reuse only parts of a network to be incorporated into a network having a different architecture. How would you go about using, say the first two layers from a previous network in a new network?

  3. How would you go about saving the network architecture and parameters? What restrictions would you impose on the architecture?