File I/O
========
So far we have discussed how to process data and how to build, train,
and test deep learning models. However, at some point we will hopefully
be happy enough with the learned models that we will want to save the
results for later use in various contexts (perhaps even to make
predictions in deployment). Additionally, when running a long training
process, the best practice is to periodically save intermediate results
(checkpointing) to ensure that we do not lose several days’ worth of
computation if we trip over the power cord of our server. Thus it is
time to learn how to load and store both individual weight vectors and
entire models. This section addresses both issues.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import torch
from torch import nn
from torch.nn import functional as F
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from mxnet import np, npx
from mxnet.gluon import nn
npx.set_np()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import flax
import jax
from flax import linen as nn
from flax.training import checkpoints
from jax import numpy as jnp
from d2l import jax as d2l
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
WARNING:absl:GlobalAsyncCheckpointManager is not imported correctly. Checkpointing of GlobalDeviceArrays will not be available.To use the feature, install tensorstore.
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import numpy as np
import tensorflow as tf
.. raw:: html
.. raw:: html
Loading and Saving Tensors
--------------------------
For individual tensors, we can directly invoke the ``load`` and ``save``
functions to read and write them respectively. Both functions require
that we supply a name, and ``save`` requires as input the variable to be
saved.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
x = torch.arange(4)
torch.save(x, 'x-file')
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
x = np.arange(4)
npx.save('x-file', x)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[21:49:50] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
x = jnp.arange(4)
jnp.save('x-file.npy', x)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
x = tf.range(4)
np.save('x-file.npy', x)
.. raw:: html
.. raw:: html
We can now read the data from the stored file back into memory.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
x2 = torch.load('x-file')
x2
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([0, 1, 2, 3])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
x2 = npx.load('x-file')
x2
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[array([0., 1., 2., 3.])]
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
x2 = jnp.load('x-file.npy', allow_pickle=True)
x2
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array([0, 1, 2, 3], dtype=int32)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
x2 = np.load('x-file.npy', allow_pickle=True)
x2
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([0, 1, 2, 3], dtype=int32)
.. raw:: html
.. raw:: html
We can store a list of tensors and read them back into memory.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
y = torch.zeros(4)
torch.save([x, y],'x-files')
x2, y2 = torch.load('x-files')
(x2, y2)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
y = np.zeros(4)
npx.save('x-files', [x, y])
x2, y2 = npx.load('x-files')
(x2, y2)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(array([0., 1., 2., 3.]), array([0., 0., 0., 0.]))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
y = jnp.zeros(4)
jnp.save('xy-files.npy', [x, y])
x2, y2 = jnp.load('xy-files.npy', allow_pickle=True)
(x2, y2)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(Array([0., 1., 2., 3.], dtype=float32),
Array([0., 0., 0., 0.], dtype=float32))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
y = tf.zeros(4)
np.save('xy-files.npy', [x, y])
x2, y2 = np.load('xy-files.npy', allow_pickle=True)
(x2, y2)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(array([0., 1., 2., 3.]), array([0., 0., 0., 0.]))
.. raw:: html
.. raw:: html
We can even write and read a dictionary that maps from strings to
tensors. This is convenient when we want to read or write all the
weights in a model.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
mydict = {'x': x, 'y': y}
torch.save(mydict, 'mydict')
mydict2 = torch.load('mydict')
mydict2
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
mydict = {'x': x, 'y': y}
npx.save('mydict', mydict)
mydict2 = npx.load('mydict')
mydict2
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
{'x': array([0., 1., 2., 3.]), 'y': array([0., 0., 0., 0.])}
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
mydict = {'x': x, 'y': y}
jnp.save('mydict.npy', mydict)
mydict2 = jnp.load('mydict.npy', allow_pickle=True)
mydict2
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array({'x': Array([0, 1, 2, 3], dtype=int32), 'y': Array([0., 0., 0., 0.], dtype=float32)},
dtype=object)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
mydict = {'x': x, 'y': y}
np.save('mydict.npy', mydict)
mydict2 = np.load('mydict.npy', allow_pickle=True)
mydict2
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array({'x': , 'y': },
dtype=object)
.. raw:: html
.. raw:: html
Loading and Saving Model Parameters
-----------------------------------
Saving individual weight vectors (or other tensors) is useful, but it
gets very tedious if we want to save (and later load) an entire model.
After all, we might have hundreds of parameter groups sprinkled
throughout. For this reason the deep learning framework provides
built-in functionalities to load and save entire networks. An important
detail to note is that this saves model *parameters* and not the entire
model. For example, if we have a 3-layer MLP, we need to specify the
architecture separately. The reason for this is that the models
themselves can contain arbitrary code, hence they cannot be serialized
as naturally. Thus, in order to reinstate a model, we need to generate
the architecture in code and then load the parameters from disk. Let’s
start with our familiar MLP.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.hidden = nn.LazyLinear(256)
self.output = nn.LazyLinear(10)
def forward(self, x):
return self.output(F.relu(self.hidden(x)))
net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class MLP(nn.Block):
def __init__(self, **kwargs):
super(MLP, self).__init__(**kwargs)
self.hidden = nn.Dense(256, activation='relu')
self.output = nn.Dense(10)
def forward(self, x):
return self.output(self.hidden(x))
net = MLP()
net.initialize()
X = np.random.uniform(size=(2, 20))
Y = net(X)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class MLP(nn.Module):
def setup(self):
self.hidden = nn.Dense(256)
self.output = nn.Dense(10)
def __call__(self, x):
return self.output(nn.relu(self.hidden(x)))
net = MLP()
X = jax.random.normal(jax.random.PRNGKey(d2l.get_seed()), (2, 20))
Y, params = net.init_with_output(jax.random.PRNGKey(d2l.get_seed()), X)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class MLP(tf.keras.Model):
def __init__(self):
super().__init__()
self.flatten = tf.keras.layers.Flatten()
self.hidden = tf.keras.layers.Dense(units=256, activation=tf.nn.relu)
self.out = tf.keras.layers.Dense(units=10)
def call(self, inputs):
x = self.flatten(inputs)
x = self.hidden(x)
return self.out(x)
net = MLP()
X = tf.random.uniform((2, 20))
Y = net(X)
.. raw:: html
.. raw:: html
Next, we store the parameters of the model as a file with the name
“mlp.params”.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
torch.save(net.state_dict(), 'mlp.params')
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net.save_parameters('mlp.params')
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
checkpoints.save_checkpoint('ckpt_dir', params, step=1, overwrite=True)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'ckpt_dir/checkpoint_1'
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net.save_weights('mlp.params')
.. raw:: html
.. raw:: html
To recover the model, we instantiate a clone of the original MLP model.
Instead of randomly initializing the model parameters, we read the
parameters stored in the file directly.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
clone = MLP()
clone.load_state_dict(torch.load('mlp.params'))
clone.eval()
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
MLP(
(hidden): LazyLinear(in_features=0, out_features=256, bias=True)
(output): LazyLinear(in_features=0, out_features=10, bias=True)
)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
clone = MLP()
clone.load_parameters('mlp.params')
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
clone = MLP()
cloned_params = flax.core.freeze(checkpoints.restore_checkpoint('ckpt_dir',
target=None))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
clone = MLP()
clone.load_weights('mlp.params')
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
Since both instances have the same model parameters, the computational
result of the same input ``X`` should be the same. Let’s verify this.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
Y_clone = clone(X)
Y_clone == Y
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
Y_clone = clone(X)
Y_clone == Y
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([[ True, True, True, True, True, True, True, True, True,
True],
[ True, True, True, True, True, True, True, True, True,
True]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
Y_clone = clone.apply(cloned_params, X)
Y_clone == Y
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array([[ True, True, True, True, True, True, True, True, True,
True],
[ True, True, True, True, True, True, True, True, True,
True]], dtype=bool)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
Y_clone = clone(X)
Y_clone == Y
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
Summary
-------
The ``save`` and ``load`` functions can be used to perform file I/O for
tensor objects. We can save and load the entire sets of parameters for a
network via a parameter dictionary. Saving the architecture has to be
done in code rather than in parameters.
Exercises
---------
1. Even if there is no need to deploy trained models to a different
device, what are the practical benefits of storing model parameters?
2. Assume that we want to reuse only parts of a network to be
incorporated into a network having a different architecture. How
would you go about using, say the first two layers from a previous
network in a new network?
3. How would you go about saving the network architecture and
parameters? What restrictions would you impose on the architecture?
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html