Context

Now that we have Activation Checkpointing and using torch.compile has significantly helped reduced both memory requirement and train time! Now we take another step in that direction!

Activation checkpointing helps us reduce the memory requirement for storing intermediate results from the forward pass that might be required in the backprop step, this also helps with supporting bigger seq-lens cause attention & activations scale with seq-len.

Now, since we have a smaller footprint of the required intermediates we can now try to perform weight-updates over longer sequences by employing the ideas of Micro-batching and gradient accumulation! Because our current batch size (20) is very small even for the currently used tiny-dataset of 4M tokens, moving to datasets an order of magnitude bigger would make this almost look like “SGD”.

Bigger batches might still be hard to support but the same effect can be achieved with gradient accumulation strategy can be employed with micro-batches to optimally deploy compute for training. With grad-acc we can increase the “effective batch size” without incurring the mem-costs.

Assumptions

effective_batch_size = micro_batch_size * accumulation_steps

This is stringent and can miss some samples at the boundaries but its fine for now.

Questions

  1. What is the largest possible micro-batch size we can use?

Increasing micro-batch size (m-b-sz) should increase the peak gpu usage. Effectively improving hardware utilisation. This will impact memory indirectly if activation-checkpointing is also in place!

  1. What is the largest effective batch size that produces meaningful updates?

Increasing effective batch size does not increase peak-gpu-usage, but it does effect the optimizer updates. Too small and we have very small change in gradients needs longer runs and very big gradients can cause updates to explode and big weights that overfit.

Experiment

Well there are 3 variables at play here:

  1. Activation Checkpointing : which allows us to trade time for gpu memory.
  2. Gradient Accumulation : trades optimizer step frequences for gradient calculation over bigger batches.
  3. Micro Batching: Helps increase effective batch size without needing the real amount of GPU mem to fit it.

So the optimal configuration will be the one that does 2 things, maximize GPU utilization without OOMs and second, having the highest possible training throughput (tokens/sec). The other important metric to track would be the optimizer states and updates and val loss, we need to understand the new learning dynamics for the bigger batch size and what hyperparams need to look like in the new regime.

To meaningfully breakout this experiment we need to do it in stages.

  1. First understand the memory frontier on our hardware. Do a sweep on micro-batch and activation checkpointing to find the best gpu utilisation.

    From the previous AC experiment we saw the baseline-torch.compile outperformed all other variants and hence will be adopting the same for this experiment and only vary the micro-batch size.

    So what we basically will find is :

    After this we should know the trade off between memory and compute time to be able to pick a micro batch size to maximize throughput.

  2. Next find the throughput limit, buy picking the best performing config from step 1 and varying the effective batch size, say if m is the best micro-batch size form step1. Try effective batch sizes of {2m, 4m, 8m, 16m}, unless we dont OOM. We can use this and measure the training throughput (tokens/sec).

  3. Now we can pick and compare the configs with full training runs, the candidate should be:

    1. [baseline] Current model with just compile and eff_batch = microbatch.
    2. [from step1] Model with highest micro-batch size
    3. [from step1] Model with 2nd highest micro-batch size
    4. [from step2] Model with the best throughput (likely the 16m variant)

    Now these candidates can be validated over a longer training run to measure loss and pp. Looking at the update norms , change in param-update from using adaptive-optimizers and gradient norms we should be able to get an idea about the stability and quality of training.

Note: That once we find the best config in step3, we might need to do a new LR sweep to ensure that the model is operating in the correct LR regime for the new effective batch size.


Step 1: Micro Batch Size Sweep

From the previous experiment it can clearly be seen that that baseline-compile with no budgeting gives us the fastest throughput on a fixed batch size of 20 while having relatively low peak-mem-resv on GPU.

Result

Observations:

Drilling deeper into the data we can find that we really only have 2 valuable candidates: mb=[28,30] . We can see a smaller batch size that 28 takes considerably more time for execution while the memory still seems to scale proportionally, why is that?

It is likely a quirk of torch.compile , maybe related to paging/tiling , whereby the kernals are better equipped to handle a certain range of batchsize compared to others ( probably something on the order of 2^x being favoured? )

See the micro batching section on this wandb reprot:

Now that we have a frontier micro-batch size for our model 28,30. Lets pick mb=28 here since it gives us a slight headroom and also has slightly better throughput as per the above experiment.

Next we want to be able to pick the best possible effective-batch-size, which involves gradient accumulation. In this case we are not only interested in the achieving the lowest final loss but also care about the “quality” of gradients per update step, effectively the goal is to reduce the wall-clock time required to train the model to convergence on a given hardware.

Just having more number of updates might provide more “steps” on the same token budget to reduce the loss, but might not be as effective on wall clock time if the batch size is suboptimal since a larger batch can produce better “quality” updates at the same rate (ps: completely depends on hardware, this is just saying if a bigger batch fits into the GPU use that since it faster than doing 2 sequential GPU passes).

This is what we addressed in Step 1 Micro Batch Size Sweep (where effective_batch = micro_batch @ accum=1) when we figure out the largest possible batch-size

Also just having larger effective batch by turning on grad-accumulation, without having better “quality of gradients” is equally wasteful. And even if we did generate better gradients, without doing Learning Rate Scaling , the rate of converge would be greatly impacted by a factor of the accum_freq.

For example: For a sample size of 120, If a model is trained on micro-batch mb and learning-rate lr , for accum_freq=1 , we get 120/mb number of steps, with the effective update being:

For mb=30 , i = 1..4 , and would be updated 4 times, so:

But with accum_freq=4 , with gradient accumulation and uniform normalization before update this can become:

Effectively curbing the convergence speed, even when a bigger batch with accumulation is capable of producing a better direction for the update, its speed gets curbed and isnt really comparable to the high update frequency without accumulation. Hence we need “learning_rate scaling” like in:

Linear scaling rule (Goyal et al. 2017, Facebook ResNet paper):


So in reality we are asking a couple of seemingly confounding questions which can be isolated as follows.

First we need to understand things at a step level/accumulation-level, by finding out how does effective-batch-size effect the update quality ?

The update quality can be measured by gradient norms and SNR like in previous experiments. Also we can additionally track a new metric gradient coherence - cosine similarity between successive gradients within each update-step so the “alignment” of intermediate accumulations can be vetted. Along with this we can also measure some basic statistics of the accumulated gradients namely norm, mean and var between successive accumulations.

This can be tested in a small amount of runs and can tell us if batch size can be increased. We stop increasing the effective-batch-size one we start seeing the gradient quality degrade.

Notice how the above metrics we care about are independent of the learning rate and should not be affected by it. This lets us discover a good estimate for “critical batch size”, we can use this estimate to empirically search for the optimal batch size around it with the inclusion of learning-rate scaling. And doing so while being confident about the gradient quality is key.

Now we can ask, “What is the critical batch size (when lr-scaling is introduced)?”

From McCandlish et al. 2018 (“An Empirical Model of Large-Batch Training”), the gradient noise scale tells you the critical batch size - the batch size below which doubling B roughly halves required steps, and above which you get diminishing returns.

This along with the above sqrt scaling rule can be used to modify the learning rate. With this modification we can measure the loss-vs-tokens which should tell us if the above change in batch-size actually has a meaningful impact on outcomes (loss).


Experiment: Assessing Gradient Quality With Accumulation

My observation:

  • Epoch time drops steadily as we increase the effective batch size. Naturally due to lesser num of grad updates
  • Loss value as expected benefits from more updates without the lr-scaling with the accumulated batches. But thats ok for this run.
  • gradient coherence shows a strong signal with all effective batch sizes tried. We are constantly in the 0.99+ range which shows strong correlation of gradients within a accumulation. Also might indicate that learning rate can be higher (lr scaling again)
  • There seems to a sharp fall in the loss landscape around step 30-40 , with the optimizer momentum picking up.
  • Global grad norms seem to fall and plateau around 12-13 for all runs which is the expected behaviour as the initial runs move the model into optimal regions are updates get smaller as the degree of weight tweaking is not required to be as large.
  • I didn’t calculate the SNR within an accumulated step but do have the Optimizer state snr which again seems to drop and level off.

Wandb Run Group: `4-effective-batch-fixed-lr-20260316-151845-575080

Learning and Realisations: What i realised from running this experiment was that the gradient signal was already very strong and coherent at micro-batch=28 and acc=1 , scaling acc did not really bring in any gradient diversity, its quite possible that the dataset I’m using Wikitext-2 is already quite homogenous and gradients across batches point in the same direct and the model training can benefit more from increasing the learning rate rather than increasing the effective batch size with accumulation. The idea of accumulation helps when micro-batches along are producing incoherent signals, helps smoothen it out and provide the best effective direction. But that is not the case in our runs , at least on this particular dataset. But i did learn to care a lot about all the different things that happen within a step which was nice.

While its true that implementing LR-scaling along with accumulation would help, but that only compensates for the degraded learning-rate as discussed above, but what we originally intended to do with this experiment is unlock better gradient quality with increased effective batch size and gradient accumulation, which is not the case right now. Now that we know where to use this, will include this into the implementation as a config so that i can used when it can really make a meaningful difference.