It is very tempting to just create a model zoo of all the different modern architectures at this point, i wish to capture the learnings instead of just reproducing the architecture from a paper, but i will try to mention and draw relation whenever they seem relevant along the way.
Current Status: Working on Stage 1 🚧
Overview
Whats been done?
Checkout Phase-0 for whats already covered, it mostly includes the fundamentals and early tweaks to the transformer around the GPT-2/3 era.
What I don’t know but wish to know/do
-
Positional Encodings, especially with RoPE and ALiBi where we dont have a separate embedding added, instead inject the positional info directly inside the decoder block. Interested in understanding the mechanism and also where its limits are and how it can break.
-
Attention for training:
-
Attention for inference:
- Flash Attention for train/inference but paged attention only for inference.
-
MoE has a lot in it, especially wrt sparse networks and such. Need to read the paper trail and dig deeper into it. The following are a bunch of relevant things i wish to have a deeper understanding about:
- MoE routing
- top-k gating
- auxiliary load-balancing loss
- capacity factor / token dropping
- grouped experts
- top1 vs top2 routing
-
This also opens up other high level themes like: “Dense vs Sparse Models” , total active params vs model params and how to train sparse models etc.
Implementation know hows to acquire
-
setting up a end to end pipeline for training and evaluating llms and also easily measuring them against benchmarks.
-
inference focus metric collection:
- KV Cache size and other cache memory and growth metrics wrt other variables.
- prefill time
- decode time
- prefill vs decode
- flash-attention, paged attention performance improvements.
- expert utilization in MoE
-
Understanding overall system architecture for large-scale inference. Look at how custom models are serevered with tools like vllm , lmcache/mooncake for kv cache etc.
-
Training Techniques for LLMs:
- Data Parallel
- Tensor Parallel
- Expert Parallel (in MoEs)
The Plan
The main goal for this phase would be to get a hang of “Dense Decoders” in terms of architectures, attention mechanisms and eval/inference measures. There are a few gears i would be shifting through in this phase, hence setting it up as different stages to neatly build on to of my work.
Stage 1 - Large Model Training Baseline
In this stage i first want to establish and locking the training recipe and incorporate all the missing eval metrics: for inference , kv cache etc looking beyond just training. And naturally to see these metrics work, a proper testing and benchmarking pipeline also needs to be added.
Stage 2 - Positional Encoding Ablation
- RoPE : Revisit the paper and core ideas and implement it as a standalone module.
- ALiBi: Revisit the paper and core ideas and implement it as a standalone module.
- Ablation study: measure convergence speed, long context extrapolation, failure cases in generation, when and why one does not work and the other does.
Stage 3: Attention Flavour - KV Sharing
- Implement MQA ( a config variant of which is GQA where num of kv_groups = 1 )
- Ablation study MHA vs MQA vs GQA : measure quality, prefill latency, decode throughput, KV-Cache growth
Stage 4: Attention Flavour - Sliding Window Interleaving
- Implement Sliding Window Attention, measure long context memory scaling.
- Alternate local and global attention widths between layers, measure cache growth change. (Gemma-3 arch*)
- Implement Gated Attention (Gemma-2 arch*) and compare with above.
Stage 5: Synthesis
Combine the best performing configs for all above ablations and train a modern dense decoder model!