Project: Transformer Efficiency Lab (TEL), Phase 1 Stage 1
Repo: Pythonista7/transformer-room & feat branch ph1-stg1
W&B: ashwin-ms-does-ai/transformer-room-baseline
Hardware (probes): A100 (Thunder Compute) and local M-series MPS
TL;DR
Metric-collection callbacks were migrated to forward_hooks to avoid graph breaks under torch.compile. Instead of getting faster, a metric-heavy “collision” step became ~4× slower than the equivalent pre-migration run. Root cause: guard-failure-driven recompilations, not graph breaks. Three independent bugs each forced a full recompile:
- A lazily-grown positional-encoding buffer (
pos_enc_cache) whose shape changed on the first forward. - An activation-norm forward hook whose closure captured a mutable boolean. Dynamo walked
__closure__[0].cell_contentsand installed a guard oncollector.capture_activation_norms. - The attention-entropy hook had the same closure-cell pattern, first on
capture_attention_entropy, then on the_attention_stash_by_layerdict.
Fix: pre-materialize the pos-enc buffer in __init__; decorate all hook bodies with @torch._dynamo.disable. Final state: 5 dynamo frames (one per @disable boundary), each compiled once, zero guard-failure recompiles.
1. The setup and the surprise
Phase 1 latency-cadence probe, collision variant fires every metric group at step 12:
d_model=768, n_heads=12, n_layers=12(~124M)seq_len=1024, micro_batch=96, accum=5, effective_batch=480log_every_n_steps=4,max_steps=16
Migration commits:
abcc124 fx: use hooks for attn entropy instead of callbacks that cause graph-breaks on compile
7a8ebaf ch: torch compile on everything but cpu
60654f6 fx: calc attn entropy off graph
f94a7ca fx: efficiently compute norms
a04faf7 fx: compute on device and only move scalars to avoid D2H overheads
W&B collision-run step timings (ms)
| Group | Run ID | step 4 | step 8 | step 12 | step 16 | epoch_s |
|---|---|---|---|---|---|---|
test-3 (callbacks) | 9s1nfa9e | 13,162 | 13,242 | 42,878 | 13,119 | 261.57 |
test-3-full-graph-mode | nnees7u2 | 12,947 | 10,938 | 8,072 | 3,567 | 155.04 |
test-4-hooked-metrics | 0n8chzxe | 4,108 | 4,110 | 49,781 | 10,348 | 228.41 |
test-4-full-graph | 1z8r401d | 47,548 | 21,945 | 79,144 | 4,111 | 627.90 |
Forward/backward breakdown for 1z8r401d:
| step | forward_ms | backward_ms | step_ms |
|---|---|---|---|
| 4 | 1,337 | 43,555 | 47,548 |
| 8 | 1,430 | 17,986 | 21,945 |
| 12 | 30,044 | 45,501 | 79,144 |
| 16 | 1,523 | 103 | 4,111 |
Backward dominating at step 4/8 (where no cadence fires) pointed at compilation, not metric cost.
2. Key passages from torch.compile, the missing manual
Ref “Torch Compile the missing manual” - https://youtu.be/rew5CSUaIXg + Doc
- Hook compilation is deferred to
backward()- explains the huge first-few-step backward times. fullgraph=Trueconstrains graph breaks, not recompilations.- “If you have a branching conditional and you need both branches to be compiled, we MUST compile twice, or you must figure out if you can use
torch.cond.”
3. tlparse diagnostic
TORCH_TRACE="/tmp/tracedir" python experiments/phase-1/ph1_latency_cadence_probe.py --variant collision
tlparse /tmp/tracedirtl_out/ showed 4 compilations of frame 0. Reasons (verbatim from recompile_reasons_68.json):
0/0: tensor 'self._modules['pos_encoding']._buffers['pos_enc_cache']'
size mismatch at index 0. expected 0, actual 128
0/1: self._modules['dec_layers']._modules['0']._modules['multi_head_attention']
._forward_hooks[list(dict.keys(...))[0]]
.__closure__[0].cell_contents.capture_activation_norms == False
# forward_hook_metrics.py:266 in activation_hook
0/2: not self._modules['dec_layers']._modules['0']._modules['multi_head_attention']
._forward_hooks[list(dict.keys(...))[0]]
.__closure__[0].cell_contents.activation_norms
# forward_hook_metrics.py:272 in activation_hook
fullgraph=True produced the same reasons. Confirmed: fullgraph doesn’t stop recompiles.
4. Bug 1 - lazy pos_enc_cache
# original
self.register_buffer('pos_enc_cache', torch.empty((0, d_model)), persistent=False)
# in forward:
self.pos_enc_cache = torch.cat((self.pos_enc_cache, extra), dim=0)Dynamo guarded on shape [0, 768]; first real forward grew it to [128, 768]; guard failed; recompile.
Fix: materialize full cache at __init__ using max_seq_len. Never reassign in forward.
Validation: tl_out_pos-enc-fix/ shows 3 frames (0/0 0/1 0/2). The pos-enc guard is gone.
5. Compiler caches - what fx_graph_cache_miss / aotautograd_cache_miss mean
These are compile-time caches, not runtime caches.
fx_graph_cache: keyed on FX graph + input metadata + inductor config. On hit, skips Inductor lowering + Triton codegen.aotautograd_cache: keyed on joint graph. On hit, skips fw/bw partitioning.
In tl_out_pos-enc-fix, [0/2] shows aotautograd_cache_miss + fx_graph_cache_miss - the third compilation is a structurally different graph (the now-active capture_activation_norms=True branch pulls in new tensor ops), not a re-validation. Frame count × compile_id is what matters for runtime; cache state is informational.
6. Bug 2 - activation-hook closure guard
Dynamo’s guard path:
self._modules[...].multi_head_attention
._forward_hooks[...]
.__closure__[0].cell_contents
.capture_activation_norms
Config has skip_nnmodule_hook_guards: true - but that flag only skips guards on hook registration, not closure cell contents.
Fix:
--- a/src/training/metrics/plugins/forward_hook_metrics.py
+++ b/src/training/metrics/plugins/forward_hook_metrics.py
@@ -261,7 +261,7 @@ def register_forward_metric_hooks(
layer = dec_layers[layer_idx]
label_tuple = tuple(labels)
collector.register_attention_layer(layer_idx, label_tuple)
-
+ @torch._dynamo.disable
def activation_hook(_module, _inputs, output, label_tuple=label_tuple):
if not collector.capture_activation_norms:
returnValidation: tl_out_diabled_dynamo/ shows [0/0] [1/0] [2/0] [3/0] [3/1] [3/2] [4/0]. Frame explosion expected (disable = graph break). But [3/1], [3/2] exposed a second closure-cell problem on the attention-entropy hook, previously masked because dynamo reports only the first-failing guard. From recompile_reasons_109.json:
3/0: ...cell_contents.capture_attention_entropy == False
# forward_hook_metrics.py:79 in stash_attention_projection_once
3/1: not ...cell_contents._attention_stash_by_layer
# forward_hook_metrics.py:100 in stash_attention_projection_once
7. Bug 3 - attention-hook closures
Same @torch._dynamo.disable treatment on packed_proj_hook and entropy_hook.
Validation: tl_out_diabled_dynamo_attn_entropy/ shows 5 frames, no guard-failure recompiles. Two show an _1 suffix: [3/0_1], [4/0_1].
The tlparse compile-id format is [frame_id/compile_id_attempt]:
compile_idincrementing = recompile (distinct graphs).attemptincrementing = dynamo restarted the same compilation (speculative rollback + retry). Only the final attempt produces a runnable graph.
Each frame compiled exactly once.
8. Mental model clarifications
- The graph = FX graph built by tracing tensor ops inside
compiled_model(...). Timing code,loss_fn,loss.backward(),optimizer.step(), metric Python are all outside. - The compiled region = scope of
torch.compile(...).fullgraph=Trueis a constraint on that region, not a different region. - Callbacks ran in eager Python after
compiled_model(...)returned - invisible to dynamo. - Hooks run during the compiled forward - fully visible to dynamo; any mutable state they close over becomes a guard candidate.
@torch._dynamo.disablemakes a function a black box. Tradeoff: graph break at each disable boundary. Incompatible withfullgraph=True, fine withfullgraph=False.
9. Tracking recompilation in production
TORCH_LOGS=recompiles- prints failing guard on each recompile.TORCH_TRACE=/tmp/tracedir+tlparse /tmp/tracedir- full HTML report.torch._dynamo.utils.counters["frames"]["ok"]- identifies which steps trigger compilation:
import torch._dynamo.utils as dynamo_utils
before = dynamo_utils.counters["frames"]["ok"]
out = model(x)
after = dynamo_utils.counters["frames"]["ok"]
if after > before:
print(f"Step {step}: compiled {after - before} new frame(s)")10. Implications for long runs
Before: every cadence boundary (steps 12, 24, 36, …) risked flipping a closure guard → full recompile. On a 10,000-step run at cadence-12, potentially 800+ recompilations.
After: guards no longer depend on mutable hook state. Compilation is a one-time warmup cost, amortizes to zero.
11. Takeaways
fullgraph=Truedoes not prevent recompilation. Guard failures restart compilation regardless.- Hooks run inside the compiled region. Mutable state they close over = guaranteed recompile when it mutates.
@torch._dynamo.disableon hook bodies is the blunt-but-correct fix. Accept the graph break; keepfullgraph=False.- Pre-materialize buffers. Lazy reassignment in
forwardtriggers shape-guard recompiles. - tlparse recompile reasons name the exact guard with file and line. Read the guard; don’t guess.
- Frame count matters, cache state doesn’t. Target: one frame, compiled once.
Artifacts
- W&B runs:
9s1nfa9e,nnees7u2,0n8chzxe,1z8r401dinashwin-ms-does-ai/transformer-room-baseline - Local tlparse:
tl_out/,tl_out_full_graph/,tl_out_pos-enc-fix/,tl_out_diabled_dynamo/,tl_out_diabled_dynamo_attn_entropy/ - Modified:
src/components/positional/positional_encoder.py,src/training/metrics/plugins/forward_hook_metrics.py - Reference: torch.compile, the missing manual (Ed Yang, Jul 2024)
Resources to read, in order
-
“Torch Compile the missing manual” - https://youtu.be/rew5CSUaIXg + Doc
-
Guards and such in torch compile: https://torchcompile-guards.hashnode.dev/inside-torchcompile-guards-how-they-work-what-they-cost-and-ways-to-optimize
1. Start here - the official conceptual overview: https://pytorch.org/docs/stable/torch.compiler_deepdive.html
This explains TorchDynamo (the Python bytecode tracer), AOTAutograd (how backward graphs are compiled separately from forward), and Inductor (the backend). The key insight is that the compiler stack is layered:
- Dynamo traces Python bytecode and extracts sub-graphs
- AOTAutograd uses those graphs to generate both fwd and bwd graphs ahead of time
- Inductor compiles those to CUDA kernels
fullgraph=True forces Dynamo to emit exactly one graph for the entire function, which then lets AOTAutograd create a clean fwd/bwd pair.
2. Graph breaks - what causes them and how to detect them: https://pytorch.org/docs/stable/torch.compiler_troubleshooting.html
You can run TORCH_LOGS=graph_breaks python train.py to see exactly what’s causing breaks in your training loop. I’d strongly recommend doing this once - it’ll tell you exactly where your metrics code is breaking the graph.
3. The AOTAutograd design doc (deeper, worth reading if you want to understand why bwd is a separate graph): https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html
The punchline: AOTAutograd traces both forward and backward at compile time using a single joint trace, then partitions it into a forward graph (saves activations) and backward graph (uses those saved activations). That’s why in fullgraph=True the backward didn’t retrace - its graph was already fixed at compile time and doesn’t depend on what Python branching happens in the forward at runtime.