In this Phase 2, I want to lay emphasis on the inference and serving end of things. Mostly focusing on KV Caching and specialised attention mechanisms tailored for inference. Some core experiments i wish to run here:
-
KV-Caching
- cache growth ablation between MHA/GQA/MQA , with different depths and context-lengths.
-
Build a Pre-fill vs Decode Benchmark
- Compare between different :
- Attention Types
- model input prompt lengths
- Look for tokens/sec on both metrics
- Compare between different :
-
Flash Attention
- Compare training speeds ( torch SDPA already has a FA3 backend )
- memory efficiency gains
-
Paged Attention
- Possibly use vLLM and do a model “serving” study.
- memory efficiency during inference
-
Bouns:
- MLA (DeepSeek multi-latent attention): Interesting KV compression angle
Note: Add inference and cache related metrics before running these experiments.
Reads
- Pytorch Blog: Disaggregated Inference at Scale with PyTorch & vLLM
- Fun tweet to visit when i get here: https://x.com/asmah2107/status/2030109681839198628?s=20