TL;DR: New speedrun here. Details below.
I generally experiment in PyTorch on my local GPUs. Recently, I received some TPU credits from the TPU Research Cloud (TRC) program to do some research, which gave me a good reason to finally learn JAX. Near about all I knew was that you could often replace numpy
with jax.numpy
and get code to run on accelerators, and that JAX code is supposed to be functional. This was also the first time I wrote code for TPUs. With this, I set out on a small side project: porting the modded nanoGPT speedrun to pure JAX with the goal of achieving the best possible performance.
This whole process took about seven days. Two were for learning the basics, three were for the work described here, and the other two were for, well, other things. Writing (even raw) JAX code has been an absolute joy (compared to PyTorch). The functional paradigm feels like home, and XLA is a powerful compiler. The integration with Google’s ecosystem, especially for TPUs, is excellent - the JAX profiler, in particular, is an experience I have never had with PyTorch. And many forms of parallelisms are trivial to express in JAX (due to how it handles sharding) while they are non-trivial and require libraries in PyTorch (though PyTorch’s H2 2025 plans indicate a shift towards automatic sharding).
My goal with this post is to document the process, the challenges, the optimizations, and the remaining questions. The code for my initial JAX port can be found here. Note that the code is in a single file, true to the spirit of the original speedrun.
Also, if you’re a JAX enjoyer or like speedruns (or just enjoy low level optimizations), feel free to optimize this further! I haven’t decided on the rules of the speedrun yet, though I believe they should be in the spirit of the original speedrun.
The naive port
As a tutorial for myself, I decided to port the version of the nanoGPT speedrun just before the introduction of “flex attention,” - this version had a record of around 7 minutes on 8xH100 GPUs. I implemented the model and training loop in pure JAX, without relying on libraries from the ecosystem like Flax, Optax, or Orbax. This meant writing everything from scratch: raw PyTree manipulations for parameters and optimizer states, and a custom training loop.
My first run clocked in at around 20 minutes. A 3x slowdown is unacceptable. It was time to understand what was going wrong.
Hardware, Bottlenecks, and a Performance Bug
The first step was to understand the hardware. I was running on a TPU v6e-8. While its compute capabilities are roughly the same as an 8xH100 SXM, its memory bandwidth is about half of that, and comparable to an 8xH100 PCIe setup. The profiler confirmed my suspicions: the model was entirely HBM-bound, with a Model Flops Utilization (MFU) of around 15% (some bad performance was expected because the model was quite small). Even worse, HBM utilization was hovering around 50%. At this point I should mention that the JAX profiler gives you information in an extremely convenient manner - MFU, resource utilization, roofline models and what not. I’ve been told that it sometimes gives inaccurate results on some GPUs, but for TPUs you’d be hard pressed to find a better profiler.
My initial attempts to improve HBM utilization failed - this was mostly about writing better data loading, but that was not the bottleneck. On another note, I tried to integrate a “FlaxAttention” kernel I found, thinking it might be a JAX equivalent to PyTorch’s flex attention, but it didn’t support the block-sparse operations I needed.
The first real improvement came from an architectural tweak. The systolic array on a v6e TPU core is \(256 \times 256\). My initial configuration had 6 attention heads with a d_head
of 128. I changed this to 4 heads with a d_head
of 256 to better align with the hardware, and increased d_model
from 768 to 1024 to maintain model capacity and achieve loss parity.
This change seemed to help, and with this and some other changes that came later in the original speedrun, I was able to reduce the number of steps from 3000 to 1675.
It’s interesting to note that at this point, I actually had an embarrassing performance bug. The bug was that when I was updating optimizer state, the parameters were being upcasted to 32 bit floats, for which JAX uses tf32 by default for matmuls (instead of fp32, which also incidentally leads to some divergence with PyTorch). This didn’t have a 2x performance hit with the new head size, but not knowing about this bug made it super confusing to me as to why the profiler reported an arithmetic intensity of ~190 while it should have been much higher according to my calculations with bf16 flops. The way I actually found out was when I did ahead-of-time compilation (for different window sizes for attention window warmup) and the compiled function complained to me that bf16 and fp32 don’t actually mix. This bug was also leading to a second jax.jit
compile that I didn’t notice with the new types. Fun fact - I switched to ahead-of-time compilation after a jit triggered in the middle of my training code with the attention window warmup, making it a minute slower. So jax.jit
can bite you in the back sometimes, but jax.log_compiles
is there to help you.
Note that after I switched over to bf16, the original head dimension became much faster too, but even at the end of 3000 steps, the loss was higher than the expected 3.28 validation loss (and 3000 steps already took 12+ minutes). I suspect some issues with optimizer precision, but I didn’t look into it much. Looking into this further can probably help shave off some seconds from the speedrun. All in all, I stuck to my newer choice of head dimension, and learnt the lesson of not taking the functional programming analogy too far (JAX does not complain about types unless you compile it, which makes sense in retrospect).
The optimizations
Both before and after the dimension change and the bugfix, I did a series of changes. Note that these changes were done in a single order, so I didn’t do many ablations after the optimizations were done. It’s completely possible that mixing and matching some of these gives better performance than my speedrun.
Replacing
vmap
with a batch-aware-from-the-start implementation gave a slight improvement.It turns out that the canonical
tf.data
recommendation for the data loader does not work on TPU v6e-8, so I wrote my own asynchronous (well, fully prefetching) data loader (and the metric logger is written in order to avoid unnecessary blocking calls to print the current loss - it prints a stale loss). This didn’t help much (perhaps at all), but I still kept it in.A minor rewrite of the cross-entropy loss function also helped performance slightly.
I tried using some custom flex attention kernels for JAX (turns out they didn’t have block sparsity, so were not useful), flash attention kernels (were slow), splash attention kernels (were also slow). I didn’t spend much time with the latter two, so it’s possible I was missing something.
On attention - I noticed that my hand-written attention implementation was faster than
jax.nn.dot_product_attention
- I’ve been told that the XLA compiler pattern-matches for attention, which is nice. However, I stuck tojax.nn.dot_product_attention
in order to get improvements from JAX version updates, in case this behavior stops working arbitrarily.Regardless, I tried porting some improvements after the flex attention improvements in the official speedrun, up until the sub-3-minute record. These are the ones that I ended up keeping:
- Attention window warmup (from 1024 to 2048, halfway through training). Note that dynamic shapes are not supported, so I did ahead-of-time compilation for both shapes.
- Improved RoPE implementation (half RoPE, half NoPE).
- Logit softcap change.
- Merged QKV weights (my original code already did this).
- Learned attention scale (instead of inverse square root of dimension).
- Reduced Adam epsilon.
These are the ones I didn’t keep/implement (might still work though):
- Value embeddings - couldn’t get either the vanilla implementation or the U-Net to work without a major slowdown earlier.
- Block sliding window - I was lazy and didn’t want to implement sliding window attention, though it’s certainly possible.
- FP8 - Not available on TPU v6e (though v7p will likely have it). I am considering
int8
training for future versions, given that the flops and bytes per second both double. - Batched Muon - I was lazy and did not implement this.
- Custom hardware/PyTorch optimizations - Skipped custom communication strategies and other parallelism tricks from the official speedrun.
After these changes, the number of required iterations went down and the MFU improved to 23%, but HBM utilization remained stuck at around 50%. My latest run finished in 10 minutes. Being generous and assuming the HBM-bound nature of the workload means performance scales linearly with memory bandwidth, this might be equivalent to 5 minutes on an 8xH100 SXM machine. That’s a significant improvement, but still leaves a 2-minute gap to the sub-3-minute record. Some other architectural changes I tried, like token mixing, generally worsened the wall-clock time or didn’t affect it much, after showing initial promise in the early parts of the loss curve. It’s also possible that they need more hyperparameter tuning that I didn’t want to do just yet.
Future Work and Open Questions
This was pretty fun, but there’s still a lot of performance on the table. Some things I’d like to try next/see someone try out next:
- I’m very curious how this JAX implementation performs on an 8xH100 SXM node.
- More hparam tuning could shave off some steps.
- The
tf32
bug was a slightly embarrassing but good lesson. There may be other, similar low-hanging fruit. - The question remains: why is HBM utilization only ~50%? Some of the candidates could be suboptimal overlapping, kernel launch overhead, or communication issues not fully saturating the memory bus.
- The official speedrun saw some benefit from manual communication primitives. This could be explored in JAX.
- Custom kernels in Pallas (lowers to Mosaic on TPUs and Triton on GPUs):
- A Pallas kernel for block-sparse flex attention could be useful. I wasted a lot of time trying to integrate a kernel that claimed to do this but wasn’t actually block-sparse.
- I tried integrating existing Flash/Splash Attention implementations, but they didn’t work on the first attempt, so I moved on. It would be great if someone could get these working.
- A custom Pallas kernel for the cut cross entropy loss could help improve MFU (according to someone whose MFU on GPUs got fixed after using a CUDA kernel for that).
- The computation for the optimizers likely happens on every PyTree leaf individually and with replicated computation across shards. This may or may not a bottleneck right now, and could be helpful to keep in mind (to avoid extraneous computation/unfused ops). Sharding computation requires using an all-gather which incurs overhead and should be measured, as with everything else.
- Using Microsoft’s Dion optimizer instead of Muon could be interesting.
- Going a bit further, trying out other parallelism strategies like FSDP/TP/hybrid strategies could be looked into.
The JAX Ecosystem
For this project, I deliberately stuck to pure JAX to force myself to learn the basics (and avoid any issues with performance due to opaque APIs). However, the JAX ecosystem is rich and mature (albeit slightly fragmented). For real-world projects, you would likely use libraries like:
- Marin, Levanter, Maxtext or Tunix for large scale training/post-training.
- Optax for optimizers.
- Orbax for checkpointing.
- Equinox (highly recommended) or Flax for writing models.
JAX’s device-agnostic nature is also a huge plus, with native support for TPUs, GPUs, and CPUs, and the same code Just Works without modifications.
Note that this doesn’t mean I will stop using PyTorch - it only means that I would use JAX wherever it makes sense to use it (e.g., personal projects, collaborative projects where everyone likes to use JAX, for portability reasons etc.). PyTorch is a huge part of the ML ecosystem and extremely valuable considering most research codebases are in PyTorch.
Resources
Since this was a learning project, these are the resources I found super useful (but didn’t necessarily complete), and my implementation of the speedrun:
- My code: modded-nanogpt-jax. Also on GitHub for collaborating on speedruns.
- The unusually well-written JAX documentation (the current training cookbook has some bugs though -
jax.ref
didn’t work for me andjax.P
was wrong). - JAX ML scaling book for understanding performance.
- Rafi Witten’s High Performance LLMs 2024, also for understanding performance.
The JAX community is also super welcoming - in this short timespan, I also learnt a bunch from people who are passionate about JAX.