Background
Now would be a great time to refresh Kaplan et al ( my notes: Scaling Laws for Neural Language Model ) and Chinchilla.
Building A Not-So-Small Language Model
Predominantly, I’ve refactored the code ( Pythonista7/transformer-room ) to the following config:
Click to view sample Config JSON
```python { "base_vocab_size": 10_000, "num_special_tokens": 2, # EOS and PAD "vocab_size": 10_002, "d_model": 128, "n_heads": 8, "layers": 2, "learning_rate": 0.001, "epochs": 3, "training_seq_len": 128, "training_stride": 128, "data_fraction": 1, "batch_size": 256, "checkpoint_every_n_steps": 250, "checkpoint_path": "gpt2_baseline_checkpoint.pt", "final_model_path": "gpt2_baseline_model.pt", "dataset_path": "../datasets/tiny_shakespeare.txt", "tokenizer_vocab_path": "tiny_shakespeare_bpe_vocab.txt", "resume_from_checkpoint": True, "use_torch_compile": True, "torch_compile_mode": "default", "torch_compile_fullgraph": False, "torch_compile_dynamic": False, } ```We will only need to tweak the above to get to the GPT-2 config, the architectural details mostly remain the same as the Shakespeare baseline we train but with more layers and more data. The original paper trains the following sizes of GPT-2 :
| Parameters | Layers | d_model |
|---|---|---|
| 117M | 12 | 768 |
| 345M | 36 | 1280 |
| 726M | 36 | 1600 |
| 1542M | 48 | 1600 |
I currently am using a Nvidia A100 with 40GB VRAM, so i should be able to replicate the smallest 117M model, or a slightly smaller variant of that. I will be primarily focusing on the language modelling tasks.
For this purpose we will be making a switch to a bigger dataset, tokenising and generating a bigger vocab. Actually this is more involved than just a dataset change so let’s talk about how we make each design choice below.
Bigger Dataset
The original GPT-2 was trained on WebText, which is claimed to be about ~40GB of text data and not release by openai, but we have OpenWebText here which is a community curated version of similar characteristics which i will be using here. But first I would like to get used to the long train process on something a bit smaller, measurable but also in the direction before committing resources to a huge run, so i picked WikiText which is a order of magnitude smaller but still contains ample learnable signals for LM tasks.
It is however worth mentioning the learnings from the Background section that scaling laws propose larger models are more sample efficient (Kaplan et al) and a sufficiently large model is severally under-trained and can we converged with a lot more data than previously expected (Chinchilla) , even though our goal is to train a model of GPT-2 size , we need to align and calibrate the expectations for performance and the perf-cieling wrt these findings and a realistic experiment budget ( which isn’t a whole lot for me right now xd ).
Wiki Text
Hugging Face , first we will be using wikitext-2 for a base training run and once we see a good train run on this we can also try out the wikitext-103 which contains over 100million+ tokens as we improve the architecture.
My tokenizer wants to send me on a optimization side quest!

Ain’t no way I’m waiting for a couple of hours for the corpus to train BPE at the moment!
Sure the dataset itself does have a tokens property i could just access, but thats no fun if i have no idea how to better my BPE especially if it cant even practically perform for a small-medium sized dataset. How do the big players do it? We did use tries to improve prefix matching but that was only for faster encoding, but we need to speed up merges somehow! Will proceed with the tokenised vocab from the dataset this is a problem for a separate post.
Avoiding this battle for another day, using the given tokenizer at hand.
Batch Size
Predominantly a trade-off for parallelism with the amount of VRAM available. Bigger batch sizes need more VRAM but also offer faster processing, for a smaller VRAM we will need to use techniques like micro-batching, gradient accumulation or trade off other hyper params like sequence len and vocab which may not be desirable.
The optimal batch-size can probably only be derived empirically for a give scenario and set of constraints. But this can be done in a organised fashion by using estimates for memory allocations, data-type awareness, profiling and calculating peak memory requirements which I’ve tried to in the following sections.
Schedulers and Learning Rate
The original Attention is All You Need paper, as well as reading Layer Norm in Transformers points to needing a ramp-up on the learning rate first followed by a long decay for optimal performance. ( will need to track this on wandb )
Also the GPT-2 paper doesn’t mention any scheduling config or learning rates, they only say:
The learning rate of each model was manually tuned for the best perplexity on a 5% held-out sample of WebText. All models still underfit WebText and held-out perplexity has as of yet improved given more training time.
Epoch / Training Time
Epoch can be a deceptive measure to think from, think in terms of tokens and steps. Then we can ask:
- How many tokens per
optim.step()is going to produce a backprop update, is that good enough? - Then we can build up an expectation for , “How many steps, and in turn, How many tokens does the model need to see and learn from before we start seeing it “learn” meaningfully from the metrics ?”
The different loss-curve stories
-
In a happy path both train and val loss would track each other and keep reducing as we give it a good amount of data.
-
An obvious overfit-signal would be the train and val loss diverging after a certain point meaning we are starting to overfit and the model is big enough(has enough capacity) to start to memorise the training data and starts performing worse on val since its memorising during training.
-
If train-loss keeps dropping but val loss slightly drops and then plateau’s then it means our dataset is too small for our model size and regularisation is strong, forcing the model to not-overfit but there is also nothing more to learn from the data, i.e: to not learn beyond a certain point despite it being able to.
Token Budgeting
The background reads give us a good estimate of model-params to data-tokens size ratios that support optimal learning. The earlier Kaplan et al favour’s bigger models which underfit while the Chinchilla suggest smaller models trained of a lot more data for longer to obtain optimal compute efficiency.
Let = number of parameters (in billions) and = training tokens (in billions).
- Scaling Laws (2020) suggest
- Chinchilla (2022) suggest
So in my case taking WikiText2 which has around 4.1M tokens and the model having 130M params of which are the decoder layers (which is what the papers consider when they say model size) so in that regard our dataset is still very small and needs to be ordered of magnitude larger for optimal compute utilization ( close to even in the underfitting scenario of 2020 scaling laws and close to as per chinchilla ).
GPT-2
Using the gpt2 configs as defined in the paper in my config run below:
JSON Config for the model
{
"run": {
"value": {
"run_name": "wikitext2_gpt2_v1",
"project_name": "transformer-room-baseline",
"artifacts_root": "/workspace/transformer-room/baseline/models",
"use_torch_compile": true,
"torch_compile_mode": "default",
"checkpoint_filename": "baseline_checkpoint.pt",
"final_model_filename": "baseline_model.pt",
"torch_compile_dynamic": false,
"resume_from_checkpoint": true,
"torch_compile_fullgraph": false,
"checkpoint_every_n_steps": 250
}
},
"model": {
"value": {
"name": "baseline_decoder",
"layers": 12,
"d_model": 768,
"n_heads": 8
}
},
"split": {
"value": {
"name": "holdout",
"seed": 42,
"shuffle": false,
"train_fraction": 0.9
}
},
"train": {
"value": {
"epochs": 3,
"stride": 1024,
"seq_len": 1024,
"batch_size": 512,
"data_fraction": 1,
"learning_rate": 0.001
}
},
"dataset": {
"value": {
"name": "hf_text",
"split": "train",
"max_rows": 0,
"streaming": false,
"text_field": "text",
"dataset_name": "Salesforce/wikitext",
"dataset_config": "wikitext-2-v1"
}
},
"logging": {
"value": {
"provider": "wandb"
}
},
"tokenizer": {
"value": {
"name": "bpe",
"vocab_path": "/workspace/transformer-room/baseline/tokenizers/wikitext2_v1_hf_vocab_bpe.txt",
"base_vocab_size": 33280,
"num_special_tokens": 3
}
}
}Running on a A100 (Befriending OOM’s)
Spinning up a A100 on vast.ai for this and I still run into:
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
OutOfMemoryError: CUDA out of memory. Tried to allocate 65.01 GiB. GPU 0 has a total capacity of 39.49 GiB of which 38.66 GiB is free. Process 2081528 has 848.00 MiB memory in use. Of the allocated memory 339.82 MiB is allocated by PyTorch, and 18.18 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"Why is this happening ?! Even reducing the number of layers i still get the error Tried to allocate 65.01 GiB. GPU 0 , lets dig in.
Since the issue is not layer specific, there seems to be one specific dense matrix init which is eating up that big alloc, the first culprit seems to be the large batch size, maybe its too big for a single GPU run, doing some quick math :
T=1024;H=8 for B in [1,2,4,8,16,32,64]: attn_fp32=B*H*T*T*4/1024**3 logits_fp32=B*T*33283*4/1024**3
print(B, 'attn_scores_fp32_GiB=',round(attn_fp32,3),'logits_fp32_GiB=',round(logits_fp32,3)) attn_scores_fp32_GiB= 0.031 logits_fp32_GiB= 0.127
2 attn_scores_fp32_GiB= 0.062 logits_fp32_GiB= 0.254
4 attn_scores_fp32_GiB= 0.125 logits_fp32_GiB= 0.508
8 attn_scores_fp32_GiB= 0.25 logits_fp32_GiB= 1.016
16 attn_scores_fp32_GiB= 0.5 logits_fp32_GiB= 2.031
32 attn_scores_fp32_GiB= 1.0 logits_fp32_GiB= 4.063
64 attn_scores_fp32_GiB= 2.0 logits_fp32_GiB= 8.126so the attention scores look manageable if we use a smaller B . Now the other places where batch size may be needed is the output projection back to vocab which has the follow shapes:
batch=512, seq_len=1024, vocab=33283
512 * 1024 * 33283 * 4 bytes = 69,79,95,10,016 bytes = ~69GB
This number looks closer to the OOM error, so lets try to ramp down the batchsize to something within 40GB with batch size of 256 instead which should half the mem-alloc. That didn’t work either, went straight down to batch_size 64 now and still got this:
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 GiB. GPU 0 has a total capacity of 39.49 GiB of which 817.56 MiB is free. Process 2166408 has 38.69 GiB memory in use. Of the allocated memory 38.15 GiB is allocated by PyTorch, and 48.04 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_ALLOC_CONF = expandable_segments:True to avoid fragmentation.
See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Picking the right dtype
It was recommended in multiple places on the internet to switch to BF16 instead of FP32 for training on A100 since it mostly provides the same numerical range while being half the bytes! Moving to bf16 as suggested for A100 GPU’s should half the memory requirement but we still run into OOM, so added a quick util script to check it here:
So even with bf16 it takes the following deconstruction of my attention implementation for configs:
B, T, D, H, L, V = 64, 1024, 768, 8, 12, 33283
- scores shape
[B,H,T,T]=(64, 8, 1024, 1024) -> bf16 -> 1 GiB - softmax : stores both input and output for use in backprop so the above 1 Gib becomes 2GiB.
Checkout Backprop = chain of VJP’s! to understand why we need both softmax input and output during backprop.
- attention mask =
[B,1,T,T]=(64, 1, 1024, 1024)→ bool →0.0625 GiB - causal mask =
[T,T]=(1024,1024)→ bool →0.00097 GiB( but does broadcasting enlarge it?) - key padding mask =
[B,T]=[64,1024]→ bool →0.00006 GiB - attn =
[B,H,T,D/H]=[64,8,1024,768/8=96]→bf16→0.09375 GiB - joined_heads =
[B, T, D]=[64,1024,768]→ bf16 →0.09375 GiB
Now if we add input and output projections of the MHA :
- Input =
[d_mode , 3*d_model]=768 * 3 * 768=0.00164795 GiB - Output =
[ d_model, d_model ]=768 * 768=0.0005 GiB
MHA Total = ~2.4 GiB in just a single MHA
Then we have output projections which are :
- L1 =
[d_model, 4 * d_model]= bf16 =768 * 768 * 4=0.002→ bf16 →0.004 GiB - L2 =
[4 * d_model , d_model]= bf16 =768 * 768 * 4=0.002→ bf16 →0.004 GiB
Decoder Total = ~2.45 GiB
Then we have the embeddings and vocab projections:
- embeddings =
[vocab_size, d_model]=[33283 * 768]= bf16 = 0.02 * 2 =0.04GiB - output =
[d_model, vocab_size]=[33283 * 768]= bf16 = 0.02 * 2 =0.04GiB
And the logits produced is [batch, seq_len , vocab_size] that is:
[64 * 1024 * 33283 ]= bf16 so x2 =2.03 GiB
That brings up the total to almost :
~2.5 GiBx number of decoder layers2.5 GiBlogits~1 GiBof optimiser state
That already totals to 33.5GiB !
i have not even carefully accounted for bias terms , other params and not to mention gradients memory!
Lets use the code in the transformer-room/baseline/util.py to generate accurate stats as per the current model implementation and compare the two d-types.

Thats kind of in the ball-park of what we expected! which is still huge! We will need to go beyond picking the right d-type here.
So sadly, this is the point where we start cutting down the batch-size and if required other model-size hyper-params like seq_len and d_model .
Optimisation’s & Trade-Off’s
-
Use SDPA kernel for MHA instead of own implementation, current implementation just materialises too many tensors.
-
Memory vs Speed trade-off can be achieved with the use activation checkpointing ( torch has a nice Blog ) this blog is really good! can use the SAC concept to trade compute for memory in my case here.
-
The output projections is huge (4 GiB ) , look into :
- Microbatch + grad accumulation so the mem-peak towards the end can be reduced which explodes due to vocab size.
- Chunked vocab cross-entropy (avoid full logits materialization) Megatron-style idea, adding Megatron to reading list.
-
More optimisations like moving to better attention implementations (flash,paged)
The Upgrades
So i will be trying to implement each one of these upgrades to measure:
- Peak memory requirement until we get running on a A100
- Once we get running , we can have the expected epoch time as a comparison between the upgrades.
First we will try to see how much compromise is required on the batch size to just fit in our model without change and then we will implement the SDPA and try to bring back as much batch as possible.
On the A100 instance
Estimates
From running the memory estimator util without alloc’s :

Real One-Step Train Memory Profile (CUDA)
Replicated the OOM :

Estimate on the CuDA Machine | Cutting the BatchSize Bills
After playing around a bit with the batch size i found B=20 to occupy just about 38.9GiB at peak , that just on the edge of our ~40GB of A100.

Running into OOM even with the estimate
So just barely fitting the model was able to complete the forward, since it OOM’s after the first loop of epoch it probably was not able to do the loss.backward() and/or optim.step() or maybe when loading up the validation set? But nonetheless we will need to create a bit of headroom.

Figuring it out!
This:
Tried to allocate 2.54 GiB. GPU 0 has a total capacity of 39.49 GiB of which 95.56 MiB is free. Process 3691180 has 39.39 GiB memory in use. Of the allocated memory 35.69 GiB is allocated by PyTorch, and 3.20 GiB is reserved by PyTorch but unallocated.
This really is the issue, we have 3.2 GB being unallocated which sucks! Lets set the suggest env
PYTORCH_ALLOC_CONF=expandable_segments:true see docs and try it again!
And it works! The fragmentation issue is gone! We are able to run 3 training epochs now!

The train loss does drop slightly and the lr/scheduler and other hyper-params can be tuned better but to be clear, this is just the baseline, and we will be running multiple experiments with improvements along the way to continue to improve this. Considering it takes around 2min 10sec on the wiki-text2 , I am eager to try a slightly longer training run to see overfit behaviour on my first “no-so-small” LM before we move on to the larger Wikitext-103.
Further Experiments