The modded nanogpt speedrun, but in JAX and on TPUs

TL;DR: New speedrun here. Details below. I generally experiment in PyTorch on my local GPUs. Recently, I received some TPU credits from the TPU Research Cloud (TRC) program to do some research, which gave me a good reason to finally learn JAX. Near about all I knew was that you could often replace numpy with jax.numpy and get code to run on accelerators, and that JAX code is supposed to be functional. This was also the first time I wrote code for TPUs. With this, I set out on a small side project: porting the modded nanoGPT speedrun to pure JAX with the goal of achieving the best possible performance. ...

August 21, 2025 · 10 min · 2125 words · nor

Theoretical properties of optimizers on a toy problem, and some intuition

Table of Contents Introduction The Setting The Optimization Problem Subgradient of the Loss Function The Family of Geometric Optimizers A Convergence Theorem Convergence Bounds for Specific Scenarios Normalized Gradient Descent Muon Adam/SignSGD A Note on Bound Tightness and Rate Variation Analysis of Rate Variation Analysis of Bound Tightness Using This In Practice: A Comparison of Optimizers The Problem with Vanilla Gradient Descent Geometric Optimizers as the Solution Comparing the structure for Normalized Gradient Descent, Muon, and Adam/SignSGD The Nature of the Updates Building Intuition from the Convergence Bounds The Meaning of \(C_{\text{lower}}\) and \(C_{\text{upper}}\) Rate Stability and Comparing Optimizers Further reading Introduction We want to answer the question: how do optimizers fundamentally differ in their approach to finding a minimum? To explore this question in a controlled environment, we focus on a simple problem: minimizing the matrix error term \(\min \lVert AX - B \rVert\). ...

August 2, 2025 · 33 min · 6841 words · nor

Deriving RoPE the proper way

Figure 1: Attention score similarity with our 3D RoPE. These plots show how positional similarity changes across an axis. Introduction RoPE has become the de facto positional embedding for transformer models. Its popularity mainly stems from its performance, but the “derivation” in the paper is also quite elegant (but flawed). Implementing high dimensional RoPE also pushes us to think about generalizing the underlying ideas as much as possible (alongside using signal processing intuition) - there’s code at the end of the post that implements things based on the ideas we develop here. ...

July 28, 2025 · 20 min · 4050 words · nor

Quantizing LLMs for inference

Motivation Let’s start by doing some arithmetic about large language models (LLMs). These are neural networks with huge parameter counts, with state-of-the-art open-weights models (i.e., ones you can download) having parameter counts of the order of 100B (\(10^{11}\)) or so (and usable ones around one order of magnitude smaller). Take the latest SOTA release Qwen 3 235B-A22B, for instance, which has roughly 235B parameters. If all these parameters were to be stored in a naive array of 32-bit (4 byte) floating point numbers, this model would require around 940 GB of storage as well as memory for a usable speed. Running this model purely on CPU with dual channel DDR4 RAM (which is likely the kind of RAM you have on your computer) would take you multiple seconds to output a single token/word (and even this is quite fast for the total size of the model because the architecture is what is called a Mixture of Experts, more on that later, so don’t worry yet). ...

May 14, 2025 · 31 min · 6410 words · nor
>