Parameter Initialization ======================== Now that we know how to access the parameters, let’s look at how to initialize them properly. We discussed the need for proper initialization in :numref:`sec_numerical_stability`. The deep learning framework provides default random initializations to its layers. However, we often want to initialize our weights according to various other protocols. The framework provides most commonly used protocols, and also allows to create a custom initializer. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import torch from torch import nn By default, PyTorch initializes weight and bias matrices uniformly by drawing from a range that is computed according to the input and output dimension. PyTorch’s ``nn.init`` module provides a variety of preset initialization methods. .. raw:: latex \diilbookstyleinputcell .. code:: python net = nn.Sequential(nn.LazyLinear(8), nn.ReLU(), nn.LazyLinear(1)) X = torch.rand(size=(2, 4)) net(X).shape .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output torch.Size([2, 1]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python from mxnet import init, np, npx from mxnet.gluon import nn npx.set_np() By default, MXNet initializes weight parameters by randomly drawing from a uniform distribution :math:`U(-0.07, 0.07)`, clearing bias parameters to zero. MXNet’s ``init`` module provides a variety of preset initialization methods. .. raw:: latex \diilbookstyleinputcell .. code:: python net = nn.Sequential() net.add(nn.Dense(8, activation='relu')) net.add(nn.Dense(1)) net.initialize() # Use the default initialization method X = np.random.uniform(size=(2, 4)) net(X).shape .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [22:10:04] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output (2, 1) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import jax from flax import linen as nn from jax import numpy as jnp from d2l import jax as d2l .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) By default, Flax initializes weights using ``jax.nn.initializers.lecun_normal``, i.e., by drawing samples from a truncated normal distribution centered on 0 with the standard deviation set as the squared root of :math:`1 / \textrm{fan}_{\textrm{in}}` where ``fan_in`` is the number of input units in the weight tensor. The bias parameters are all set to zero. Jax’s ``nn.initializers`` module provides a variety of preset initialization methods. .. raw:: latex \diilbookstyleinputcell .. code:: python net = nn.Sequential([nn.Dense(8), nn.relu, nn.Dense(1)]) X = jax.random.uniform(d2l.get_key(), (2, 4)) params = net.init(d2l.get_key(), X) net.apply(params, X).shape .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output (2, 1) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import tensorflow as tf By default, Keras initializes weight matrices uniformly by drawing from a range that is computed according to the input and output dimension, and the bias parameters are all set to zero. TensorFlow provides a variety of initialization methods both in the root module and the ``keras.initializers`` module. .. raw:: latex \diilbookstyleinputcell .. code:: python net = tf.keras.models.Sequential([ tf.keras.layers.Flatten(), tf.keras.layers.Dense(4, activation=tf.nn.relu), tf.keras.layers.Dense(1), ]) X = tf.random.uniform((2, 4)) net(X).shape .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output TensorShape([2, 1]) .. raw:: html
.. raw:: html
Built-in Initialization ----------------------- Let’s begin by calling on built-in initializers. The code below initializes all weight parameters as Gaussian random variables with standard deviation 0.01, while bias parameters are cleared to zero. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def init_normal(module): if type(module) == nn.Linear: nn.init.normal_(module.weight, mean=0, std=0.01) nn.init.zeros_(module.bias) net.apply(init_normal) net[0].weight.data[0], net[0].bias.data[0] .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output (tensor([-0.0129, -0.0007, -0.0033, 0.0276]), tensor(0.)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python # Here force_reinit ensures that parameters are freshly initialized even if # they were already initialized previously net.initialize(init=init.Normal(sigma=0.01), force_reinit=True) net[0].weight.data()[0] .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output array([ 0.00354961, -0.00614133, 0.0107317 , 0.01830765]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python weight_init = nn.initializers.normal(0.01) bias_init = nn.initializers.zeros net = nn.Sequential([nn.Dense(8, kernel_init=weight_init, bias_init=bias_init), nn.relu, nn.Dense(1, kernel_init=weight_init, bias_init=bias_init)]) params = net.init(jax.random.PRNGKey(d2l.get_seed()), X) layer_0 = params['params']['layers_0'] layer_0['kernel'][:, 0], layer_0['bias'][0] .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output (Array([ 0.00457076, 0.01890736, -0.0014968 , 0.00327491], dtype=float32), Array(0., dtype=float32)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net = tf.keras.models.Sequential([ tf.keras.layers.Flatten(), tf.keras.layers.Dense( 4, activation=tf.nn.relu, kernel_initializer=tf.random_normal_initializer(mean=0, stddev=0.01), bias_initializer=tf.zeros_initializer()), tf.keras.layers.Dense(1)]) net(X) net.weights[0], net.weights[1] .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output (, ) .. raw:: html
.. raw:: html
We can also initialize all the parameters to a given constant value (say, 1). .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def init_constant(module): if type(module) == nn.Linear: nn.init.constant_(module.weight, 1) nn.init.zeros_(module.bias) net.apply(init_constant) net[0].weight.data[0], net[0].bias.data[0] .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output (tensor([1., 1., 1., 1.]), tensor(0.)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net.initialize(init=init.Constant(1), force_reinit=True) net[0].weight.data()[0] .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output array([1., 1., 1., 1.]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python weight_init = nn.initializers.constant(1) net = nn.Sequential([nn.Dense(8, kernel_init=weight_init, bias_init=bias_init), nn.relu, nn.Dense(1, kernel_init=weight_init, bias_init=bias_init)]) params = net.init(jax.random.PRNGKey(d2l.get_seed()), X) layer_0 = params['params']['layers_0'] layer_0['kernel'][:, 0], layer_0['bias'][0] .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output (Array([1., 1., 1., 1.], dtype=float32), Array(0., dtype=float32)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net = tf.keras.models.Sequential([ tf.keras.layers.Flatten(), tf.keras.layers.Dense( 4, activation=tf.nn.relu, kernel_initializer=tf.keras.initializers.Constant(1), bias_initializer=tf.zeros_initializer()), tf.keras.layers.Dense(1), ]) net(X) net.weights[0], net.weights[1] .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output (, ) .. raw:: html
.. raw:: html
We can also apply different initializers for certain blocks. For example, below we initialize the first layer with the Xavier initializer and initialize the second layer to a constant value of 42. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def init_xavier(module): if type(module) == nn.Linear: nn.init.xavier_uniform_(module.weight) def init_42(module): if type(module) == nn.Linear: nn.init.constant_(module.weight, 42) net[0].apply(init_xavier) net[2].apply(init_42) print(net[0].weight.data[0]) print(net[2].weight.data) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tensor([-0.0974, 0.1707, 0.5840, -0.5032]) tensor([[42., 42., 42., 42., 42., 42., 42., 42.]]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net[0].weight.initialize(init=init.Xavier(), force_reinit=True) net[1].initialize(init=init.Constant(42), force_reinit=True) print(net[0].weight.data()[0]) print(net[1].weight.data()) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [-0.26102373 0.15249556 -0.19274211 -0.24742058] [[42. 42. 42. 42. 42. 42. 42. 42.]] .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net = nn.Sequential([nn.Dense(8, kernel_init=nn.initializers.xavier_uniform(), bias_init=bias_init), nn.relu, nn.Dense(1, kernel_init=nn.initializers.constant(42), bias_init=bias_init)]) params = net.init(jax.random.PRNGKey(d2l.get_seed()), X) params['params']['layers_0']['kernel'][:, 0], params['params']['layers_2']['kernel'] .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output (Array([ 0.38926104, -0.4023119 , -0.41848803, -0.6341998 ], dtype=float32), Array([[42.], [42.], [42.], [42.], [42.], [42.], [42.], [42.]], dtype=float32)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python net = tf.keras.models.Sequential([ tf.keras.layers.Flatten(), tf.keras.layers.Dense( 4, activation=tf.nn.relu, kernel_initializer=tf.keras.initializers.GlorotUniform()), tf.keras.layers.Dense( 1, kernel_initializer=tf.keras.initializers.Constant(42)), ]) net(X) print(net.layers[1].weights[0]) print(net.layers[2].weights[0]) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output .. raw:: html
.. raw:: html
Custom Initialization ~~~~~~~~~~~~~~~~~~~~~ Sometimes, the initialization methods we need are not provided by the deep learning framework. In the example below, we define an initializer for any weight parameter :math:`w` using the following strange distribution: .. math:: \begin{aligned} w \sim \begin{cases} U(5, 10) & \textrm{ with probability } \frac{1}{4} \\ 0 & \textrm{ with probability } \frac{1}{2} \\ U(-10, -5) & \textrm{ with probability } \frac{1}{4} \end{cases} \end{aligned} .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
Again, we implement a ``my_init`` function to apply to ``net``. .. raw:: latex \diilbookstyleinputcell .. code:: python def my_init(module): if type(module) == nn.Linear: print("Init", *[(name, param.shape) for name, param in module.named_parameters()][0]) nn.init.uniform_(module.weight, -10, 10) module.weight.data *= module.weight.data.abs() >= 5 net.apply(my_init) net[0].weight[:2] .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Init weight torch.Size([8, 4]) Init weight torch.Size([1, 8]) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tensor([[ 0.0000, -7.6364, -0.0000, -6.1206], [ 9.3516, -0.0000, 5.1208, -8.4003]], grad_fn=) Note that we always have the option of setting parameters directly. .. raw:: latex \diilbookstyleinputcell .. code:: python net[0].weight.data[:] += 1 net[0].weight.data[0, 0] = 42 net[0].weight.data[0] .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tensor([42.0000, -6.6364, 1.0000, -5.1206]) .. raw:: html
.. raw:: html
Here we define a subclass of the ``Initializer`` class. Usually, we only need to implement the ``_init_weight`` function which takes a tensor argument (``data``) and assigns to it the desired initialized values. .. raw:: latex \diilbookstyleinputcell .. code:: python class MyInit(init.Initializer): def _init_weight(self, name, data): print('Init', name, data.shape) data[:] = np.random.uniform(-10, 10, data.shape) data *= np.abs(data) >= 5 net.initialize(MyInit(), force_reinit=True) net[0].weight.data()[:2] .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Init dense0_weight (8, 4) Init dense1_weight (1, 8) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output array([[-6.0683527, 8.991421 , -0. , 0. ], [ 6.4198647, -9.728567 , -8.057975 , 0. ]]) Note that we always have the option of setting parameters directly. .. raw:: latex \diilbookstyleinputcell .. code:: python net[0].weight.data()[:] += 1 net[0].weight.data()[0, 0] = 42 net[0].weight.data()[0] .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output array([42. , 9.991421, 1. , 1. ]) .. raw:: html
.. raw:: html
Jax initialization functions take as arguments the ``PRNGKey``, ``shape`` and ``dtype``. Here we implement the function ``my_init`` that returns a desired tensor given the shape and data type. .. raw:: latex \diilbookstyleinputcell .. code:: python def my_init(key, shape, dtype=jnp.float_): data = jax.random.uniform(key, shape, minval=-10, maxval=10) return data * (jnp.abs(data) >= 5) net = nn.Sequential([nn.Dense(8, kernel_init=my_init), nn.relu, nn.Dense(1)]) params = net.init(d2l.get_key(), X) print(params['params']['layers_0']['kernel'][:, :2]) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [[ 0. -5.891962 ] [ 0. -9.597271 ] [-5.809202 6.3091564] [ 0. 0. ]] When initializing parameters in JAX and Flax, the the dictionary of parameters returned has a ``flax.core.frozen_dict.FrozenDict`` type. It is not advisable in the Jax ecosystem to directly alter the values of an array, hence the datatypes are generally immutable. One might use ``params.unfreeze()`` to make changes. .. raw:: html
.. raw:: html
Here we define a subclass of ``Initializer`` and implement the ``__call__`` function that return a desired tensor given the shape and data type. .. raw:: latex \diilbookstyleinputcell .. code:: python class MyInit(tf.keras.initializers.Initializer): def __call__(self, shape, dtype=None): data=tf.random.uniform(shape, -10, 10, dtype=dtype) factor=(tf.abs(data) >= 5) factor=tf.cast(factor, tf.float32) return data * factor net = tf.keras.models.Sequential([ tf.keras.layers.Flatten(), tf.keras.layers.Dense( 4, activation=tf.nn.relu, kernel_initializer=MyInit()), tf.keras.layers.Dense(1), ]) net(X) print(net.layers[1].weights[0]) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Note that we always have the option of setting parameters directly. .. raw:: latex \diilbookstyleinputcell .. code:: python net.layers[1].weights[0][:].assign(net.layers[1].weights[0] + 1) net.layers[1].weights[0][0, 0].assign(42) net.layers[1].weights[0] .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output .. raw:: html
.. raw:: html
Summary ------- We can initialize parameters using built-in and custom initializers. Exercises --------- Look up the online documentation for more built-in initializers. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html