The modded nanogpt speedrun, but in JAX and on TPUs
TL;DR: New speedrun here. Details below. I generally experiment in PyTorch on my local GPUs. Recently, I received some TPU credits from the TPU Research Cloud (TRC) program to do some research, which gave me a good reason to finally learn JAX. Near about all I knew was that you could often replace numpy with jax.numpy and get code to run on accelerators, and that JAX code is supposed to be functional. This was also the first time I wrote code for TPUs. With this, I set out on a small side project: porting the modded nanoGPT speedrun to pure JAX with the goal of achieving the best possible performance. ...