Related: Batch Norm , Layer Norm
Quick recap :
- Batch Norm: Create batches, calculate norm stats over the batch ( mean, var ) and use that to normalize the input distribution for the subsequent layer.
- Layer Norm: Relying on batches becomes tricky in sequence models that use BPTT, instead of having off-graph nodes like in BN (for mean and var), we can include it as a layer with params in the diff-graph and learn projection params as part of training, this allows us to normalize per sample and then project it via learned params which gives expressivity back to the models while also doing normalization.
Code/Flow Comparison (Pre-LN vs Post-LN)
“Attention is All you Need” - Transformer code ref , they call this the Post-LN Transformer.
# self.layer_norm{N} = nn.LayerNorm(d_model)
# Decoder
# [B,T,d_model]
self_attn = self.self_MHA(Q,K,V,mask,padding)
layer_norm1 = self.norm1(self_attn + output_embedding)
feed_forward1 = self.linear1(layer_norm1)
a1 = self.relu1(feed_forward1)
feed_forward2 = self.linear2(a1)
layer_norm2 = self.norm2(feed_forward2 + layer_norm1)Pre-LM Transformer flow:
# Decoder
# [B,T,d_model]
+ norm_X = self.norm1(X)
+ self_attn = self.self_MHA(norm_X,norm_X,norm_X,mask,padding) # self attn
- self_attn = self.self_MHA(Q,K,V,mask,padding)
- layer_norm1 = self.norm1(self_attn + output_embedding)
+ self_attn = self_attn + X
- feed_forward1 = self.linear1(layer_norm2)
+ layer_norm2 = self.norm2(self_attn)
+ feed_forward1 = self.linear1(layer_norm2)
a1 = self.relu1(feed_forward1)
feed_forward2 = self.linear2(a1)
- layer_norm2 = self.norm2(feed_forward2 + layer_norm1)
+ res = feed_forward2 + self_attnCore problem
In Pre-LN Models, LN is between residual blocks, the training for Pre-LN models is sensitive to max learning rate. They need to use a slow ramp-up/warn-up process which slows down training and forces us to deal with more hyper-params. Its best to get rid of this all together.
In this paper, we try to alleviate this problem by finding ways to safely remove the learning rate warm-up stage.
Q. Why is warm-up required for Pre-LN Transformers ? Analysis suggests that when using Pre-LN the gradients of params near output layer are quite large. Hence if large lr is used in the initial steps it can destabilise the learning/optimization process.
In olden times, training of CNN and RNN’s used to start off with high lr and then toned down, having a ramp-up only proved to be useful when very large batch sizes were used. But transformers be different, having no/low warn-up causes the “optimization to diverge”.
The paper goes on to repeat a lot of whats already in Attention is All You Need but thats ok.
Learning Rate Def (Vasvani el al.)
This is how Post-LN does warm-ups:
Experiments shows having warmup and high are good for training post-LN irrespective of the optimizer used (they exp with SGD & Adam)
Hypothesis and Empirical Evidence
Paper goes on to mathematically hypothesise how post-LN can cause large gradients and pre-LN has comparatively much smaller gradients, they go on to empirically prove this with experiments.
Conclusive evidence can be seen in Figure 3 where using Pre-LN consistently maintains lower gradients with increasing depth while post-LN starts to explode without warm-up or is too small with warmup, both of which are sub-optimal or destructive for learning.
Well that pretty much captures the core ideas.