.. _sec_softmax_scratch:
Softmax Regression Implementation from Scratch
==============================================
Because softmax regression is so fundamental, we believe that you ought
to know how to implement it yourself. Here, we limit ourselves to
defining the softmax-specific aspects of the model and reuse the other
components from our linear regression section, including the training
loop.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
X.sum(0, keepdims=True), X.sum(1, keepdims=True)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(tensor([[5., 7., 9.]]),
tensor([[ 6.],
[15.]]))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
X.sum(0, keepdims=True), X.sum(1, keepdims=True)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[22:09:48] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(array([[5., 7., 9.]]),
array([[ 6.],
[15.]]))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
X.sum(0, keepdims=True), X.sum(1, keepdims=True)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(Array([[5., 7., 9.]], dtype=float32),
Array([[ 6.],
[15.]], dtype=float32))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
tf.reduce_sum(X, 0, keepdims=True), tf.reduce_sum(X, 1, keepdims=True)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(,
)
.. raw:: html
.. raw:: html
Computing the softmax requires three steps: (i) exponentiation of each
term; (ii) a sum over each row to compute the normalization constant for
each example; (iii) division of each row by its normalization constant,
ensuring that the result sums to 1:
.. math:: \mathrm{softmax}(\mathbf{X})_{ij} = \frac{\exp(\mathbf{X}_{ij})}{\sum_k \exp(\mathbf{X}_{ik})}.
The (logarithm of the) denominator is called the (log) *partition
function*. It was introduced in `statistical
physics `__
to sum over all possible states in a thermodynamic ensemble. The
implementation is straightforward:
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def softmax(X):
X_exp = torch.exp(X)
partition = X_exp.sum(1, keepdims=True)
return X_exp / partition # The broadcasting mechanism is applied here
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def softmax(X):
X_exp = np.exp(X)
partition = X_exp.sum(1, keepdims=True)
return X_exp / partition # The broadcasting mechanism is applied here
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def softmax(X):
X_exp = jnp.exp(X)
partition = X_exp.sum(1, keepdims=True)
return X_exp / partition # The broadcasting mechanism is applied here
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def softmax(X):
X_exp = tf.exp(X)
partition = tf.reduce_sum(X_exp, 1, keepdims=True)
return X_exp / partition # The broadcasting mechanism is applied here
.. raw:: html
.. raw:: html
For any input ``X``, we turn each element into a nonnegative number.
Each row sums up to 1, as is required for a probability. Caution: the
code above is *not* robust against very large or very small arguments.
While it is sufficient to illustrate what is happening, you should *not*
use this code verbatim for any serious purpose. Deep learning frameworks
have such protections built in and we will be using the built-in softmax
going forward.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = torch.rand((2, 5))
X_prob = softmax(X)
X_prob, X_prob.sum(1)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(tensor([[0.2511, 0.1417, 0.1158, 0.2529, 0.2385],
[0.2004, 0.1419, 0.1957, 0.2504, 0.2117]]),
tensor([1., 1.]))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = np.random.rand(2, 5)
X_prob = softmax(X)
X_prob, X_prob.sum(1)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(array([[0.17777154, 0.1857739 , 0.20995119, 0.23887765, 0.18762572],
[0.24042214, 0.1757977 , 0.23786479, 0.15572716, 0.19018826]]),
array([1., 1.]))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = jax.random.uniform(jax.random.PRNGKey(d2l.get_seed()), (2, 5))
X_prob = softmax(X)
X_prob, X_prob.sum(1)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(Array([[0.21010879, 0.13711403, 0.17836188, 0.19433393, 0.28008142],
[0.13819133, 0.12869641, 0.1519639 , 0.32253352, 0.25861484]], dtype=float32),
Array([1., 1.], dtype=float32))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = tf.random.uniform((2, 5))
X_prob = softmax(X)
X_prob, tf.reduce_sum(X_prob, 1)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(,
)
.. raw:: html
.. raw:: html
The Model
---------
We now have everything that we need to implement the softmax regression
model. As in our linear regression example, each instance will be
represented by a fixed-length vector. Since the raw data here consists
of :math:`28 \times 28` pixel images, we flatten each image, treating
them as vectors of length 784. In later chapters, we will introduce
convolutional neural networks, which exploit the spatial structure in a
more satisfying way.
In softmax regression, the number of outputs from our network should be
equal to the number of classes. Since our dataset has 10 classes, our
network has an output dimension of 10. Consequently, our weights
constitute a :math:`784 \times 10` matrix plus a :math:`1 \times 10` row
vector for the biases. As with linear regression, we initialize the
weights ``W`` with Gaussian noise. The biases are initialized as zeros.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class SoftmaxRegressionScratch(d2l.Classifier):
def __init__(self, num_inputs, num_outputs, lr, sigma=0.01):
super().__init__()
self.save_hyperparameters()
self.W = torch.normal(0, sigma, size=(num_inputs, num_outputs),
requires_grad=True)
self.b = torch.zeros(num_outputs, requires_grad=True)
def parameters(self):
return [self.W, self.b]
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class SoftmaxRegressionScratch(d2l.Classifier):
def __init__(self, num_inputs, num_outputs, lr, sigma=0.01):
super().__init__()
self.save_hyperparameters()
self.W = np.random.normal(0, sigma, (num_inputs, num_outputs))
self.b = np.zeros(num_outputs)
self.W.attach_grad()
self.b.attach_grad()
def collect_params(self):
return [self.W, self.b]
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class SoftmaxRegressionScratch(d2l.Classifier):
num_inputs: int
num_outputs: int
lr: float
sigma: float = 0.01
def setup(self):
self.W = self.param('W', nn.initializers.normal(self.sigma),
(self.num_inputs, self.num_outputs))
self.b = self.param('b', nn.initializers.zeros, self.num_outputs)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class SoftmaxRegressionScratch(d2l.Classifier):
def __init__(self, num_inputs, num_outputs, lr, sigma=0.01):
super().__init__()
self.save_hyperparameters()
self.W = tf.random.normal((num_inputs, num_outputs), 0, sigma)
self.b = tf.zeros(num_outputs)
self.W = tf.Variable(self.W)
self.b = tf.Variable(self.b)
.. raw:: html
.. raw:: html
The code below defines how the network maps each input to an output.
Note that we flatten each :math:`28 \times 28` pixel image in the batch
into a vector using ``reshape`` before passing the data through our
model.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(SoftmaxRegressionScratch)
def forward(self, X):
X = X.reshape((-1, self.W.shape[0]))
return softmax(torch.matmul(X, self.W) + self.b)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(SoftmaxRegressionScratch)
def forward(self, X):
X = X.reshape((-1, self.W.shape[0]))
return softmax(np.dot(X, self.W) + self.b)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(SoftmaxRegressionScratch)
def forward(self, X):
X = X.reshape((-1, self.W.shape[0]))
return softmax(jnp.matmul(X, self.W) + self.b)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(SoftmaxRegressionScratch)
def forward(self, X):
X = tf.reshape(X, (-1, self.W.shape[0]))
return softmax(tf.matmul(X, self.W) + self.b)
.. raw:: html
.. raw:: html
The Cross-Entropy Loss
----------------------
Next we need to implement the cross-entropy loss function (introduced in
:numref:`subsec_softmax-regression-loss-func`). This may be the most
common loss function in all of deep learning. At the moment,
applications of deep learning easily cast as classification problems far
outnumber those better treated as regression problems.
Recall that cross-entropy takes the negative log-likelihood of the
predicted probability assigned to the true label. For efficiency we
avoid Python for-loops and use indexing instead. In particular, the
one-hot encoding in :math:`\mathbf{y}` allows us to select the matching
terms in :math:`\hat{\mathbf{y}}`.
To see this in action we create sample data ``y_hat`` with 2 examples of
predicted probabilities over 3 classes and their corresponding labels
``y``. The correct labels are :math:`0` and :math:`2` respectively
(i.e., the first and third class). Using ``y`` as the indices of the
probabilities in ``y_hat``, we can pick out terms efficiently.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y_hat[[0, 1], y]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([0.1000, 0.5000])
Now we can implement the cross-entropy loss function by averaging over
the logarithms of the selected probabilities.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def cross_entropy(y_hat, y):
return -torch.log(y_hat[list(range(len(y_hat))), y]).mean()
cross_entropy(y_hat, y)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor(1.4979)
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(SoftmaxRegressionScratch)
def loss(self, y_hat, y):
return cross_entropy(y_hat, y)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
y = np.array([0, 2])
y_hat = np.array([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y_hat[[0, 1], y]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([0.1, 0.5])
Now we can implement the cross-entropy loss function by averaging over
the logarithms of the selected probabilities.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def cross_entropy(y_hat, y):
return -np.log(y_hat[list(range(len(y_hat))), y]).mean()
cross_entropy(y_hat, y)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array(1.4978662)
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(SoftmaxRegressionScratch)
def loss(self, y_hat, y):
return cross_entropy(y_hat, y)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
y = jnp.array([0, 2])
y_hat = jnp.array([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y_hat[[0, 1], y]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array([0.1, 0.5], dtype=float32)
Now we can implement the cross-entropy loss function by averaging over
the logarithms of the selected probabilities.
Note that to make use of ``jax.jit`` to speed up JAX implementations,
and to make sure ``loss`` is a pure function, the ``cross_entropy``
function is re-defined inside the ``loss`` to avoid usage of any global
variables or functions which may render the ``loss`` function impure. We
refer interested readers to the `JAX
documentation `__
on ``jax.jit`` and pure functions.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def cross_entropy(y_hat, y):
return -jnp.log(y_hat[list(range(len(y_hat))), y]).mean()
cross_entropy(y_hat, y)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array(1.4978662, dtype=float32)
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(SoftmaxRegressionScratch)
@partial(jax.jit, static_argnums=(0))
def loss(self, params, X, y, state):
def cross_entropy(y_hat, y):
return -jnp.log(y_hat[list(range(len(y_hat))), y]).mean()
y_hat = state.apply_fn({'params': params}, *X)
# The returned empty dictionary is a placeholder for auxiliary data,
# which will be used later (e.g., for batch norm)
return cross_entropy(y_hat, y), {}
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
y_hat = tf.constant([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y = tf.constant([0, 2])
tf.boolean_mask(y_hat, tf.one_hot(y, depth=y_hat.shape[-1]))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Now we can implement the cross-entropy loss function by averaging over
the logarithms of the selected probabilities.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def cross_entropy(y_hat, y):
return -tf.reduce_mean(tf.math.log(tf.boolean_mask(
y_hat, tf.one_hot(y, depth=y_hat.shape[-1]))))
cross_entropy(y_hat, y)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(SoftmaxRegressionScratch)
def loss(self, y_hat, y):
return cross_entropy(y_hat, y)
.. raw:: html
.. raw:: html
Training
--------
We reuse the ``fit`` method defined in :numref:`sec_linear_scratch` to
train the model with 10 epochs. Note that the number of epochs
(``max_epochs``), the minibatch size (``batch_size``), and learning rate
(``lr``) are adjustable hyperparameters. That means that while these
values are not learned during our primary training loop, they still
influence the performance of our model, both vis-à-vis training and
generalization performance. In practice you will want to choose these
values based on the *validation* split of the data and then, ultimately,
to evaluate your final model on the *test* split. As discussed in
:numref:`subsec_generalization-model-selection`, we will regard the
test data of Fashion-MNIST as the validation set, thus reporting
validation loss and validation accuracy on this split.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
data = d2l.FashionMNIST(batch_size=256)
model = SoftmaxRegressionScratch(num_inputs=784, num_outputs=10, lr=0.1)
trainer = d2l.Trainer(max_epochs=10)
trainer.fit(model, data)
.. figure:: output_softmax-regression-scratch_aecde8_120_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
data = d2l.FashionMNIST(batch_size=256)
model = SoftmaxRegressionScratch(num_inputs=784, num_outputs=10, lr=0.1)
trainer = d2l.Trainer(max_epochs=10)
trainer.fit(model, data)
.. figure:: output_softmax-regression-scratch_aecde8_123_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
data = d2l.FashionMNIST(batch_size=256)
model = SoftmaxRegressionScratch(num_inputs=784, num_outputs=10, lr=0.1)
trainer = d2l.Trainer(max_epochs=10)
trainer.fit(model, data)
.. figure:: output_softmax-regression-scratch_aecde8_126_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
data = d2l.FashionMNIST(batch_size=256)
model = SoftmaxRegressionScratch(num_inputs=784, num_outputs=10, lr=0.1)
trainer = d2l.Trainer(max_epochs=10)
trainer.fit(model, data)
.. figure:: output_softmax-regression-scratch_aecde8_129_0.svg
.. raw:: html
.. raw:: html
Prediction
----------
Now that training is complete, our model is ready to classify some
images.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X, y = next(iter(data.val_dataloader()))
preds = model(X).argmax(axis=1)
preds.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
torch.Size([256])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X, y = next(iter(data.val_dataloader()))
preds = model(X).argmax(axis=1)
preds.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(256,)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X, y = next(iter(data.val_dataloader()))
preds = model.apply({'params': trainer.state.params}, X).argmax(axis=1)
preds.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(256,)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X, y = next(iter(data.val_dataloader()))
preds = tf.argmax(model(X), axis=1)
preds.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
TensorShape([256])
.. raw:: html
.. raw:: html
We are more interested in the images we label *incorrectly*. We
visualize them by comparing their actual labels (first line of text
output) with the predictions from the model (second line of text
output).
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
wrong = preds.type(y.dtype) != y
X, y, preds = X[wrong], y[wrong], preds[wrong]
labels = [a+'\n'+b for a, b in zip(
data.text_labels(y), data.text_labels(preds))]
data.visualize([X, y], labels=labels)
.. figure:: output_softmax-regression-scratch_aecde8_150_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
wrong = preds.astype(y.dtype) != y
X, y, preds = X[wrong], y[wrong], preds[wrong]
labels = [a+'\n'+b for a, b in zip(
data.text_labels(y), data.text_labels(preds))]
data.visualize([X, y], labels=labels)
.. figure:: output_softmax-regression-scratch_aecde8_153_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
wrong = preds.astype(y.dtype) != y
X, y, preds = X[wrong], y[wrong], preds[wrong]
labels = [a+'\n'+b for a, b in zip(
data.text_labels(y), data.text_labels(preds))]
data.visualize([X, y], labels=labels)
.. figure:: output_softmax-regression-scratch_aecde8_156_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
wrong = tf.cast(preds, y.dtype) != y
X, y, preds = X[wrong], y[wrong], preds[wrong]
labels = [a+'\n'+b for a, b in zip(
data.text_labels(y), data.text_labels(preds))]
data.visualize([X, y], labels=labels)
.. figure:: output_softmax-regression-scratch_aecde8_159_0.svg
.. raw:: html
.. raw:: html
Summary
-------
By now we are starting to get some experience with solving linear
regression and classification problems. With it, we have reached what
would arguably be the state of the art of 1960–1970s of statistical
modeling. In the next section, we will show you how to leverage deep
learning frameworks to implement this model much more efficiently.
Exercises
---------
1. In this section, we directly implemented the softmax function based
on the mathematical definition of the softmax operation. As discussed
in :numref:`sec_softmax` this can cause numerical instabilities.
1. Test whether ``softmax`` still works correctly if an input has a
value of :math:`100`.
2. Test whether ``softmax`` still works correctly if the largest of
all inputs is smaller than :math:`-100`?
3. Implement a fix by looking at the value relative to the largest
entry in the argument.
2. Implement a ``cross_entropy`` function that follows the definition of
the cross-entropy loss function :math:`\sum_i y_i \log \hat{y}_i`.
1. Try it out in the code example of this section.
2. Why do you think it runs more slowly?
3. Should you use it? When would it make sense to?
4. What do you need to be careful of? Hint: consider the domain of
the logarithm.
3. Is it always a good idea to return the most likely label? For
example, would you do this for medical diagnosis? How would you try
to address this?
4. Assume that we want to use softmax regression to predict the next
word based on some features. What are some problems that might arise
from a large vocabulary?
5. Experiment with the hyperparameters of the code in this section. In
particular:
1. Plot how the validation loss changes as you change the learning
rate.
2. Do the validation and training loss change as you change the
minibatch size? How large or small do you need to go before you
see an effect?
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html