10.1. Attention Mechanisms
Open the notebook in Colab
Open the notebook in Colab

In Section 9.7, we encode the source sequence input information in the recurrent unit state and then pass it to the decoder to generate the target sequence. A token in the target sequence may closely relate to one or more tokens in the source sequence, instead of the whole source sequence. For example, when translating “Hello world.” to “Bonjour le monde.”, “Bonjour” maps to “Hello” and “monde” maps to “world”. In the seq2seq model, the decoder may implicitly select the corresponding information from the state passed by the encoder. The attention mechanism, however, makes this selection explicit.

Attention is a generalized pooling method with bias alignment over inputs. The core component in the attention mechanism is the attention layer, or called attention for simplicity. An input of the attention layer is called a query. For a query, attention returns an output based on the memory—a set of key-value pairs encoded in the attention layer. To be more specific, assume that the memory contains \(n\) key-value pairs, \((\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_n, \mathbf{v}_n)\), with \(\mathbf{k}_i \in \mathbb R^{d_k}\), \(\mathbf{v}_i \in \mathbb R^{d_v}\). Given a query \(\mathbf{q} \in \mathbb R^{d_q}\), the attention layer returns an output \(\mathbf{o} \in \mathbb R^{d_v}\) with the same shape as the value.

../_images/attention.svg

Fig. 10.1.1 The attention layer returns an output based on the input query and its memory.

The full process of attention mechanism is expressed in Fig. 10.1.2. To compute the output of attention, we first use a score function \(\alpha\) that measures the similarity between the query and the key. So for each key \(\mathbf{k}_1, \ldots, \mathbf{k}_n\), we compute the scores \(a_1, \ldots, a_n\) by

(10.1.1)\[a_i = \alpha(\mathbf q, \mathbf k_i).\]

Next we use softmax to obtain the attention weights, i.e.,

(10.1.2)\[\mathbf{b} = \mathrm{softmax}(\mathbf{a})\quad \text{, where }\quad {b}_i = \frac{\exp(a_i)}{\sum_j \exp(a_j)}, \mathbf{b} = [b_1, \ldots, b_n]^T .\]

Finally, the output is a weighted sum of the values:

(10.1.3)\[\mathbf o = \sum_{i=1}^n b_i \mathbf v_i.\]
../_images/attention_output.svg

Fig. 10.1.2 The attention output is a weighted sum of the values.

Different choices of the score function lead to different attention layers. Below, we introduce two commonly used attention layers. Before diving into the implementation, we first express two operators to get you up and running: a masked version of the softmax operator masked_softmax and a specialized dot operator batch_dot.

import math
from mxnet import np, npx
from mxnet.gluon import nn
npx.set_np()
import math
import torch
from torch import nn

The masked softmax takes a 3-dimensional input and enables us to filter out some elements by specifying a valid length for the last dimension. (Refer to Section 9.5 for the definition of a valid length.) As a result, any value outside the valid length will be masked as \(0\). Let us implement the masked_softmax function.

#@save
def masked_softmax(X, valid_len):
    # X: 3-D tensor, valid_len: 1-D or 2-D tensor
    if valid_len is None:
        return npx.softmax(X)
    else:
        shape = X.shape
        if valid_len.ndim == 1:
            valid_len = valid_len.repeat(shape[1], axis=0)
        else:
            valid_len = valid_len.reshape(-1)
        # Fill masked elements with a large negative, whose exp is 0
        X = npx.sequence_mask(X.reshape(-1, shape[-1]), valid_len, True,
                              axis=1, value=-1e6)
        return npx.softmax(X).reshape(shape)
#@save
def masked_softmax(X, valid_len):
    # X: 3-D tensor, valid_len: 1-D or 2-D tensor
    if valid_len is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_len.dim() == 1:
            valid_len = torch.repeat_interleave(valid_len, repeats=shape[1], dim=0)
        else:
            valid_len = valid_len.reshape(-1)
        # Fill masked elements with a large negative, whose exp is 0
        X = X.reshape(-1, shape[-1])
        for count, row in enumerate(X):
            row[int(valid_len[count]):]=-1e6
        return nn.functional.softmax(X.reshape(shape), dim=-1)

To illustrate how this function works, we construct two \(2 \times 4\) matrices as the input. In addition, we specify that the valid length equals to 2 for the first example, and 3 for the second example. Then, as we can see from the following outputs, the values outside valid lengths are masked as zero.

masked_softmax(np.random.uniform(size=(2, 2, 4)), np.array([2, 3]))
array([[[0.488994  , 0.511006  , 0.        , 0.        ],
        [0.43654838, 0.56345165, 0.        , 0.        ]],

       [[0.28817102, 0.3519408 , 0.3598882 , 0.        ],
        [0.29034293, 0.25239873, 0.45725834, 0.        ]]])
masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))
tensor([[[0.4263, 0.5737, 0.0000, 0.0000],
         [0.5726, 0.4274, 0.0000, 0.0000]],

        [[0.3552, 0.2617, 0.3830, 0.0000],
         [0.2175, 0.3439, 0.4386, 0.0000]]])

Moreover, the second operator batch_dot takes two inputs \(X\) and \(Y\) with shapes \((b, n, m)\) and \((b, m, k)\), respectively, and returns an output with shape \((b, n, k)\). To be specific, it computes \(b\) dot products for \(i= \{1,\ldots, b\}\), i.e.,

(10.1.4)\[Z[i,:,:] = X[i,:,:] Y[i,:,:].\]
npx.batch_dot(np.ones((2, 1, 3)), np.ones((2, 3, 2)))
array([[[3., 3.]],

       [[3., 3.]]])
torch.bmm(torch.ones(2,1,3), torch.ones(2,3,2))
tensor([[[3., 3.]],

        [[3., 3.]]])

10.1.1. Dot Product Attention

Equipped with the above two operators: masked_softmax and batch_dot, let us dive into the details of two widely used attention layers. The first one is the dot product attention: it assumes that the query has the same dimension as the keys, namely \(\mathbf q, \mathbf k_i \in\mathbb R^d\) for all \(i\). The dot product attention computes the scores by a dot product between the query and a key, which is then divided by \(\sqrt{d}\) to minimize the unrelated influence of the dimension \(d\) on the scores. In other words,

(10.1.5)\[\alpha(\mathbf q, \mathbf k) = \langle \mathbf q, \mathbf k \rangle /\sqrt{d}.\]

Beyond the single-dimensional queries and keys, we can always generalize them to multi-dimensional queries and keys. Assume that \(\mathbf Q\in\mathbb R^{m\times d}\) contains \(m\) queries and \(\mathbf K\in\mathbb R^{n\times d}\) has all the \(n\) keys. We can compute all \(mn\) scores by

(10.1.6)\[\alpha(\mathbf Q, \mathbf K) = \mathbf Q \mathbf K^\top /\sqrt{d}.\]

With (10.1.6), we can implement the dot product attention layer DotProductAttention that supports a batch of queries and key-value pairs. In addition, for regularization we also use a dropout layer.

#@save
class DotProductAttention(nn.Block):
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    # query: (batch_size, #queries, d)
    # key: (batch_size, #kv_pairs, d)
    # value: (batch_size, #kv_pairs, dim_v)
    # valid_len: either (batch_size, ) or (batch_size, xx)
    def forward(self, query, key, value, valid_len=None):
        d = query.shape[-1]
        # Set transpose_b=True to swap the last two dimensions of key
        scores = npx.batch_dot(query, key, transpose_b=True) / math.sqrt(d)
        attention_weights = self.dropout(masked_softmax(scores, valid_len))
        return npx.batch_dot(attention_weights, value)
#@save
class DotProductAttention(nn.Module):
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    # query: (batch_size, #queries, d)
    # key: (batch_size, #kv_pairs, d)
    # value: (batch_size, #kv_pairs, dim_v)
    # valid_len: either (batch_size, ) or (batch_size, xx)
    def forward(self, query, key, value, valid_len=None):
        d = query.shape[-1]
        # set transpose_b=True to swap the last two dimensions of key
        scores = torch.bmm(query, key.transpose(1,2)) / math.sqrt(d)
        attention_weights = self.dropout(masked_softmax(scores, valid_len))
        return torch.bmm(attention_weights, value)

Let us test the class DotProductAttention in a toy example. First, create two batches, where each batch has one query and 10 key-value pairs. Via the valid_len argument, we specify that we will check the first \(2\) key-value pairs for the first batch and \(6\) for the second one. Therefore, even though both batches have the same query and key-value pairs, we obtain different outputs.

atten = DotProductAttention(dropout=0.5)
atten.initialize()
keys = np.ones((2, 10, 2))
values = np.arange(40).reshape(1, 10, 4).repeat(2, axis=0)
atten(np.ones((2, 1, 2)), keys, values, np.array([2, 6]))
array([[[ 2.      ,  3.      ,  4.      ,  5.      ]],

       [[10.      , 11.      , 12.000001, 13.      ]]])
atten = DotProductAttention(dropout=0.5)
atten.eval()
keys = torch.ones(2,10,2)
values = torch.arange(40, dtype=torch.float32).reshape(1,10,4).repeat(2,1,1)
atten(torch.ones(2,1,2), keys, values, torch.tensor([2, 6]))
tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],

        [[10.0000, 11.0000, 12.0000, 13.0000]]])

As we can see above, dot product attention simply multiplies the query and key together, and hopes to derive their similarities from there. Whereas, the query and key may not be of the same dimension. To address such an issue, we may resort to the multilayer perceptron attention.

10.1.2. Multilayer Perceptron Attention

In multilayer perceptron attention, we project both query and keys into \(\mathbb R^{h}\) by learnable weights parameters. Assume that the learnable weights are \(\mathbf W_k\in\mathbb R^{h\times d_k}\), \(\mathbf W_q\in\mathbb R^{h\times d_q}\), and \(\mathbf v\in\mathbb R^{h}\). Then the score function is defined by

(10.1.7)\[\alpha(\mathbf k, \mathbf q) = \mathbf v^\top \text{tanh}(\mathbf W_k \mathbf k + \mathbf W_q\mathbf q).\]

Intuitively, you can imagine \(\mathbf W_k \mathbf k + \mathbf W_q\mathbf q\) as concatenating the key and value in the feature dimension and feeding them to a single hidden layer perceptron with hidden layer size \(h\) and output layer size \(1\). In this hidden layer, the activation function is \(\tanh\) and no bias is applied. Now let us implement the multilayer perceptron attention.

#@save
class MLPAttention(nn.Block):
    def __init__(self, units, dropout, **kwargs):
        super(MLPAttention, self).__init__(**kwargs)
        # Use flatten=False to keep query's and key's 3-D shapes
        self.W_k = nn.Dense(units, use_bias=False, flatten=False)
        self.W_q = nn.Dense(units, use_bias=False, flatten=False)
        self.v = nn.Dense(1, use_bias=False, flatten=False)
        self.dropout = nn.Dropout(dropout)


    def forward(self, query, key, value, valid_len):
        query, key = self.W_q(query), self.W_k(key)
        # Expand query to (batch_size, #querys, 1, units), and key to
        # (batch_size, 1, #kv_pairs, units). Then plus them with broadcast
        features = np.expand_dims(query, axis=2) + np.expand_dims(key, axis=1)
        features = np.tanh(features)
        scores = np.squeeze(self.v(features), axis=-1)
        attention_weights = self.dropout(masked_softmax(scores, valid_len))
        return npx.batch_dot(attention_weights, value)
#@save
class MLPAttention(nn.Module):
    def __init__(self, key_size, query_size, units, dropout, **kwargs):
        super(MLPAttention, self).__init__(**kwargs)
        self.W_k = nn.Linear(key_size, units, bias=False)
        self.W_q = nn.Linear(query_size, units, bias=False)
        self.v = nn.Linear(units, 1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, valid_len):
        query, key = self.W_k(query), self.W_q(key)
        # expand query to (batch_size, #querys, 1, units), and key to
        # (batch_size, 1, #kv_pairs, units). Then plus them with broadcast.
        features = query.unsqueeze(2) + key.unsqueeze(1)
        scores = self.v(features).squeeze(-1)
        attention_weights = self.dropout(masked_softmax(scores, valid_len))
        return torch.bmm(attention_weights, value)

To test the above MLPAttention class, we use the same inputs as in the previous toy example. As we can see below, despite MLPAttention containing an additional MLP model, we obtain the same outputs as for DotProductAttention.

atten = MLPAttention(units=8, dropout=0.1)
atten.initialize()
atten(np.ones((2, 1, 2)), keys, values, np.array([2, 6]))
array([[[ 2.      ,  3.      ,  4.      ,  5.      ]],

       [[10.      , 11.      , 12.000001, 13.      ]]])
atten = MLPAttention(key_size=2, query_size=2, units=8, dropout=0.1)
atten.eval()
atten(torch.ones(2,1,2), keys, values, torch.tensor([2, 6]))
tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],

        [[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=<BmmBackward>)

10.1.3. Summary

  • An attention layer explicitly selects related information.

  • An attention layer’s memory consists of key-value pairs, so its output is close to the values whose keys are similar to the queries.

  • Two commonly used attention models are dot product attention and multilayer perceptron attention.

10.1.4. Exercises

  1. What are the advantages and disadvantages for dot product attention and multilayer perceptron attention, respectively?

10.1.5. Discussions

image0