Getting fatigued just reading, following along with this video since this paper looks mostly about experimental finding by doing a sort of “hyper-param” search on model specifications. Also this seems to be a foreshadowing to what LLaMA does , so going to speed run this one.

Since it is typically only feasible to train these large models once, accurately estimating the best model hyper parameters for a given compute budget is critical

Section 3 is where the core ideas are. They try to figure out the following :

Given a fixed FLOPs budget, how should one trade-off model size and the number of training tokens?

And they try to answer this with 3 different approaches.

Approach 1: Fix model sizes and vary number of training tokens

  • Fixed Model Sizes between 70M to 10B
  • For each model, train of 4 different training-sequences(token lens)
  • For a fixed FLOPs

Approach 2: IsoFLOP profiles

  • Vary Model Sizes
  • Fixed FLOPs count (ranging from 6 × 1018 to 3 × 1021 FLOPs )
  • Compute required tokens based on model-size and FLOPs

This allows us to directly answer the question: For a given FLOP budget, what is the optimal parameter count?

Approach 3: Fitting a parametric loss function

Main Take Away

All three approaches suggest that as compute budget increases, model size and the amount of training data should be increased in approximately equal proportions. The first and second approaches yield very similar predictions for optimal model sizes, as shown in Figure 1 and Figure A3. The third approach predicts even smaller models being optimal at larger compute budgets.

The Chinchilla model itself is just trained based on these insights from the extensive experiments on different model-specs.

We test this hypothesis by training a model on the larger end of this range—70B parameters—for 1.4T tokens, due to both dataset and computational efficiency considerations

  • We train Chinchilla on MassiveText
  • We use AdamW
  • We train Chinchilla with a slightly modified SentencePiece Tokenizer.
  • Whilst the forward and backward pass are computed in bfloat16, we store a float32 copy of the weights in the distributed optimiser state ( Note entirely clear why this is the case.)

    (Rajbhandari et al., 2020). See Lessons Learned from Rae et al. (2021) for additional details.

Chincilla outperforms other larger models by a good step despite being a fraction of the param-size.