Project: Transformer Efficiency Lab (TEL), Phase 1 Stage 1 Repo: Pythonista7/transformer-room & feat branch ph1-stg1 W&B: ashwin-ms-does-ai/transformer-room-baseline Hardware (probes): A100 (Thunder Compute) and local M-series MPS

TL;DR

Metric-collection callbacks were migrated to forward_hooks to avoid graph breaks under torch.compile. Instead of getting faster, a metric-heavy “collision” step became ~4× slower than the equivalent pre-migration run. Root cause: guard-failure-driven recompilations, not graph breaks. Three independent bugs each forced a full recompile:

  1. A lazily-grown positional-encoding buffer (pos_enc_cache) whose shape changed on the first forward.
  2. An activation-norm forward hook whose closure captured a mutable boolean. Dynamo walked __closure__[0].cell_contents and installed a guard on collector.capture_activation_norms.
  3. The attention-entropy hook had the same closure-cell pattern, first on capture_attention_entropy, then on the _attention_stash_by_layer dict.

Fix: pre-materialize the pos-enc buffer in __init__; decorate all hook bodies with @torch._dynamo.disable. Final state: 5 dynamo frames (one per @disable boundary), each compiled once, zero guard-failure recompiles.


1. The setup and the surprise

Phase 1 latency-cadence probe, collision variant fires every metric group at step 12:

  • d_model=768, n_heads=12, n_layers=12 (~124M)
  • seq_len=1024, micro_batch=96, accum=5, effective_batch=480
  • log_every_n_steps=4, max_steps=16

Migration commits:

abcc124  fx: use hooks for attn entropy instead of callbacks that cause graph-breaks on compile
7a8ebaf  ch: torch compile on everything but cpu
60654f6  fx: calc attn entropy off graph
f94a7ca  fx: efficiently compute norms
a04faf7  fx: compute on device and only move scalars to avoid D2H overheads

W&B collision-run step timings (ms)

GroupRun IDstep 4step 8step 12step 16epoch_s
test-3 (callbacks)9s1nfa9e13,16213,24242,87813,119261.57
test-3-full-graph-modennees7u212,94710,9388,0723,567155.04
test-4-hooked-metrics0n8chzxe4,1084,11049,78110,348228.41
test-4-full-graph1z8r401d47,54821,94579,1444,111627.90

Forward/backward breakdown for 1z8r401d:

stepforward_msbackward_msstep_ms
41,33743,55547,548
81,43017,98621,945
1230,04445,50179,144
161,5231034,111

Backward dominating at step 4/8 (where no cadence fires) pointed at compilation, not metric cost.


2. Key passages from torch.compile, the missing manual

Ref “Torch Compile the missing manual” - https://youtu.be/rew5CSUaIXg + Doc

  • Hook compilation is deferred to backward() - explains the huge first-few-step backward times.
  • fullgraph=True constrains graph breaks, not recompilations.
  • “If you have a branching conditional and you need both branches to be compiled, we MUST compile twice, or you must figure out if you can use torch.cond.”

3. tlparse diagnostic

TORCH_TRACE="/tmp/tracedir" python experiments/phase-1/ph1_latency_cadence_probe.py --variant collision
tlparse /tmp/tracedir

tl_out/ showed 4 compilations of frame 0. Reasons (verbatim from recompile_reasons_68.json):

0/0: tensor 'self._modules['pos_encoding']._buffers['pos_enc_cache']'
     size mismatch at index 0. expected 0, actual 128

0/1: self._modules['dec_layers']._modules['0']._modules['multi_head_attention']
     ._forward_hooks[list(dict.keys(...))[0]]
     .__closure__[0].cell_contents.capture_activation_norms == False
     # forward_hook_metrics.py:266 in activation_hook

0/2: not self._modules['dec_layers']._modules['0']._modules['multi_head_attention']
     ._forward_hooks[list(dict.keys(...))[0]]
     .__closure__[0].cell_contents.activation_norms
     # forward_hook_metrics.py:272 in activation_hook

fullgraph=True produced the same reasons. Confirmed: fullgraph doesn’t stop recompiles.


4. Bug 1 - lazy pos_enc_cache

# original
self.register_buffer('pos_enc_cache', torch.empty((0, d_model)), persistent=False)
# in forward:
self.pos_enc_cache = torch.cat((self.pos_enc_cache, extra), dim=0)

Dynamo guarded on shape [0, 768]; first real forward grew it to [128, 768]; guard failed; recompile.

Fix: materialize full cache at __init__ using max_seq_len. Never reassign in forward.

Validation: tl_out_pos-enc-fix/ shows 3 frames (0/0 0/1 0/2). The pos-enc guard is gone.


5. Compiler caches - what fx_graph_cache_miss / aotautograd_cache_miss mean

These are compile-time caches, not runtime caches.

  • fx_graph_cache: keyed on FX graph + input metadata + inductor config. On hit, skips Inductor lowering + Triton codegen.
  • aotautograd_cache: keyed on joint graph. On hit, skips fw/bw partitioning.

In tl_out_pos-enc-fix, [0/2] shows aotautograd_cache_miss + fx_graph_cache_miss - the third compilation is a structurally different graph (the now-active capture_activation_norms=True branch pulls in new tensor ops), not a re-validation. Frame count × compile_id is what matters for runtime; cache state is informational.


6. Bug 2 - activation-hook closure guard

Dynamo’s guard path:

self._modules[...].multi_head_attention
    ._forward_hooks[...]
    .__closure__[0].cell_contents
    .capture_activation_norms

Config has skip_nnmodule_hook_guards: true - but that flag only skips guards on hook registration, not closure cell contents.

Fix:

--- a/src/training/metrics/plugins/forward_hook_metrics.py
+++ b/src/training/metrics/plugins/forward_hook_metrics.py
@@ -261,7 +261,7 @@ def register_forward_metric_hooks(
         layer = dec_layers[layer_idx]
         label_tuple = tuple(labels)
         collector.register_attention_layer(layer_idx, label_tuple)
-
+        @torch._dynamo.disable
         def activation_hook(_module, _inputs, output, label_tuple=label_tuple):
             if not collector.capture_activation_norms:
                 return

Validation: tl_out_diabled_dynamo/ shows [0/0] [1/0] [2/0] [3/0] [3/1] [3/2] [4/0]. Frame explosion expected (disable = graph break). But [3/1], [3/2] exposed a second closure-cell problem on the attention-entropy hook, previously masked because dynamo reports only the first-failing guard. From recompile_reasons_109.json:

3/0: ...cell_contents.capture_attention_entropy == False
     # forward_hook_metrics.py:79 in stash_attention_projection_once

3/1: not ...cell_contents._attention_stash_by_layer
     # forward_hook_metrics.py:100 in stash_attention_projection_once

7. Bug 3 - attention-hook closures

Same @torch._dynamo.disable treatment on packed_proj_hook and entropy_hook.

Validation: tl_out_diabled_dynamo_attn_entropy/ shows 5 frames, no guard-failure recompiles. Two show an _1 suffix: [3/0_1], [4/0_1].

The tlparse compile-id format is [frame_id/compile_id_attempt]:

  • compile_id incrementing = recompile (distinct graphs).
  • attempt incrementing = dynamo restarted the same compilation (speculative rollback + retry). Only the final attempt produces a runnable graph.

Each frame compiled exactly once.


8. Mental model clarifications

  • The graph = FX graph built by tracing tensor ops inside compiled_model(...). Timing code, loss_fn, loss.backward(), optimizer.step(), metric Python are all outside.
  • The compiled region = scope of torch.compile(...). fullgraph=True is a constraint on that region, not a different region.
  • Callbacks ran in eager Python after compiled_model(...) returned - invisible to dynamo.
  • Hooks run during the compiled forward - fully visible to dynamo; any mutable state they close over becomes a guard candidate.
  • @torch._dynamo.disable makes a function a black box. Tradeoff: graph break at each disable boundary. Incompatible with fullgraph=True, fine with fullgraph=False.

9. Tracking recompilation in production

  • TORCH_LOGS=recompiles - prints failing guard on each recompile.
  • TORCH_TRACE=/tmp/tracedir + tlparse /tmp/tracedir - full HTML report.
  • torch._dynamo.utils.counters["frames"]["ok"] - identifies which steps trigger compilation:
import torch._dynamo.utils as dynamo_utils
before = dynamo_utils.counters["frames"]["ok"]
out = model(x)
after = dynamo_utils.counters["frames"]["ok"]
if after > before:
    print(f"Step {step}: compiled {after - before} new frame(s)")

10. Implications for long runs

Before: every cadence boundary (steps 12, 24, 36, …) risked flipping a closure guard → full recompile. On a 10,000-step run at cadence-12, potentially 800+ recompilations.

After: guards no longer depend on mutable hook state. Compilation is a one-time warmup cost, amortizes to zero.


11. Takeaways

  1. fullgraph=True does not prevent recompilation. Guard failures restart compilation regardless.
  2. Hooks run inside the compiled region. Mutable state they close over = guaranteed recompile when it mutates.
  3. @torch._dynamo.disable on hook bodies is the blunt-but-correct fix. Accept the graph break; keep fullgraph=False.
  4. Pre-materialize buffers. Lazy reassignment in forward triggers shape-guard recompiles.
  5. tlparse recompile reasons name the exact guard with file and line. Read the guard; don’t guess.
  6. Frame count matters, cache state doesn’t. Target: one frame, compiled once.

Artifacts

  • W&B runs: 9s1nfa9e, nnees7u2, 0n8chzxe, 1z8r401d in ashwin-ms-does-ai/transformer-room-baseline
  • Local tlparse: tl_out/, tl_out_full_graph/, tl_out_pos-enc-fix/, tl_out_diabled_dynamo/, tl_out_diabled_dynamo_attn_entropy/
  • Modified: src/components/positional/positional_encoder.py, src/training/metrics/plugins/forward_hook_metrics.py
  • Reference: torch.compile, the missing manual (Ed Yang, Jul 2024)

Resources to read, in order

1. Start here - the official conceptual overview: https://pytorch.org/docs/stable/torch.compiler_deepdive.html

This explains TorchDynamo (the Python bytecode tracer), AOTAutograd (how backward graphs are compiled separately from forward), and Inductor (the backend). The key insight is that the compiler stack is layered:

  • Dynamo traces Python bytecode and extracts sub-graphs
  • AOTAutograd uses those graphs to generate both fwd and bwd graphs ahead of time
  • Inductor compiles those to CUDA kernels

fullgraph=True forces Dynamo to emit exactly one graph for the entire function, which then lets AOTAutograd create a clean fwd/bwd pair.

2. Graph breaks - what causes them and how to detect them: https://pytorch.org/docs/stable/torch.compiler_troubleshooting.html

You can run TORCH_LOGS=graph_breaks python train.py to see exactly what’s causing breaks in your training loop. I’d strongly recommend doing this once - it’ll tell you exactly where your metrics code is breaking the graph.

3. The AOTAutograd design doc (deeper, worth reading if you want to understand why bwd is a separate graph): https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html

The punchline: AOTAutograd traces both forward and backward at compile time using a single joint trace, then partitions it into a forward graph (saves activations) and backward graph (uses those saved activations). That’s why in fullgraph=True the backward didn’t retrace - its graph was already fixed at compile time and doesn’t depend on what Python branching happens in the forward at runtime.