Tasks

  • Finalise a base model for this stage of experiments.
  • Benchmark / Test setup for true testing model generalisation beyond a single training set.
  • Setting up HuggingFace upload pipeline post training.

Finding Problems

Tokenizer Bottlenecks

I have already started to see some bottlenecks and compounding issues! I intend to switch to open-web-text dataset and at least have 1B tokens for training but the way i currently load up, tokenise and use the data is very memory bound and inefficient, the currently code cannot handle things like:

  • "\n\n".join(segments)
  • BPETokenizer._train() corpus = list(text.encode("utf-8")) , this would probably overflow the RAM.
  • The current dataset and pipeline just does not fundamentally support streaming, which I need it to in order to scale up the experiments.

But the good news is that the training pipeline shouldn’t really be affected by this, most of the issue seems to be within the tokeniser code. As i see it we have 2 issues,

  1. BPE training on a 1B corpus is just going to take too long, even running for 50-100k would be time consuming and might still not be in the best interest as i scale to bigger experiments.
  2. Currently the corpus is being tokenised and materialised completely at training time, there isnt a true-streaming capability as of now.

Solved in this commit

Metric calculation and aggregation stalling training

# Cadances as per the config of the run
	log_every_n_steps= 25, # for step metrics
	diagnostics_every_n_steps= 100, # default for diagnostics
	layer_grad_norms_every_n_steps= 500,
	parameter_optimizer_norms_every_n_steps=500, # for update/weight ratio freq
	attention_entropy_every_n_steps=250,

Some of the metrics calculation is being offloaded to the CPU which takes up too much time, figure out a way to perform this computation either async or on gpu and only transfer the final scalar to the CPU for logging.

  • Why why there isnt a spike at step 750 is a mystery to me !?
  • Also idk how only step 100 takes that long for basic diagnostics but every other step%100 is low?

Need a way to isolate timings for each metric computation, currently there is not way to tell what contributes how much in:

StepWhat firesSpike magnitude
100diagnostics (activation norms)~19s
250attention entropy only~29s - larger
500diagnostics + attention entropy~20s

This was attempted to be solved in this commit

Stateful Loader Multi Processing and Prefect

Current i had to disable workers and prefect for stateful loader but as we scale up this might bottleneck me. Also need to add a hugging face api key when data loading to avoid strict api limits

Will trying in this commit

 train_loader = StatefulDataLoader(
        train_dataset,
        batch_size=batching.loader_batch_size,
        pin_memory=pin_memory,
-       num_workers=0,        
+       num_workers=2,
+       prefetch_factor=2,
+       multiprocessing_context="spawn"
    )

After some tinkering i find that this too does not help, so I will be staying at num_worker=0 until i identify the dataloader to be the bottleneck.

Possible Improvements

  • use multi proc dataloader to speed things up commit 👎🏻
  • maybe do async/bg-thread metric diagnostics, because metrics like attention entropy and large grad norms perform matrix heavy ops on CPU which stalls training for too long.

This was attempted to be solved in this commit, where we do a duplicate computation of attention because the SDPA does not expose the intermediate tensors.


Bigger GPUs

The original GPT-2 was probably train on V100 or TPUs of the era, and training gpt2 made famous by @karapathy on A100s and recently on H100s. So using a 40GB A100 or RTX A6000 doesn’t really make sense at the moment.

Some behind the napkin budgeting

We are using a 10B Token dataset: HuggingFaceFW/fineweb sample-10BT.

Chinchilla recommends 20 tokens per param as the token budget for reasonable training runs. While llama recommends a 1000:1 for token:param, which lands us around 125B tokens on a single run which kinda insane to try at the moment! (also the above dataset is only 10B tokens)

On an 40GB A100, including torch.compile and final model upload, the train time for B=128 MB=32 @ 1k steps was 34mins.

So a chinchilla regime approximation 20 toks : 1 param would mean 19k steps taking around 11-11.25hrs on the 40GB-A100

Thats a bit too expensive and it makes sense to upgrade to a better GPU at this point, looking at my options i see the A100 vs H100 both at 80GB while the H100 costs close to 2x the price of A100.

First the new 80GB VRAM easily allows doubling the micro-batch size to 64 and the effective batch to 512. ( mb 2x , eff_batch 4x ) Considering the same throughput increasing the mb should not since the num of matrix ops is same only the size of the matrix has increased. The eff_batch should making the step time also go up 4x since we are going to be running more grad_acc before a backward pass to hit the 512 batch-size.

Basically we should have 2x the speed compared to the 40GB-A100 on the same task cutting the train time to at least around 5.5hrs on a 80GB-A100 (assuming all other latencies are constants like torch compile, data loaders etc).

Based on some redditors and googling i hear that a H100 can potentially provide a 2-3x performance gain over a A100, which can potentially drop the training run to sub ~2-2.5hrs which is great!

I will try to run a small experiment comparing the 2 step times before committing to a large run. Experiment details below.

Moving from 40 80GB A100

wandb run group: phase1/stage-1

In a quick dry run comparison i immediately spotted this issue of randoms stalling in train, and these spikes coincide with the metrics aggregation cadence which is likely the issue, hence i ran a small probe run to isolate which metric gates might be the root cause and found the following:

wandb run group: phase1/stage-1/latency-cadence-probe

These results are on wandb under the group phase1/stage-1/latency-cadence-probe , it is pretty evident that the issue stems from the metric collection and affects both forward and backward passes:

  • Forward pass:

    • seems severely affected by diagnostics, probe-collision and attention entropy.
  • Backward pass:

    • the control has very low latency but every other variant is equally slow so im not sure why.

These are the cumulative total:

MetricsEngine timing summary (ms):
  after_backward:
    global_grad_norm: 586.589 !!
    layer_grad_norm: 7.411
    layernorm_grad_norm: 0.014
    loss_metrics: 0.010
    parameter_optimizer_norms: 0.003
    step_timing_memory: 0.003
    forward_hook_metrics: 0.002
  after_microbatch_backward:
    loss_metrics: 0.094
    step_timing_memory: 0.016
    global_grad_norm: 0.014
    layer_grad_norm: 0.012
    forward_hook_metrics: 0.012
    layernorm_grad_norm: 0.012
    parameter_optimizer_norms: 0.010
  after_optimizer_step:
    parameter_optimizer_norms: 1224.786 !!
    step_timing_memory: 0.070
    loss_metrics: 0.008
    forward_hook_metrics: 0.003
    layernorm_grad_norm: 0.003
    global_grad_norm: 0.003
    layer_grad_norm: 0.002
  collect_epoch_metrics:
    loss_metrics: 0.013
    step_timing_memory: 0.009
    global_grad_norm: 0.001
    layer_grad_norm: 0.000
    layernorm_grad_norm: 0.000
    parameter_optimizer_norms: 0.000
    forward_hook_metrics: 0.000
  collect_step_metrics:
    forward_hook_metrics: 3.022
    loss_metrics: 0.060
    step_timing_memory: 0.016
    layer_grad_norm: 0.007
    global_grad_norm: 0.006
    layernorm_grad_norm: 0.006
    parameter_optimizer_norms: 0.005
  on_step_start:
    parameter_optimizer_norms: 948.666 !!
    step_timing_memory: 0.210
    forward_hook_metrics: 0.085
    loss_metrics: 0.018
    layer_grad_norm: 0.018
    layernorm_grad_norm: 0.013
    global_grad_norm: 0.012
  on_train_end:
    forward_hook_metrics: 0.088
    parameter_optimizer_norms: 0.003
    step_timing_memory: 0.001
    loss_metrics: 0.001
    global_grad_norm: 0.001
    layer_grad_norm: 0.000
    layernorm_grad_norm: 0.000
  on_train_start:
    forward_hook_metrics: 0.081
    step_timing_memory: 0.004
    loss_metrics: 0.001
    global_grad_norm: 0.000
    layer_grad_norm: 0.000
    layernorm_grad_norm: 0.000
    parameter_optimizer_norms: 0.000
 

Check handling of optim param groups in optimizer metrics, some have assumed only adam optim and its params, make sure such things are gated and config validation catches and fails these early and not during a train run.


Side quest : torch.compile recompilation spikes (resolved)

Before kicking off the Stage 1 baseline training run at scale, I got pulled into what looked like a small instrumentation question and became a several-day detour into torch.compile internals.

The trigger was a latency-cadence probe five short runs that sweep different metric-collection cadences, including a collision variant that fires every diagnostic group at the same step. The collision variant was showing very large step-time spikes (tens of seconds on A100), and in a particularly bad configuration the end-to-end run was ~4× slower than the supposedly-worse pre-fix baseline. That didn’t line up with any reasonable theory of “metrics are expensive,” so I dug in.

The short version: three independent bugs were each forcing torch.compile to recompile the forward graph.

  1. A lazily-grown positional-encoding buffer whose shape changed on the first forward.
  2. A forward hook whose closure captured a mutable boolean dynamo walked into __closure__[0].cell_contents and guarded on it, and the flag flipped on every metric-cadence step.
  3. The attention-entropy hook had the same closure-cell pattern, first on a capture flag, then on a stash dict transitioning from empty to populated.

The fix was to pre-materialize the pos-enc buffer in __init__ and decorate the hook bodies with @torch._dynamo.disable. End state: five distinct dynamo frames (one per disable boundary), each compiled exactly once, zero guard-failure recompiles. Cadence boundaries are now free only the genuine cost of the metric computation shows up in the step timing.

The full writeup, with verbatim tlparse guard strings, W&B run IDs, step-level timing tables, and the commit-level diff of the fixes, is in Debugging torch.compile & stories of recompilation.

Two things to carry forward from this:

  • The metrics infrastructure is now safe to use with torch.compile at any cadence. This unblocks serious use of activation norms and attention entropy on long runs without worrying about recompilation cost scaling with metric frequency.
  • The tlparse + TORCH_LOGS=recompiles workflow is now second nature and is the default diagnostic any time a training loop shows unexplained step-time variance under compile.

Back to the plan

With the compile pipeline stable, Stage 1 returns to its original trajectory.