Good blog from torch: https://pytorch.org/blog/activation-checkpointing-techniques/
Good discussion about the implementation of activation checkpointing using “Mincut” here
Real questions to answer from this experiment.
-
For the same training setup, what memory do I save and what slowdown do I pay?
-
Given the same GPU memory budget, which strategy gives me the most throughput or largest trainable setup?
Measurement Conditions
- Fixed tokens per run: ✅
- Same seed: ✅
- Same effective batch: ✅
Wandb Report
torch.compile is giving us the best performance gain providing the best trade off of recompute for time in our A100 setup. For detailed Analysis see the wandb report above. But a budget of 0.8 isn’t far off especially if operating in a extremely memory constrained setup as it future reduces the memory requirement by another 20% !
Conclusion
-
Winner: Baseline
torch.compile -
Memory saved:
- With no optim and chk-pt baseline mode was @ 14.8GB & 130s
- With the best (
torch.compile) we are @ 10.6GB & 48s
-
Speed change:
- 2.7x speed gain while reducing memory footprint by 30+%
-
Now we should be able to accommodate a bigger batch size!
Experimental Notes
Group 1 — 3-memory-frontier-step1-compile-budget-20260314-091315-256427
- Ran ~10am March 14
- Git commit:
509d1bc06bee27c632e74f20d1f0d2a406ee2d65 - Tested AC budgets: 0.12, 0.38, 0.75 at mb=16, 32, 40–64
- This is the one with the weird results (AC=0.75 mb=32 appearing faster than no-AC, etc.)
Group 2 — 4-memory-frontier-step1-compile-20260314-143138-930527
- Ran ~2:30pm March 14 (after group 3)
- Git commit:
4e12565103f48d744baf18b868040dfcd2d1884d— a different, later commit - Only tested no-AC (
activation_memory_budget=null) at mb=16, 24, 28, 30, 31, 32 - Has a
summary_stage="final"run — the only group with a final summary tag