Notes on Attention and Transformers

Introduction

  • Key signals of success in seq modeling has been : RNN base and encoder-decoder architecture.
  • Fundamental constraint is sequential processing
  • Attention becoming important part of seq-modelling making the model invariant to sequence length , while this is still used in conjunction with recurrent networks.
  • Proposal for Transformers, avoid recurrence and rely completely on attention to draw out global dependencies across input and output sequences.

Model Architecture

Encoder

  • Converts input (x_1,x_2,....x_n) to intermediate representation (z_1,z_2 ... z_n)
  • stack of N=6 layers, each layer has 2 sub layers, 1. multi-head-attention and 2. position wise fully connected NN.
  • There is also a residual connection around the 2 sub layers and then layer-norm is done. i.e: LayerNorm( x + Sublayer(x))

Decoder

  • Has 3 sub layers, in addition to what we see in the encoder it adds a additional attention sub-layer which performs attention over the output of the encoder and the input of the decoder-self-attention sub-layer.
  • This also has residual connections which pass over each sublayer.
  • This also has layer normalization after each sublayer.
  • In the self attention step of the decoder we also include a mask which hides subsequent positions ( i.e: dont even show it whats coming next, so that all the attention can be focused on the sequence until now)

Attention

The basic idea of attention can be captured as follows:


f( q , k , v ) => output

Think of attention as a mapping between a query and a set of key<>value pairs which map the query to an output. And note that all the parameters of this function are vectors.

Scaled Dot Product Attention

  • Queries and Keys are of dimension d_k and value can be of dimension d_v
  • Scaled attention is computed as follows, the scaling here done by the 1 / d_k , which is because we use “product-attention” instead of additive version , since dot prods can be faster and more parallel but long muls can cause activations to saturate , hence we do the scaling.
Attention(Q,K,V) => softmax( Q . Kt / sqrt(d_k) ) . V

Q.shape = [B,T,D] , K.shape = [B,T,D] , V.shape = [B,T,D]

Technically the first dot product can be interpreted as :

  • axis 0 (B) → batch index b
  • axis 1 (T) → query position index i
  • axis 2 (T) → key position index j

So element-wise:

scores[b, i, j] = dot(Q[b, i, :], K[b, j, :])

Each row scores[b, i, :] answers: “Query token i — how much attention do I give to each key token j?”

And in the decoder step when we do masking , what we really mean to say is that: “We never block queries

Each query position i always produces an output vector.

What we block is:

which keys a given query is allowed to look at.

So masking operates on the key axis (j)conditioned on the query axis (i).

This will produce a mask that is -inf on the upper triangular and allows values only from the lower triangular to pass.

scores[b, i, j] = -if j > i
scores[b, i, j] = real value otherwise

What this really means is query_1 can only access keys of positions upto k_1 and not ahead, query_2 can use k_1,k_2 and so on.

i\j   0   1   2   3
 0   ok  ---
 1   ok  ok  --
 2   ok  ok  ok  -
 3   ok  ok  ok  ok
 

That is why:

  • upper triangle = masked
  • lower triangle + diagonal = allowed

Note that the diagonal is included meaning we can see up-to and including the key for the current query position.

Multi-head Attention

Instead of performing the above attention step we do a lot more with vectorization and representation!

We project each of the inputs Q,K,V linearly “h” times , where each projection is done by a d_k,d_k,d_v dimentional vector respectively. so one triad of QKV is now projected “h” number of times and this group of representations is processed in parallel !

Basically a couple of things happened here: a. we are having multiple representations for the same data by projecting it “h” times with parameterized layers which also learn better representations over training. b. Attention is computed “h” times , once for each representation , and this is done in parallel by smartly vectorizingit.

c. We then take this output for each of the h representations ( d_v times h ) concat’ing them into a single big matrix and project it back to d_model dims producing the output of the multihead-attention block.

One of the most important thing when implementing this is to remember to pick h such that d_model/d_k = h , and h (no of heads) is an int.

Implementation Mental Model

the input is a D dim embedding per token T packed into B items per batch so its [B,T,D]

here “big_” meaning that all heads are jointly processed

we have projection matrices for big_Q,big_K,big_V of dims [D,D] or better said [D, H * d_head] so that becomes [B,T,D] @ [D,H*d_head] = [B,T,H * d_head]

we reshape [B,T,H * d_head] to have separate heads [B,H,T,d_head] , by doing a reshape+transpose,so that there is no mixing between heads and then scaled attention is computed on this reshape+transposed matrix input.

scores: (B, H, T, d_head) @ (B, H, d_head, T) → (B, H, T, T)

weights softmax over last dim

(B, H, T, T) @ (B, H, T, d_head) → (B, H, T, d_head)

Then the result of the scaled attention is again [B,H,T,d_head] which we concat along the d_head axis eliminating the H axis, i.e: when implementing this is done as transpose(1,2) followed by a reshape([B,T,D]) , note here that i assume D = H * d_head.

this is then passed through a linear projection W_o of dims [D,D] , making the op, [B,T,D] @ [D,D], this essentially helps intermixing info across heads and sharing info

Position-wise Feed-Forward Networks

At the end of the attention layers , in both the encoder and decoder we see a full connected MLP. It has 2 linear transforms with a relu between them. Input/Output dims are d_model = 512 and hidden dims is 2048.

FFN(x) = max ( 0, x . W_1 + b_1 ) . W_2 + b_2

Embedding & Softmax

  • Learned embedding is used to conver input to tokens of dims d_model.
  • Same embedding is used on the softmax(decoder output) to convert to next-token probs.

Positional Encoding

This is kind of like binary encoding where we vary the period of the bit location to generate a unique sequence. So instead of a binary signal we generate a sin/cos signal which is represented by a vector of dim i . Each dim is a different period and positional data is encoded as a function of sin applied to params ( i , position ). The main reason to choose this representation is the ease for the model to attend to relative position granted the positional encodings can be represented as linear functions of each other.

Why Self Attention?

  • computational complexity per layer.
  • amount of computation that can be parallelized , i.e: min number of sequential ops required.
  • path length between long range dependencies.

Training

  • Hardware: x8 Nvidia P100 GPUs
  • Optimizer = Adam, lr was adaptive
  • Regularization : residual dropout , dropout in the summation between positional embedding and input embedding , label smoothing.

Results

Explore from reading paper.