.. _sec_multihead-attention: Multi-Head Attention ==================== In practice, given the same set of queries, keys, and values we may want our model to combine knowledge from different behaviors of the same attention mechanism, such as capturing dependencies of various ranges (e.g., shorter-range vs. longer-range) within a sequence. Thus, it may be beneficial to allow our attention mechanism to jointly use different representation subspaces of queries, keys, and values. To this end, instead of performing a single attention pooling, queries, keys, and values can be transformed with :math:`h` independently learned linear projections. Then these :math:`h` projected queries, keys, and values are fed into attention pooling in parallel. In the end, :math:`h` attention-pooling outputs are concatenated and transformed with another learned linear projection to produce the final output. This design is called *multi-head attention*, where each of the :math:`h` attention pooling outputs is a *head* :cite:`Vaswani.Shazeer.Parmar.ea.2017`. Using fully connected layers to perform learnable linear transformations, :numref:`fig_multi-head-attention` describes multi-head attention. .. _fig_multi-head-attention: .. figure:: ../img/multi-head-attention.svg Multi-head attention, where multiple heads are concatenated then linearly transformed. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import math import torch from torch import nn from d2l import torch as d2l .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import math from mxnet import autograd, np, npx from mxnet.gluon import nn from d2l import mxnet as d2l npx.set_np() .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import jax from flax import linen as nn from jax import numpy as jnp from d2l import jax as d2l .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import tensorflow as tf from d2l import tensorflow as d2l .. raw:: html
.. raw:: html
Model ----- Before providing the implementation of multi-head attention, let’s formalize this model mathematically. Given a query :math:`\mathbf{q} \in \mathbb{R}^{d_q}`, a key :math:`\mathbf{k} \in \mathbb{R}^{d_k}`, and a value :math:`\mathbf{v} \in \mathbb{R}^{d_v}`, each attention head :math:`\mathbf{h}_i` (:math:`i = 1, \ldots, h`) is computed as .. math:: \mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v}, where :math:`\mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q}`, :math:`\mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k}`, and :math:`\mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v}` are learnable parameters and :math:`f` is attention pooling, such as additive attention and scaled dot product attention in :numref:`sec_attention-scoring-functions`. The multi-head attention output is another linear transformation via learnable parameters :math:`\mathbf W_o\in\mathbb R^{p_o\times h p_v}` of the concatenation of :math:`h` heads: .. math:: \mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}. Based on this design, each head may attend to different parts of the input. More sophisticated functions than the simple weighted average can be expressed. Implementation -------------- In our implementation, we choose the scaled dot product attention for each head of the multi-head attention. To avoid significant growth of computational cost and parametrization cost, we set :math:`p_q = p_k = p_v = p_o / h`. Note that :math:`h` heads can be computed in parallel if we set the number of outputs of linear transformations for the query, key, and value to :math:`p_q h = p_k h = p_v h = p_o`. In the following implementation, :math:`p_o` is specified via the argument ``num_hiddens``. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class MultiHeadAttention(d2l.Module): #@save """Multi-head attention.""" def __init__(self, num_hiddens, num_heads, dropout, bias=False, **kwargs): super().__init__() self.num_heads = num_heads self.attention = d2l.DotProductAttention(dropout) self.W_q = nn.LazyLinear(num_hiddens, bias=bias) self.W_k = nn.LazyLinear(num_hiddens, bias=bias) self.W_v = nn.LazyLinear(num_hiddens, bias=bias) self.W_o = nn.LazyLinear(num_hiddens, bias=bias) def forward(self, queries, keys, values, valid_lens): # Shape of queries, keys, or values: # (batch_size, no. of queries or key-value pairs, num_hiddens) # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries) # After transposing, shape of output queries, keys, or values: # (batch_size * num_heads, no. of queries or key-value pairs, # num_hiddens / num_heads) queries = self.transpose_qkv(self.W_q(queries)) keys = self.transpose_qkv(self.W_k(keys)) values = self.transpose_qkv(self.W_v(values)) if valid_lens is not None: # On axis 0, copy the first item (scalar or vector) for num_heads # times, then copy the next item, and so on valid_lens = torch.repeat_interleave( valid_lens, repeats=self.num_heads, dim=0) # Shape of output: (batch_size * num_heads, no. of queries, # num_hiddens / num_heads) output = self.attention(queries, keys, values, valid_lens) # Shape of output_concat: (batch_size, no. of queries, num_hiddens) output_concat = self.transpose_output(output) return self.W_o(output_concat) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class MultiHeadAttention(d2l.Module): #@save """Multi-head attention.""" def __init__(self, num_hiddens, num_heads, dropout, use_bias=False, **kwargs): super().__init__() self.num_heads = num_heads self.attention = d2l.DotProductAttention(dropout) self.W_q = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False) self.W_k = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False) self.W_v = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False) self.W_o = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False) def forward(self, queries, keys, values, valid_lens): # Shape of queries, keys, or values: # (batch_size, no. of queries or key-value pairs, num_hiddens) # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries) # After transposing, shape of output queries, keys, or values: # (batch_size * num_heads, no. of queries or key-value pairs, # num_hiddens / num_heads) queries = self.transpose_qkv(self.W_q(queries)) keys = self.transpose_qkv(self.W_k(keys)) values = self.transpose_qkv(self.W_v(values)) if valid_lens is not None: # On axis 0, copy the first item (scalar or vector) for num_heads # times, then copy the next item, and so on valid_lens = valid_lens.repeat(self.num_heads, axis=0) # Shape of output: (batch_size * num_heads, no. of queries, # num_hiddens / num_heads) output = self.attention(queries, keys, values, valid_lens) # Shape of output_concat: (batch_size, no. of queries, num_hiddens) output_concat = self.transpose_output(output) return self.W_o(output_concat) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class MultiHeadAttention(nn.Module): #@save num_hiddens: int num_heads: int dropout: float bias: bool = False def setup(self): self.attention = d2l.DotProductAttention(self.dropout) self.W_q = nn.Dense(self.num_hiddens, use_bias=self.bias) self.W_k = nn.Dense(self.num_hiddens, use_bias=self.bias) self.W_v = nn.Dense(self.num_hiddens, use_bias=self.bias) self.W_o = nn.Dense(self.num_hiddens, use_bias=self.bias) @nn.compact def __call__(self, queries, keys, values, valid_lens, training=False): # Shape of queries, keys, or values: # (batch_size, no. of queries or key-value pairs, num_hiddens) # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries) # After transposing, shape of output queries, keys, or values: # (batch_size * num_heads, no. of queries or key-value pairs, # num_hiddens / num_heads) queries = self.transpose_qkv(self.W_q(queries)) keys = self.transpose_qkv(self.W_k(keys)) values = self.transpose_qkv(self.W_v(values)) if valid_lens is not None: # On axis 0, copy the first item (scalar or vector) for num_heads # times, then copy the next item, and so on valid_lens = jnp.repeat(valid_lens, self.num_heads, axis=0) # Shape of output: (batch_size * num_heads, no. of queries, # num_hiddens / num_heads) output, attention_weights = self.attention( queries, keys, values, valid_lens, training=training) # Shape of output_concat: (batch_size, no. of queries, num_hiddens) output_concat = self.transpose_output(output) return self.W_o(output_concat), attention_weights .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class MultiHeadAttention(d2l.Module): #@save """Multi-head attention.""" def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs): super().__init__() self.num_heads = num_heads self.attention = d2l.DotProductAttention(dropout) self.W_q = tf.keras.layers.Dense(num_hiddens, use_bias=bias) self.W_k = tf.keras.layers.Dense(num_hiddens, use_bias=bias) self.W_v = tf.keras.layers.Dense(num_hiddens, use_bias=bias) self.W_o = tf.keras.layers.Dense(num_hiddens, use_bias=bias) def call(self, queries, keys, values, valid_lens, **kwargs): # Shape of queries, keys, or values: # (batch_size, no. of queries or key-value pairs, num_hiddens) # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries) # After transposing, shape of output queries, keys, or values: # (batch_size * num_heads, no. of queries or key-value pairs, # num_hiddens / num_heads) queries = self.transpose_qkv(self.W_q(queries)) keys = self.transpose_qkv(self.W_k(keys)) values = self.transpose_qkv(self.W_v(values)) if valid_lens is not None: # On axis 0, copy the first item (scalar or vector) for num_heads # times, then copy the next item, and so on valid_lens = tf.repeat(valid_lens, repeats=self.num_heads, axis=0) # Shape of output: (batch_size * num_heads, no. of queries, # num_hiddens / num_heads) output = self.attention(queries, keys, values, valid_lens, **kwargs) # Shape of output_concat: (batch_size, no. of queries, num_hiddens) output_concat = self.transpose_output(output) return self.W_o(output_concat) .. raw:: html
.. raw:: html
To allow for parallel computation of multiple heads, the above ``MultiHeadAttention`` class uses two transposition methods as defined below. Specifically, the ``transpose_output`` method reverses the operation of the ``transpose_qkv`` method. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python @d2l.add_to_class(MultiHeadAttention) #@save def transpose_qkv(self, X): """Transposition for parallel computation of multiple attention heads.""" # Shape of input X: (batch_size, no. of queries or key-value pairs, # num_hiddens). Shape of output X: (batch_size, no. of queries or # key-value pairs, num_heads, num_hiddens / num_heads) X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1) # Shape of output X: (batch_size, num_heads, no. of queries or key-value # pairs, num_hiddens / num_heads) X = X.permute(0, 2, 1, 3) # Shape of output: (batch_size * num_heads, no. of queries or key-value # pairs, num_hiddens / num_heads) return X.reshape(-1, X.shape[2], X.shape[3]) @d2l.add_to_class(MultiHeadAttention) #@save def transpose_output(self, X): """Reverse the operation of transpose_qkv.""" X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2]) X = X.permute(0, 2, 1, 3) return X.reshape(X.shape[0], X.shape[1], -1) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python @d2l.add_to_class(MultiHeadAttention) #@save def transpose_qkv(self, X): """Transposition for parallel computation of multiple attention heads.""" # Shape of input X: (batch_size, no. of queries or key-value pairs, # num_hiddens). Shape of output X: (batch_size, no. of queries or # key-value pairs, num_heads, num_hiddens / num_heads) X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1) # Shape of output X: (batch_size, num_heads, no. of queries or key-value # pairs, num_hiddens / num_heads) X = X.transpose(0, 2, 1, 3) # Shape of output: (batch_size * num_heads, no. of queries or key-value # pairs, num_hiddens / num_heads) return X.reshape(-1, X.shape[2], X.shape[3]) @d2l.add_to_class(MultiHeadAttention) #@save def transpose_output(self, X): """Reverse the operation of transpose_qkv.""" X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2]) X = X.transpose(0, 2, 1, 3) return X.reshape(X.shape[0], X.shape[1], -1) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python @d2l.add_to_class(MultiHeadAttention) #@save def transpose_qkv(self, X): """Transposition for parallel computation of multiple attention heads.""" # Shape of input X: (batch_size, no. of queries or key-value pairs, # num_hiddens). Shape of output X: (batch_size, no. of queries or # key-value pairs, num_heads, num_hiddens / num_heads) X = X.reshape((X.shape[0], X.shape[1], self.num_heads, -1)) # Shape of output X: (batch_size, num_heads, no. of queries or key-value # pairs, num_hiddens / num_heads) X = jnp.transpose(X, (0, 2, 1, 3)) # Shape of output: (batch_size * num_heads, no. of queries or key-value # pairs, num_hiddens / num_heads) return X.reshape((-1, X.shape[2], X.shape[3])) @d2l.add_to_class(MultiHeadAttention) #@save def transpose_output(self, X): """Reverse the operation of transpose_qkv.""" X = X.reshape((-1, self.num_heads, X.shape[1], X.shape[2])) X = jnp.transpose(X, (0, 2, 1, 3)) return X.reshape((X.shape[0], X.shape[1], -1)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python @d2l.add_to_class(MultiHeadAttention) #@save def transpose_qkv(self, X): """Transposition for parallel computation of multiple attention heads.""" # Shape of input X: (batch_size, no. of queries or key-value pairs, # num_hiddens). Shape of output X: (batch_size, no. of queries or # key-value pairs, num_heads, num_hiddens / num_heads) X = tf.reshape(X, shape=(X.shape[0], X.shape[1], self.num_heads, -1)) # Shape of output X: (batch_size, num_heads, no. of queries or key-value # pairs, num_hiddens / num_heads) X = tf.transpose(X, perm=(0, 2, 1, 3)) # Shape of output: (batch_size * num_heads, no. of queries or key-value # pairs, num_hiddens / num_heads) return tf.reshape(X, shape=(-1, X.shape[2], X.shape[3])) @d2l.add_to_class(MultiHeadAttention) #@save def transpose_output(self, X): """Reverse the operation of transpose_qkv.""" X = tf.reshape(X, shape=(-1, self.num_heads, X.shape[1], X.shape[2])) X = tf.transpose(X, perm=(0, 2, 1, 3)) return tf.reshape(X, shape=(X.shape[0], X.shape[1], -1)) .. raw:: html
.. raw:: html
Let’s test our implemented ``MultiHeadAttention`` class using a toy example where keys and values are the same. As a result, the shape of the multi-head attention output is (``batch_size``, ``num_queries``, ``num_hiddens``). .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python num_hiddens, num_heads = 100, 5 attention = MultiHeadAttention(num_hiddens, num_heads, 0.5) batch_size, num_queries, num_kvpairs = 2, 4, 6 valid_lens = torch.tensor([3, 2]) X = torch.ones((batch_size, num_queries, num_hiddens)) Y = torch.ones((batch_size, num_kvpairs, num_hiddens)) d2l.check_shape(attention(X, Y, Y, valid_lens), (batch_size, num_queries, num_hiddens)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python num_hiddens, num_heads = 100, 5 attention = MultiHeadAttention(num_hiddens, num_heads, 0.5) attention.initialize() batch_size, num_queries, num_kvpairs = 2, 4, 6 valid_lens = np.array([3, 2]) X = np.ones((batch_size, num_queries, num_hiddens)) Y = np.ones((batch_size, num_kvpairs, num_hiddens)) d2l.check_shape(attention(X, Y, Y, valid_lens), (batch_size, num_queries, num_hiddens)) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [22:06:01] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python num_hiddens, num_heads = 100, 5 attention = MultiHeadAttention(num_hiddens, num_heads, 0.5) batch_size, num_queries, num_kvpairs = 2, 4, 6 valid_lens = jnp.array([3, 2]) X = jnp.ones((batch_size, num_queries, num_hiddens)) Y = jnp.ones((batch_size, num_kvpairs, num_hiddens)) d2l.check_shape(attention.init_with_output(d2l.get_key(), X, Y, Y, valid_lens, training=False)[0][0], (batch_size, num_queries, num_hiddens)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python num_hiddens, num_heads = 100, 5 attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5) batch_size, num_queries, num_kvpairs = 2, 4, 6 valid_lens = tf.constant([3, 2]) X = tf.ones((batch_size, num_queries, num_hiddens)) Y = tf.ones((batch_size, num_kvpairs, num_hiddens)) d2l.check_shape(attention(X, Y, Y, valid_lens, training=False), (batch_size, num_queries, num_hiddens)) .. raw:: html
.. raw:: html
Summary ------- Multi-head attention combines knowledge of the same attention pooling via different representation subspaces of queries, keys, and values. To compute multiple heads of multi-head attention in parallel, proper tensor manipulation is needed. Exercises --------- 1. Visualize attention weights of multiple heads in this experiment. 2. Suppose that we have a trained model based on multi-head attention and we want to prune less important attention heads to increase the prediction speed. How can we design experiments to measure the importance of an attention head? .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html