- Introduction
- What goes wrong at long context
- Training stability and QK-norm
- Looking at the distribution
- Local vs global behavior and inductive biases
- Similar existing literature
- Positional encodings and hybrid attention (local and global)
- Attention sinks and gating
- Revisiting QK-norm and norm information
- Experimental details
- Acknowledgements
- References
- Final notes
Introduction
Meta
One of my recent side-quests is to understand which architectural choices for models scale well, to ensure that future work that interests me remains meaningful.
This post is a byproduct of some exploration along those lines - initially I was going to write about some minor original work (theoretical with some experimental results) on how to scale QK-norm to long contexts well, but I quickly realized it might be better to also write a summary of what I think is good for long context from an attention perspective (which might be more speculative than the rest). It was especially clear when I started writing the post after getting experimental results - that was when I started searching literature for non-trivial baselines and found lots of places where potentially-not-ideal heuristics were used (including LLaMA 4 and Grok 2).
So in this post, I will talk about some theory, some small-scale experimental results, and what I think is a reasonable starting point for long-context attention design.
This is also organized in more of a research-notebook style than a paper or blog post should be, and would be better to read that way.
Spoiler
I think the following is roughly a good start (but likely not sufficient by itself) for training models for long-context, the rest of the post covers this in more detail:- QK-norm or QK-clip for stability,
- Logit scaling ~ \(\sqrt{2 \log n}\) for attention specificity,
- Explicit architectural choices that have some distinction between long and short context,
- Gating, norming/scaling the SDPA/attention output, and/or something similar, for sinks and for long-context behavior.
Scoping
As far as the new (or independently-found but existing) results are concerned (not the general discussion), since I only have finite compute/time and this is not my main thing, it’s important to define the scope of what we’ll look at.
Note that the most reliable way to get strong long context behavior is still the boring one (train at the target context lengths with enough compute and data). The goal here is to improve the architecture so that it does not quietly sabotage long context behavior (or at least makes the model’s life during training easier) and makes it easier/cheaper to extend context without as many pathological failures as before. There is an increasing amount of literature on getting strong long-context behavior, and many top LLM labs seem to be pointing to architectural improvements as the way forward.
These experiments were done on ~150M models (training to a 2K/4K context length and evaluation done up until a 16K/32K/64K context length - two setups), on ~1B tokens, with the Muon optimizer for 2D parameters and AdamW for 1D and scalar parameters, with my modded-nanogpt-jax baseline which replicates the modded-nanogpt speedrun in JAX on TPUs. More scaling experiments need to be done to validate these findings in case you’re training bigger models, even though I am reasonably confident that a lot of what we will see below transfers (and likely becomes more visible at longer context). Experimental setups and conclusions are in the section on experimental details, I’ve tried to summarize my findings instead of flooding the post with numbers (especially since there are so many experiments and this is a blog post, not a paper).
The main question is the following:
If we train a model only at short context, what actually goes wrong in vanilla attention when we run it at long context, and how can we fix that architecturally?
Large enough models with enough long context training can, in principle, learn around a lot of this. But we want to understand and patch the failure modes in the default setup.
Note that length extension training/tricks are currently out of scope (increasing RoPE \(\theta\), YaRN-style extrapolation and so on) due to compute/time reasons, though I do expect that in general, if we train in a way that sets up the model for better length generalization right out of the box, it is likely that it will have fewer issues later on as opposed to a setup that doesn’t care about it. So all we will consider below would be “out of the box” length generalization (how the marginal performance at a higher context length looks like when trained on a given context length).
As far as general discussion is concerned - my goals are to make it a somewhat self-contained note that attempts to explain how and why long context attention breaks if trained only at short context, and some practical notes on making training more long context-friendly by fixing certain architectural footguns. We will be talking about things like scaling stuff properly, using hybrid attention, positional encoding, gating and so on.
We will: (i) analyze attention logits under simple probabilistic models, (ii) derive a scaling law for logits; (iii) compare it to existing proposals; and (iv) connect it to and use it to inform specific architectural design choices and small-scale experiments.
What goes wrong at long context
Firstly, long context is important for tasks that require extended reasoning, agentic use cases, video models, and so on, so it is becoming increasingly important to pay attention to scaling context length. Many labs offer 1M+ token context windows, which is extremely long, but it is often the case that only a fraction of that 1M context length is usable.
Architecturally, as context lengths grow, attention (which is the only “dynamic” part of the model in terms of the sequence length) becomes the bottleneck - not just computationally, but also (and more importantly) representationally. It has to pick relevant information out of a potentially large set of distractors.
Note that here we will focus on NoPE (no positional encoding) instead of RoPE, since the understanding of more modern literature is that (pure) RoPE is not as good as NoPE for long context behavior/length generalization (and is better replaced/delegated to local information only), and NoPE also makes our analysis simpler. See the section on positional encodings etc. for more detail.
Even if we ignore stability tricks that are often pinned as culprits for bad long context performance, there is a basic statistical bottleneck here: if the attention logit distribution at every position looks the same as the context grows (and there is no bias like a decay as the distance between key and query indices grows), then the tail behavior of the logit distribution tells us that there will be many “competitors” in a given distance near the maximum. The maximum logit (or the second/third/so on) stops being cleanly separable, and lots of tokens are attended to inadvertently. Specificity is lost and we can’t attend to tokens we like without also attending to lots of distractors.
There is also a small caveat here - heads are sometimes highly structured and specialized - previous-token heads, punctuation/syntax heads, generally heads with attention sinks, and so on. More on this in the section on gating and attention sinks. But for now, we will talk about extreme-value-selection properties of attention - if attention in your model can’t do that at least half-decently, it is probably not a good model to scale.
The main object I use to reason about it is \(\mathbb{E}\left[\frac{1}{w_{\max}}\right]\) where \(w_{\max}\) is the largest softmax weight at a given query, and the expectation is a proxy for an empirical quantity and the distributions of things are unspecified for now (we will argue about this later). This has another interpretation other than the obvious one - it gives us a heuristic estimate for the number of things attention attends to.
If this quantity is roughly \(1\), the head is essentially focusing on a single token most of the time. As it grows, the attention to the top contributor(s) is diluted and spread out to other significant contributors (so a sum of values corresponding to different tokens).
We would like to be in some kind of a “critical” regime where the head can attend almost purely to one token when needed, and on average, it is not forced to put nearly all of its mass on a single token or, conversely, forced to spread over a large number of tokens (e.g., \(N^k\) for \(0 < k < 1\)).
Our analysis will be about staying in that critical regime as the context grows.
Training stability and QK-norm
Before we touch upon how to transform attention logits, we need an architecture that trains stably.
In fact, my main motivation for considering length-generalization was to ensure that the kinds of architectures I work with are all sane and scalable (otherwise you risk working with spherical cows in a vacuum).
There are a few places attention can have issues during training - dot products grow and logits blow up, gradients become unhealthy, and training can show instabilities that are hard to trace from just global metrics.
For softmax-based instabilities (i.e., specific to logits), two tricks help a lot: QK-norm for attention logits, and z-loss for the lm head logits (e.g., see Small-scale proxies for large-scale Transformer training instabilities (2023), though these techniques are older than the paper). These are training stability tricks that are used in quite prominent open model releases (and not just for this specific part of the architecture - people use z-loss for MoE routers for example) - Qwen 3, Gemma 3, OLMo 3, GLM 4.5, and Marin (which explicitly notes that they needed QK-norm to avoid concerning loss spikes), and I am reasonably confident that a few frontier labs (which obviously care about length-generalization) use these in their training setup.
We will look at QK-norm here - since modern understanding says that QK-norm leads to poor length generalization. Note that there are methods like QK-clip (in Kimi K2), logit-softcapping (in Gemma 2, but replaced by QK-norm in Gemma 3), QKV norm and other kinds of things explored in different places - and it is possible that logit clipping doesn’t affect length generalization as much as QK-norm (though generally there are still RMSNorms in MLA setups on the compressed states, which might lead to a somewhat similar behavior) - we will only look at QK-norm in detail (while doing some experiments for QK-clip later) for the sake of simplicity. Even without QK-norm, one could argue that logit magnitudes are somewhat similarly distributed due to the norms before attention in most models (except in cases like OLMo 2/OLMo 3, where there is no norm before attention), though the connection is indirect due to the (pseudo-)“metric” \(W_Q W_K^T\) in the middle that might meaningfully change training dynamics. It is worth noting that QK-norm does not apply to MLA directly, and since Muon leads to more logit blowups than AdamW in the Kimi K2 setup, they were left with no option other than use something close to QK-norm in spirit - QK-clip (rescales projection matrices post-update to bound the max logit growth).
QK-norm means normalizing queries and keys, often with an RMSNorm-style operation. People generally used QK LayerNorm, but now QK RMSNorm is more common (still with learned weights) - it seems to avoid certain issues with abnormally large layernorm weights especially when weight decay is applied (both are expected - without centering, vectors are in general more aligned and magnification factors end up not being that large, and weight decay’s purpose is to shrink weights). The modded-nanogpt speedrun goes a step further and removes the learnable weights (so it is practically a projection to the unit hypersphere), though not validated on a large scale as far as I know, which might make it a bit risky at scale even though many people say it’s not as big of a deal as we think. The motivation behind any kind of QK-norm is to help keep attention in a regime where the gradients are healthy. Note that all the analysis here will use QK-norm with weights frozen to 1 for the sake of simplicity (the modded nanogpt scenario) - if you plan on using learned weights with QK-norm, check how it behaves with the weighted setup (note that weights on top of layer norm also do not interact nicely with other theory either, and there can often be blow-ups, and we will avoid thinking about them by freezing these weights to 1 instead).
QK-norm is not without controversy. The “RoPE to NoPE and back again” paper (2025) argues that QK-norm is bad due to not preserving magnitude information, and proposes alternatives without QK-norm. Another work (Attention Sinks: A ‘Catch, Tag, Release’ Mechanism for Embeddings (2025)) shows that even models with QK-norm still form attention sinks. Also, lots of open models use QK-norm nevertheless, and show somewhat reasonable long context performance.
So we are in a situation where:
- QK-norm is empirically very helpful, maybe even necessary for stable training at scale.
- It has theoretical drawbacks (though normal attention with pre-norm also has these drawbacks).
- It appears to interact with long context behavior in nontrivial ways.
Part of this note is about understanding those trade-offs and how to patch them. In light of this, our crude baseline will be a strong baseline in terms of loss (the modded-nanogpt speedrun) which uses p-RoPE (considered good for length generalization) and parameter-free QK-norm (the thing whose behavior we want to improve upon and doesn’t length-generalize directly). Note that, however, our analysis will also apply (though potentially in a slightly weaker manner) to cases where we don’t use QK-norm.
Looking at the distribution
Let’s come back to our setup. The simplest heuristic is to treat attention logits as i.i.d. random variables and ask:
How should we scale the logits as a function of context length \(n\) so that the attention mechanism stays in a usable regime?
I initially worked this out as a math problem, and ended up with a $\sqrt{2 \log n}$-like scaling rule (and some potential positional biases) - we will call this \(\sqrt{2 \log n}\) scaling regardless of the minor tweaks we make (for the sake of brevity), which I will make precise and explain below. Later when doing a literature review, I saw related but different proposals like scalable softmax and scale-invariant attention.
There are two main assumptions on the logits that we can try.
The first one is a Gaussian, which is a crude but analytically tractable model, but becomes increasingly accurate as head dimension increases and is good as long as total context length is under \(\exp(O(d^k))\) where \(k\) is \(\frac{1}{2}\) or \(1\) or something like that (depending on what assumptions you care about).
The other one is a Beta distribution, which characterizes dot products of unit vectors in high-dimensional spaces.
The Gaussian assumption and \(\sqrt{2 \log n}\)
Assume that for a fixed query, the pre-softmax logits \(z_i\) over \(n\) keys are approximately i.i.d. Gaussian. This is wrong in general (both i.i.d. and Gaussian), but potentially useful for understanding extremes.
The idea is simple: consider the maximum logit \(z_{(1)}\) and the rest of the order statistics. As \(n\) goes to \(\infty\), after suitable centering and rescaling, logit values converge in distribution to a Poisson point process with intensity \(\lambda(x) = \exp(-x) dx\) (standard extreme value theory result). This will help us in analytically computing the expected value (inverse of max attention weight) that we care about.
The softmax weights are proportional to \(e^{\alpha z_i}\), where \(\alpha\) is some scaling factor we control. We want to choose \(\alpha\) as a function of \(n\) so that \(\mathbb{E}[1/w_{\max}]\) stays in a nice range.
Here is a sketch of the proof (a full proof is given in spoiler tags for completeness):
- Let \(z_i\) be i.i.d. Gaussian. The maximum is about \(\sqrt{2 \log n}\) plus lower-order corrections.
- Subtract off this leading term and rescale by multiplying by a term that grows like \(\sqrt{2 \log n}\) plus lower-order corrections: the differences between the max and the other large logits then converge to the Poisson point process with intensity \(e^{-x} dx\).
- When we scale logits by \(\alpha\), the exponentials behave like a series \(\sum_{k \ge 1} \exp(-\alpha \Delta_k)\) where \(\Delta_k\) are the gaps from the top, and become \(\approx \frac{\log k}{\sqrt{2 \log n}}\), so the “right” scale is something like \(r \sqrt{2 \log n}\) with \(r \ge 1\).
- If we pick \(\alpha = r \sqrt{2 \log n}\), the structure of the series becomes comparable to a zeta-like series - a more detailed computation gives something like \(1 + 1 / (r - 1)\) as the contribution from the tail.
A more complete proof
Firstly, we state the Poisson point process convergence.
Let \(X_1, X_2, \dots\) be i.i.d.. Define centering and scaling constants \(b_n, a_n > 0\) and the point process on \(\mathbb{R}\)
\[N_n := \sum_{i = 1}^n \delta_{a_n (X_i - b_n)}\]
where \(\delta_x\) is the Dirac point mass at \(x\). Then one can choose \(b_n, a_n\) so that
\[N_n \Rightarrow_{n \to \infty} N\]
where \(N\) is a Poisson point process on \(\mathbb{R}\) with intensity measure
\[\mu(dx) = e^{-x} dx\]
This is a standard result in extremal value theory (e.g., Resnick, Extreme Values, Regular Variation and Point Processes, Chapter 4).
We generally pick \(b_n = \Phi^{-1}(1 - 1/n)\) where \(\Phi\) is the CDF of the standard normal distribution, and roughly equals \(\sqrt{2 \log n} - \frac{\log \log n + \log {4 \pi}}{2 \sqrt{2 \log n}} + o(1/\sqrt{\log n})\).
The heuristic we claimed earlier (\(\frac{-\log i}{\sqrt{2 \log n}}\)) comes from integrating the intensity and computing how many variables are, on average, ahead of us. We do a more precise analysis below.
Consider the atoms \(\xi_k\) of the Poisson process to which this converges (in decreasing order). Then we are looking for \(\sum_{k > 0} \exp(-r (\xi_1 - \xi_k))\). Consider the mapping \(x \to exp(-x)\), applied to these atoms (and in general, this process). The intensity of the mapped process is \(1\) on \((0, \infty)\), and we get the mapped atoms \(0 < \eta_1 < \eta_2 < \dots\), and we are looking for \(\sum_{k=1}^\infty \left(\frac{\eta_1}{\eta_k}\right)^r\).
We condition on the maximum first, suppose the maximum is \(s\). Using Campbell’s theorem, we get the sum as \(1 + s^r \int_{s}^\infty (s/t)^r dt = 1 + \frac{s}{r - 1}\). Now taking expectations over \(s\), since \(s\) is an \(\mathrm{Exp}(1)\) random variable (and its expectation is \(1\)), the required expectation is just \(1 + \frac{1}{r - 1}\).
We observe two things:
- There is a regime where scaling by \(\sqrt{2 \log n}\) (as the leading term) keeps the expected inverse of the top weight in a critical zone.
- In that regime (and more broadly, when we scale by \(r \sqrt{2 \log n}\) for any \(r\)), the ratio between the top weight and the second top weight (or any two constant-rank weights) is roughly constant, not degenerating to 1 or to infinity.
Note that (if you looked at the proof) \(r = 1 \pm o(1)\) is likely the place to be at - any more and in the expected case, you have a constant weight at the top logit (when training models using \(r > 1\), we see worse losses than at \(r = 1\)), and any less and you attend to something like \(n^{1 - r}\) tokens in the average case, which is not exactly ideal (unless for training reasons - it might not hurt to start out with a trainable scaling factor less than 1 in front of the \(\sqrt{2 \log n}\), and it seems to help in my experiments).
When making this \(r\) learned (either layer-wise or global), it seems to be \(< 1\) (around \(0.5\)), potentially due to small/finite context-length effects (attending to \(\sqrt{1024} = 32\) positions is not that bad), or maybe it is just easy for the model to allocate attention to many positions (even if it is polynomial in the length) for loss reasons, or it might just be because our logits are not really i.i.d. Gaussian random variables, or that the model attends to different tokens in different layers and by the end, there is enough information for the unembedding head to figure out what the next token is going to be despite the diffuse attention. When allowing layerwise constants, the learned \(r\) values somewhat match with our intuition - initial layers have small scales (presumably attend to more global information/coarse structure), middle layers have larger scales (presumably more specific and precise behavior, which also lead to good embeddings), and the final few layers have scales somewhere in the middle. Note that we can try to do something like norming the SDPA output (or something less intrusive like plotting SDPA output vector norm by position, fitting a functional form to it, then dividing it by that, in order to let the updates to the residual stream not be forced to have some norm), since the “participation ratio” (sum of squares of attention weights) for \(r < 1/2\) in the i.i.d. Gaussian case is inversely proportional to the position (and for \(1/2 < r < 1\) is proportional to \(t^{2(r - 1)}\)), so might go down to \(0\) as context length grows. More experimentation is required for finding the optimal \(r\) (or even the slowly growing functional form that is not exactly \(\sqrt{2 \log n}\)), but for the sake of simplicity, we use learnable \(r\) since that is better than nothing.
This is the case where all of our logits are equally scaled. If we started out from scratch and didn’t have any positional encoding, we might want to have some ALiBi style biases (but with much less aggressive decay - something roughly like \(- \log (1 + t)\) for the token \(t\) places before the current token may be much nicer than naive ALiBi.
Note that there is another way - take the expected value quite literally and explicitly solve for moderate \(n\) as well. This does not work very well, maybe because as \(n\) becomes smaller, the inverse top weight expectation becomes similar to \(n\). One can also try an attention weight of \(\log n\) or something similar, but that also doesn’t work as well. As a consequence, there are many different ways of modifying \(\sqrt{2 \log n}\) for small \(n\) - one thing that worked well was using the expected largest logit to the first two terms (\(\sqrt{2 \log n} - \frac{\log \log n + \log {4 \pi} - 2 \gamma}{2 \sqrt{2 \log n}}\) where \(\gamma\) is the Euler-Mascheroni constant (but \(n\) needs to be shifted by at least \(2\) to make the \(\log \log n\) work). In general, we would use something like \(a + \sqrt{2 \log (b n + c)}\) for some constants \(a, b, c\) - \(a = 0, b = 0.1, c = 2\) (picked at random while keeping in mind that \(b < 1\) should work according to our previous choice) also seem to work somewhat. As earlier, learnable scales can improve things and are generally good for loss.
The Beta assumption and \(n^{2 / (d - 1)}\)
The Gaussian assumption is mathematically convenient but not that realistic if you think about long context behavior. For random high-dimensional vectors on a sphere, dot products or cosines often behave like Beta distributions - in the case of attention, if \(q\) and \(k\) were on the hypersphere in \(d\) dimensions and completely random, a logit can be reparametrized as \(\sqrt{d}(2 B - 1)\) where \(B\) is the Beta distribution \(B(a, a)\) with \(a = \frac{d - 1}{2}\) where \(d\) is the head dimension.
If you assume a Beta distribution for the logits (or, equivalently, cosines), you get a very different scaling - the required logit multiplier grows like \(n^{2 / (d - 1)}\).
Note that this is roughly accurate - unless \(W_Q\) and \(W_K\) have perfectly aligned subspaces, or the representations lie on a low dimensional ball exactly, this exponent is correct. A simple way is to think about it is to consider two random vectors in three dimensions, and restrict them to two circles. If the circles coincide, the \(d\) to use is trivially \(2\). If they are at an angle, the tail behaves like a Beta distribution (you have to deal with an elliptic integral, and there is an inverse-sin dependence on the angle in the constant factor, but still, the exponent corresponds to \(d = 3\)).
This scaling is much faster than even \(\text{poly}(\log n)\). Intuitively, the lighter the tails, the more aggressively you need to crank up the sharpness (logit scaling) to be able to distinguish the top token. And since Beta tails are light - in fact bounded (here by \(\sqrt{d}\)), we need to pay a heavier price than in the Gaussian case.
Unfortunately, actually computing this scaling exactly in practice is numerically painful on multiple accounts: the analysis involves the confluent hypergeometric function \({}_1F_1\) (hyp1f1) which behaves badly for large \(d\) and \(n\), cumulant or series expansions converge very slowly in the regimes we care about, Monte Carlo integration has quadratic sample complexity in the desired precision (equivalently, quadrature methods require quadratic number of nodes for the double integral), Newton’s method tends to diverge on the relevant equations, root bisection and even Adam-style methods can be brittle. Some of these issues hold even for the “asymptotic” solution case which requires integrating over an incomplete Gamma function.
Empirically, when I tried to use Beta-based scaling, I ran into these numerical issues and effectively had to reject the approach for now. Even after some degradation due to numerical approximations, the performance matched the simpler \(\sqrt{2 \log n}\) setup or was slightly better (for very small head dimensions) - though it might be interesting to look at it, if you’re interested enough in numerical methods to try to make it work.
So the recommendation is to use Gaussian-based \(\sqrt{2 \log n}\) scaling heuristics by default - they are simple and robust. The Beta analysis is just a piece of warning that, eventually, for moderate \(d\) and very high \(t\) (something like \(\exp(O(d^k))\) where \(k\) is \(1/2\) or \(1\), depending on what you care about), you might need more aggressive scaling (or bigger heads). Large head dimensions push the problematic regime far out in \(n\), so in practice you can postpone this issue.
Local vs global behavior and inductive biases
Another problem is that when you extend context length, you implicitly change how local patterns behave. For the rest of this section, suppose that our current attention layer cares about local information (or the intent of including it is to do so).
Locally (over a small window), you want the model to behave essentially the same whether the total context length is 4k or 64k, as well as maintain roughly the same distribution of logits over a small neighborhood.
However, if you simply scale logits as a function of global length \(n\), you may end up paying an unnecessary penalty for attending over the entire context and making local attention patterns weaker or sharper in a way that depends on the global context length, not just local structure.
A good inductive bias would be one that is length-invariant locally (or doesn’t behave as sharply on total context length as our scaling would do), and for global information (if the attention layer is allowed to look at everything), is very picky about what it attends to (both in terms of sharpness as well as total weight) - so as we go away, a bias (like a negative logarithmic bias as we discussed earlier) and some sharpness would be nice. Note that in the Gaussian regime, \(E[\exp(aX + b)] = \exp(a^2/2 + b)\), so some part of \(b\) needs to also compensate for \(a\) if we only look at unnormalized attention mass (the second moment however is \(\exp(2a^2 + 2b)\), so we need to keep this in mind to avoid over-sharpening).
Note that if we want to keep the overall behavior intact, we can just use a sqrt-log-scaling in \(t\), the distance from the current token (or a power-law if you care about the Beta case), so that for the vast majority of positions, we still follow a very similar scaling, plus a bias to compensate for the multiplicative factor that comes from sharpening the Gaussian, and a potential positional decay to avoid attending globally too much.
Note that this is more of a positional embedding than a scaling technique that we would like to use to stabilize our training. Also, using only this kind of a layer (as opposed to also using global layers) might conflict with benchmarks such as NIAH or RULER, where if the needle is at the beginning vs if the needle is at the end, there will be different behaviors if there is a positional decay (before you say NIAH is a bad benchmark, it is just an example - in real world usage, there may be more non-trivial retrievals prone to this failure mode).
Similar existing literature
After the above, I decided to look at what approaches are present in literature. Two main approaches stood out to me as adjacent to this idea: scalable softmax (which is also used in Llama 4 and Grok 2) and scale-invariant attention.
Scalable softmax
The scalable softmax literature (many papers proposing roughly the same idea) proposes a more aggressive solution: multiply logits by \(\log n\) (instead of the \(\sqrt{2 \log n}\) we derive), alongside additional learned scale factors.
As \(n \to \infty\), this makes the distribution increasingly sharp (consider the difference between the maximum and the second-maximum logits), and it is not good in the average case for multiple reasons.
In certain worst-case or toy problems (e.g. “Overcoming a theoretical limitation of self-attention”), a \(\log n\) scaling can be provably good. From a worst-case theoretical perspective, this is attractive: you can guarantee that, in the worst case, attention remains capable of selecting a single relevant token. But from an average-case ML perspective, it is too aggressive.
Scalable softmax often includes per-head learnable scales, which somewhat mask this issue for normal context lengths. In principle, you could try some heads with \(\log n\) scaling, others with \(\sqrt{\log n}\) scaling as above. This might work in practice but is hard to interpret.
In my experiments, I found:
- \(\sqrt{2 \log n}\) scaling (when using identical parameter counts as \(\log n\) scaling) consistently improves upon \(\log n\) scaling in terms of validation loss at all context lengths (per-token NLL averaged over ranges of powers of 2), and also slightly beats it (or is at parity) in terms of loss at training context minus the loss with longer contexts.
- The original scalable softmax formulation used per-head learnable multipliers - I then removed those (but did keep a global learnable multiplier) in both setups when comparing to my setup, to isolate the effect of the functional form of the scaling. The wins hold up in this case too.
Note that in my experiments, it helped to use a similar kind of \(a, b, c\) formulation for \(\log n\) as well (which is what I compared against), though I think using it for \(\log n\) just pushes the issues further in the future.
Position-dependent scaling and scale-invariant attention
Another dimension is whether scaling depends only on global sequence length \(n\), or also on position \(t\) (distance to current query token). This is similar to what we discussed earlier.
The scale-invariant attention paper addresses this by introducing position-dependent scaling, and this and the earlier idea about making things position-dependent for local attention layers are remarkably similar. It ends up with a similar functional form to \(\sqrt{\log t}\), and depends on a scale parameter \(\tau\), which is quite similar to the \(a + \sqrt{2 \log(b n + c)}\) form as before (they use \(\sqrt{2 \log(1 + t/\tau) + 1} = \sqrt{2 \log (\sqrt{e} + t \sqrt{e}/\tau)}\) as the scale and \(-2 \log(1 + t/\tau)\) as the bias, which is a bias of \(\approx -\log t\) after removing the compensation for the exponential-of-scaled-Gaussian). Their idea is to think about what happens at different scales (i.e., total attention mass in \([a^k, a^{k + 1})\) ranges, global attention sparsity and so on), and trying to construct desiderata based on that, which somewhat align with what we want to do as well, other than the “scale-invariance” which we think of as a positional encoding thing more than a long context thing.
More precisely, \(L_{\text{effective}}(t) = a(t) \cdot L - m(t)\), where \(L\) is the raw logit, \(a(t)\) is a slowly varying sharpness scale (as described earlier), \(m(t)\) is a slowly increasing offset. The log of the mean is \(\frac{a^2}{2} - m = 1/2 - \log(1 + t/\tau)\), which shows a decay, and hence this might suffer from the same issues as mentioned before, at long context. The constant factor in front of \(m\) can also be tuned, but I haven’t experimented with this yet. Note that the higher order moments of the exponential of the logit do not show a decay with \(t\) for the second moment (and increase with \(t\) for third and higher order moments), though thinking of it as a decay seems more consistent with experimental observations later.
Note that the OpenReview discussion on this paper talks about the fact that logits look at least marginally Gaussian, which is a weaker thing than we need for our setup, but it’s some confirmation that our Gaussian logit assumptions are not completely broken (but they are still broken).
In my experiments, scale-invariant attention outperforms a pure \(\sqrt{\log n}\) scaling that ignores position, in terms of loss, but not in terms of loss difference at training vs longer validation context length (in which case they have similar performance). The performance gap appears to come from an implicit positional bias, which is roughly like a slow ALiBi-like bias. To check this, I also tried just adding a \(-\log t\) style bias, but it seems that the more careful small \(t\) behavior is also part of the positional bias that plays a role here.
Small note: their experimental setup has some issues like not using document masking, which is likely a missed detail, and it ends up favoring their results a bit more than it truly does over the other approaches (because the attention mass at other documents is artificially reduced in their setup), and also gives wrong “length generalization” metrics (having to attend to multiple documents is arguably quite bad for models at the scale both they and I are training, but this might not be as big of an issue for larger models). After fixing this by adding document masking, there is rough parity between the approaches, modulo the slight constant loss offset due to positional effects of their setup. It also shows that the paper actually undersold even their own results - loss decreases visibly when at long context, instead of their increase from 16K to 64K.
In an architecture with different short and long context layers, I think one natural configuration could be that short context layers (e.g., RoPE with sliding windows) use scale-invariant attention (maybe even just full attention with scale-invariant attention), and long context layers (e.g., NoPE + global attention) use pure \(\sqrt{2 \log n}\) scaling. There are also other possible variants where we “interpolate” between scale-invariant attention with configurable decay rate and pure \(\sqrt{2 \log n}\) scaling, as well as configurable coefficient of \(m\) that decides the decay. I have not tried a lot of experiments on this yet either.
Positional encodings and hybrid attention (local and global)
RoPE has become the default positional encoding for LLMs, but it can behave badly even after context length extension adjustments and training. In practice, RoPE thetas are quite large (around a million, or in the case of Grok-2, around 200 million). At these scales, the slowly varying components can sometimes safely be replaced by no positional encodings (NoPE), and it seems wise to do so, especially because long distance behavior with RoPE would gradually become OOD for the model. NoPE by itself seems to length-generalize better than RoPE, but with poor performance on evals (which is likely due to poor handling of local information).
Both the “Round and round we go” paper and “RoPE to NoPE and back again” paper notice that RoPE seems best at capturing high-frequency features, and that it is less well-suited for low-frequency or global information. Keeping this in mind, they propose different hybrid schemes - the former proposes p-RoPE (where some fraction of dimensions don’t use positional encoding at all - this is the setup I use for my experiment as well, alongside QK-norm), and the latter proposes RNoPE-swa (where out of every 4 contiguous layers, 3 are local (sliding-window) and use RoPE, and 1 is global attention and uses NoPE, and the last layer uses global attention) - their hypothesis is that using RoPE for high-frequency information and NoPE for full-context (global) aggregation is good, and if you use RoPE only with SWA, this eliminates RoPE-related issues with length generalization.
Lots of commonly seen models do the SWA-for-local-attention thing - e.g., Gemma 2, Gemma 3, Command A, GPT-OSS-20B, GPT-OSS-120B.
Another architectural direction is to replace some self-attention layers with RNN-ish or state-space-like layers, though this is mostly done to improve decoding speed and not just remove RoPE. SWA layers have a sense of stationarity that these might not have (and would require more work in order to enforce), but these layers tend to be more sophisticated, especially with more recent RNN architectures, and they can also store some more global information instead of just local information, which can help. There are a lot of these kinds of hybrid models, e.g. (in no particular order) Jamba, RecurrentGemma, Zamba, Samba, Mamba-in-Llama, Zamba2, Hymba, Bamba, MiniMax-Text-01, Hunyuan-TurboS, Zebra-Llama, Nemotron-H, Qwen3-Next-80B-A3B, Jet-Nemotron, Kimi-Linear-48B. Note that some of these “short context” (or more precisely, “constant context”) layers are not compatible with RoPE - people generally either don’t bother with positional embeddings, or use something like a 1D causal convolution (so, a linear combination of past few tokens) - this can also be applied at different parts in the architecture, not just at the beginning or in the “attention” layer. As a side-note, causal convolutions are important for linear attention performance (and have been in RWKV for a long time, and the Primer paper also finds it useful for transformer architectures), and modern linear attention variants use it - one possible explanation is https://kexue.fm/archives/11320.
For this design, it seems important to use more modern linear attention architectures instead of plain linear attention (the Gated DeltaNet/RWKV7 style architectures) - indeed, MiniMax’s experiments and posts by their researchers suggest that using plain linear attention doesn’t work (and they used vanilla linear attention in their first attempt, MiniMax-Text-01) and reverted to full attention in their second model, while Moonshot and Alibaba released Kimi-Linear (uses KDA - Kimi Delta Attention) and Qwen3-Next (uses gated deltanet) respectively, and both models are competitive with (and sometimes outperform) their own full-attention baseline - in Moonshot’s case, they stress-tested the architecture throughout the training process by throwing potentially adversarial evals at it. Another reason to use modern linear attention variants is that standard linear attention variants don’t have strong attention-like properties, and end up being weaker than SWA. Deepseek also released a sparse-attention variant (though not local) they call Deepseek Sparse Attention (DSA) and they use it in their model Deepseek-V3.2-Exp, and it is also a very promising approach. These three papers by Moonshot, Alibaba, and Deepseek are worth reading.
There is another kind of approach for SWA that seems worth mentioning, and one of its examples is 2-simplicial attention. Originally, it was supposed to be a tensor-based generalization for attention (the name 2-simplicial comes from the simplex between three indices \(i, j, k\)), as older papers on 2-simplicial attention do, but naively implementing it requires compute that scales with \(n^3\) with sequence length, leaving alone concerns about caching. A recent approach is to use it only for a “sliding window” (they use a sliding window of 512 for one K dimension and 32 for the other K’ dimension), and it seems to improve the exponent in their scaling laws, which sounds like a big deal. Throwing more compute at something in deep learning generally works, so using more compute for a short window where resolution is important sounds reasonable.
I think most of these approaches make sense and are quite strong, and putting some of these together (along with fixing scale issues and potentially gating for full-attention) would likely keep us on track for long context (modulo things like precision issues). Side note: we handle both constant amount of past context as well as global context, but we might be missing deliberate architectural choices that cater to “moderate” amounts of context. While logit scaling with \(r < r_{\text{crit}}\) and/or different speeds of decay terms might provide some multi-scale behavior, it is still possible that this doesn’t solve issues as cleanly as for global and local behavior, and that might matter at extremely long context lengths.
The rest of this post assumes something like that is available for the positional side and focus on the scaling and normalization aspects.
There is one important thing that the above discussion misses, though, and that is length-generalization but in an “algorithmic” sense. For example, a task which needs \(f(n)\) sequential short-context computations and \(g(n)\) sequential “all-gather"s, where either \(f(n)\) or \(g(n)\) increase with \(n\), will be bottlenecked by the number of layers at some point, if we look at single-forward-pass performance (note that we are being somewhat tautological as well as imprecise here, but this is mainly for intuition). There are architectures with looped layers (universal transformers, for example, are an active area of research) or adaptive compute in general, but we will not talk about these (even though these are interesting to me), especially since as of now, the largest models with a variable number of layers are nowhere near the frontier, and are hard to train via backprop in general.
Attention sinks and gating
The Qwen “Gated LLM” paper introduces a gating mechanism on top of attention (their models seem to use QK-norm too, judging from the Huggingface config) and shows good length generalization, and that too without attention sinks. Attention sinks might hurt our analysis in some obscure way, are bad for interpretability, but also emerge naturally during training and are the reason why transformers can do many important things that they are able to do. From a certain perspective, one could argue that since attention sinks are dramatically reduced after using gating, and we don’t see performance degradation (rather, we see better length generalization), we should be happy.
Attention sinks are somewhat controversial in that way - so it makes sense to defer to the Qwen paper on gated LLMs and the catch/tag/release mechanism paper for the current consensus on their purpose (especially their related work sections) - there is a whole body of literature on it. Note that this second paper talks about much weaker attention sinks than the one at the first token - from the attention diagrams in the first paper, gating seems to preserve the kinds of sinks that they talk about.
In transformers (will be talking mostly about LLMs), attention sinks play several roles, some of them are serving as tags for global information, potentially improving issues with in-context covariate shift, and more simply, providing a place for heads to dump mass.
But this might be wasteful, since sinks use capacity that could be spent distinguishing hard tokens or performing better aggregation, and a lot of the sharpness (distinguishable-ness) budget goes towards global tagging. Attention sinks are important enough that people design stuff around it, yet the model seems to need to do some gymnastics around it (large activations in general, much smaller value vectors while having huge attention logits corresponding to sink tokens and very specific low-variance keys corresponding to sink tokens as a sink-identification signal to the query - this is partly the model fighting against total softmax weight 1). If we had an alternative mechanism that does all the things attention sinks do, while not making logits (the only dynamic part of our network), it may make life much simpler.
Gating seems to be one solution to this - you compute attention as usual, and then gate the output of SDPA in an input dependent manner.
My rough hypothesis is the following. Part of what attention sinks achieve (e.g. working around the sum-to-one constraint of the softmax, among other things) is also achievable by gating but through an input dependent gating mechanism. This gating governs the output scale of the linearly-combined value vectors that was earlier done through sinks - since the attention sink key signals to the query tokens, this means that there is information in the representation of all tokens in context that attend to the sink, so it should be sufficient to use a projection for the input to capture this attention sink (thinking to first-order and not worrying about different training/representation dynamics). In fact, headwise gating should be sufficient to achieve this - and we can go further by using elementwise gating for more capacity in value vectors - different scales for different elements, all depending on different projections of the input, which might ensure that all elements of the value can be meaningful separately because the model can choose post-hoc which dimensions of a value are worth passing on - sparsity is important. This also reduces optimization pressure that pushes the model to create sinks in the first place. This is something that I think should improve interpretability as well.
The Qwen gating paper also shows that norming the SDPA output itself (or adding non-linearities there) is beneficial, and that is another explanation why it helps improve performance (since dividing into multiple heads gives low rank structure to \(W_V W_O\)), though it doesn’t as directly explain long context improvements.
Note that this is an orthogonal change on top of our logit scaling - and since it removes attention sinks, I think it might favorably impact our scaling. One caveat is that the paper’s attention maps seem to show significant attention mass concentrated at the current token which could be slightly concerning - this is expected, though, since the q and k projections of the current token have correlations in the general case as they are projections of the same vector. The following is why I think removing this attention sink alongside the improved logit-scaling also improves their length generalization. Attention sinks provide some kind of non-scaling anchor, and at large context lengths, its contribution changes. Without our scaling, there might also be issues with the long-context behavior of scale of value vectors (“participation ratio”, defined as the sum of squares of weights, decreases as context length increases, so value vector scale might decrease). It might be the case that as context length grows, these cancel out in a vanilla setup, but in our setup, being more deliberate about it should help (by both removing attention sinks as well as controlling the variance of the SDPA output by ensuring near-constant or very slowly growing participation-ratio).
My current intuition is that gating should contribute positively to get a robust long context setup, when combined with QK-norm for stability, careful logit scaling and hybrid positional encodings/layers for both good local and global structure. Note that gating is not the only method there is that can lead to this behavior.
In my experiments, gating helped reduce overall loss (I experimented with the initialization for the gating matrix at all zeros, and doubled the sigmoid, so that at init, you have the same behavior as earlier - it’s possible that this is not the best init but it worked better than a couple other inits for me). Note that the loss improvements might just be because of the extra parameters, so we look at length generalization. In terms of length generalization, at a small scale, there seems to be only a slight negative difference, if at all, due to gating (on top of scale-invariant attention), and this is in line with the paper. It is possible that the gains here primarily arise in context length extension, as seen in the paper, which is out of scope for the post. In terms of training stability, which is another thing gating aims to fix, there were no spikes on the small scale I tested things on either before or after, and it is likely that only at large-scale (model size and training steps) do we see noticeable differences in training instabilities with/without gating.
Revisiting QK-norm and norm information
The “RoPE to NoPE and back again” paper makes an interesting hypothesis for the observed poor QK length generalization results in their setup - that QK-norm effectively removes magnitude information from \(Q\) and \(K\).
But do we actually want magnitude information to stay in \(Q\) and \(K\)? Our approach so far has been to be very deliberate with the norm information we want to keep in logits instead of relying on magnitude information from \(Q\) and \(K\). However, one pitfall with being deliberate about it is that we can be empirically wrong when it comes to practical logit behavior (i.e., the geometry of the \((q_i, k_j)\) vectors, let alone thinking about the projection matrices).
There are also other arguments in favor of QK-norm: there is an interesting line of work (called test-time regression/training) that shows how linear attention is a kind of a least-squares estimator in the context of in-context learning, and that work also shows that attention with QK-norm is locally constant non-parametric estimation in that setup while unnormalized attention is suboptimal in their setting, so it is theoretically interesting. Our scaling also fits into their setup - it only requires changing the bandwidth of the kernel, and our choice corresponds to the case where the “participation ratio” is near its critical state. While this line of work is not very well-understood, it gives some intuition nevertheless.
Most modern architectures are prenorm, so some of these effects are already decoupled from QK-norm. So even if we don’t use QK-norm, it’s plausible that our scaling might still apply if projection matrices are well-behaved (on small scale experiments, this assumption seems to hold weakly for QK-clip). They relieve pressure on \(Q\) and \(K\) to encode position or context-length information in their raw norms.
If you do find that norm information is genuinely important but QK-norm removes too much of it (which I haven’t faced yet), there is a simple idea to preserve some norm information, similar to soft capping: make the epsilon finite (and optionally learnable). The norm of the vector carries the norm of the original vector - if you take the norm of the “RMSNorm”-ed vectors now, they should be \(x^2 / (x^2/d + \epsilon)\) and since the \(\epsilon\) is non-negligible, you get norm information for free (though potentially in the way you might want it). This also gives the norm layer some Lipschitz-ness if you care about it. However, it partially gets rid of the niceties one gets with using norms.
Experimental details
Details
The ordering of these experiments is for ease of understanding, and doesn’t reflect chronological order.
Note that some of the results are bound to be specific to my setup - double check before scaling up (both at small and large scale).
Also, most experiments were run post hoc (i.e., after coming up with theory) due to limited compute, there might be parts of the design space that were not sufficiently explored.
- Baselines and initial setup
- Scale-invariant attention paper has some code they provide for their experiments.
- From their code, we use scale-invariant attention and log-N scaling with p-RoPE.
- Their code doesn’t have document masking, which adversely affects results for other methods (especially \(\log n\) which oversharpens). Adding in document masking improves results across methods, even theirs. We will discuss results only with document masking.
- My modded nanogpt variant implementation in JAX.
- Both vanilla attention and multiplier-based scaling \(\sqrt{2 \log (n + 1)}\)
- Scale-invariant attention paper has some code they provide for their experiments.
All experiments for evaluation were run with three seeds, tuning was done with one seed.
Light tuning for \(a + \sqrt{2 \log (bn + c)}\) (with \(a = 0\)) gives \(b = 0.1\) and \(c = 2 - 0.1\) (due to slight 0-based v/s 1-based indexing differences) and works well (there might be a better choice that I may not have explored). Before that, also tried \(\sqrt{2 \log (n + 1)} - \frac{\log \log (n + 1) + \log (4 \pi) - 2 \gamma}{2\sqrt{2 \log (n + 1)}}\) which also works well (this is also the expectation of the maximum logit up to \(o(1 / \sqrt{\log n})\), which is incidentally not a very good bound, but works for us). This is important because small \(n\) behavior is important, and we don’t want our functional forms of scaling to be merely asymptotic (which is a bigger problem here than one would expect because these scales are very slowly growing, and so is the error in approximation).
Another thing to keep in mind is that just changing the functional form of scaling is not enough, it is also important to adjust different scales - e.g., the initial attention scale to ensure parity across variants for long context (attention entropy matters), the expected SDPA output variance after gating and so on. Keeping this heuristic in mind leads to efficient hyperparameter tuning.
When measuring performance according to final loss, the results were (rightmost is best): original baseline < \(\log n\) scaling < \(\sqrt{2 \log n}\) scaling < scale-invariant attention. All comparisons at similar parameter count if any scales are learned. The differences are roughly in line with literature, enough for it to make sense to use the latter two.
When measuring performance according to length generalization (loss at 4K context v/s loss at 64K context, or when training on 2K, loss at 2K context v/s loss at 32K context), the results were (rightmost is best): original baseline \(\ll\) \(\log n\) scaling \(\lesssim\) \(\sqrt{2 \log n}\) scaling \(\lesssim\) scale-invariant attention.
Training-time-wise, scale-invariant attention is slower than the other alternatives - for the rest, when not using biases, the scaling can be absorbed into the query scale, so it is sufficient to do a pointwise product of a static vector with Q.
In terms of principled choices, the most suitable choices seem to be \(\sqrt{2 \log n}\) scaling and scale-invariant attention, and since validation losses are better too, I think they are both good decisions overall (one for global layers, one for local layers, and in general can also be mix-and-match’d with learned bias scales and such).
Then I looked at the \(\sqrt{2 \log n}\) multiplier setup with head-wise, layer-wise, global and static weights. It seems that having layer-wise/global weights is sufficient, and there is no loss of length generalization. Overall loss improves the more weights we add. Upon looking at the weights, the weights were consistently around (mostly less than) \(1\), which was interesting.
I also looked at solutions to \(E[1 / w_{\max}] = C\) assuming Gaussian logits. Numerically this gives small factors in front of the \(\sqrt{2 \log n}\) for small \(n\), and doesn’t help loss-wise either. The reason why this is expected to not work is that as \(r\) tends to \(1\) (the critical regime), at infinite \(n\) the expected inverse weight blows up, and in the finite \(n\) regime, the weight becomes something that slowly grows with \(n\), which should be a better scaling. Edge effects at small \(n\) (say at \(n \le C\)) and the large \(n\) behavior are at odds with each other in this scenario.
I tried a bunch of different techniques for beta logits (for both scale-invariant attention as well our uniform-scaling setup). For long context, it becomes quite inaccurate really quickly (numerical issues - solving equations or sampling or just precision), and some quick calculations and a couple of experiments showed that for typical head sizes, we would need quite long context lengths to see noticeable issues - depending on what you care about, a head size of 64 could lead to issues at a roughly 10k context length or 1M+ context length, and the rate of growth of these thresholds is something like \(\exp(O(\sqrt{d}))\) and \(\exp(O(d))\) respectively - this is why I kept this set of experiments for revisiting later. I ran experiments with head size 16, and didn’t find a super significant difference with the exact solution (this was only one experiment and likely suboptimal, and can be refined further if one wants). The main issue is hyp1f1 and Beta related things degrading at large \(n\) and \(d\), and it is worth trying a linear interpolation as a good approximation. In any case, one can increase head dimension practically (to something like 256) and stop worrying about it.
Then to understand if the difference between scale-invariant attention and normal attention is just the positional bias, I added an extra bias of the form \(- \log (1 + t / \tau)\) to my setup. It reduces the total loss, but doesn’t completely catch up to scale-invariant attention, which means there is some low-\(t\) or low-\(n\) behavior that is also important. Ideally to verify this, one would sweep over more parametric forms.
I also looked at learned coefficient of \(m\) very briefly (though it needs more rigorous experiments) - making the coefficient of \(m\) learned (and per-layer constant) in scale-invariant attention shows that for some layers, the coefficient is close to or smaller than \(1\) (less local bias), and for other layers, it sometimes even exceeds \(2\) (highly local bias). This suggests that the easiest way of improving loss is to improve local behavior of the model (or the “decay” doesn’t actually affect things), and more fine-grained tests are important (e.g., finetuning the model to solve NIAH since at the scale at which we’re training models, doing this is non-trivial).
I did a few other experiments (some of them are from the ones below), and it seemed that results from scale-invariant attention transfer to the \(\sqrt{2 \log n}\) scaling, so since scale-invariant attention is validated better in literature, for the rest of my experiments that required a large number of training runs, I decided to use it (instead of all setups) to save compute and time. Note that this is only a proxy.
Other minor experiments included doing per-head/per-layer/global/static approaches for scaling these multipliers, and global scaling generally seems good enough at our scale within margin of error (sometimes you would need an extra smaller-than-\(1\) scale in front), but per-layer is also a good (and probably safer) choice.
Then I moved on to gating - as pointed out earlier, it seems gating improves loss (at least due to parameters, hard to decouple it from actual improvements), and going by the Qwen paper it seems likely that it will pay off when doing context length extension - this is something that needs to be done in order to validate the usefulness of gating, but I like the approach.
Then I did a bunch of other experiments related to verifying assumptions on whether QK-clip could replace QK-norm or not and whether or not our setup (or the scale-invariant attention setup) works on it or not (since for Muon, the optimizer I was using, it seems that at scale it is absolutely necessary to have some sort of mechanism in place to control logit growth). It seems that both scale-invariant attention and the \(\sqrt{2 \log n}\) scaling work with QK-clip (at least improve it slightly), though in my experiments, the onset of length generalization is a bit later in training for scale-invariant attention as compared to the \(\sqrt{2 \log n}\) scaling. It is worth noting that scale transfer properties for QK-clip are not as well-understood as they are for QK-norm, and it is possible that things change at scale.
I also looked at whether NoPE + QK-norm gets better length generalization than p-RoPE + QK-norm in our setup - it seems that there is a slight difference but it might as well be negligible, but while performing length extension I would expect there to be some extra things to keep in mind with p-RoPE.
Different positions and types of normalizations may have different effects on length generalization, so I looked at a few recommendations from literature - peri-LN (sandwich norm), and QKV normalization (with or without the pre-norm). With QKV normalization, there wasn’t any significant improvement in length-generalization performance with these as far as my experiments explored - it is also possible that tuning here was not enough (since making such drastic changes can change training dynamics a lot), so it is still inconclusive. I did try normalizing the SDPA output (it is also somewhat in the spirit of peri-LN), and that gave a bit better final loss and length generalization performance. An alternative can be adding an RMSNorm after the output projection (with learned weights, not frozen like in our case), like in peri-LN, but I haven’t tried this yet.
It is possible that I’ve missed mentioning other experiments I did, I’ll add them if I remember any. Please let me know if there are any experiments worth checking out - it’s possible I’ve done those, and if I haven’t, I’ll run those if feasible.
As far as future work is concerned, the following is something I couldn’t do due to compute/time constraints:
- Redoing these experiments at scale (and more exhaustively)
- Revisiting the Beta scaling more precisely, and figuring out at what scale we end up in a regime where the head size starts to matter
- A good start for both stability and length generalization seems to be the following, performing ablations should be interesting:
- Architecture/positional encoding as mentioned in the section on them (different hybrids, getting rid of RoPE)
- QK-norm or QK-clip for stability (and comparing it to peri-LN/sandwich norm and QKV normalization and/or the many different kinds of stability tricks people use)
- Gating or similar mechanisms for better length generalization (and checking whether this holds after our scaling and so on)
- Logit scaling mechanisms such as \(\sqrt{2 \log n}\) scaling or scale-invariant attention (and looking at what happens to sinks in our setup, what happens at different positions, looking at head specialization and so on)
- Adding a norm (either after SDPA or learned norm after the attention module) to prevent the attention module’s output from varying with position, in the limit.
- Tracking different kinds of statistics across context length as well as across the training run should provide insights and help improve scaling in more fine-grained ways
- For all these methods, check if finetuning helps (my setup partially does this due to context-length warmup to double the length halfway through the training run, but it would be better to do evals like NIAH on models of this size after finetuning them for such capabilities since they’re very small/undertrained compared to usual releases and hence bad at these evals)
Acknowledgements
Acknowledgements, in no particular order
- fairy - for discussing some results across modalities, pointing out Llama 4’s log scaling and other miscellaneous discussion.
- Grad - for miscellaneous discussion.
- fluffy - for miscellaneous discussion.
- fern - for miscellaneous discussion.
- uwu1 - for pointing me to Grok-2’s log scaling.
- WhereIsTheExit - for pointing me to the attention catch tag and release paper a while back and pointing out that QK-norm also has attention sinks.
- Ravna - for believing in QK-norm and whose conviction in QK-norm made me get into this rabbit hole in the first place.
I would also like to thank Google’s TPU Research Cloud for TPU credits which provided for a part of the compute used for this work.
References
A non-exhaustive list of references
The following list of references is incomplete, please refer to the their references for more information.
- Henry, A., Dachapally, P. R., Pawar, S. S., & Chen, Y. (2020, November). Query-key normalization for transformers. In Findings of the Association for Computational Linguistics: EMNLP 2020 (pp. 4246-4253).
- Dehghani, M., Djolonga, J., Mustafa, B., Padlewski, P., Heek, J., Gilmer, J., … & Houlsby, N. (2023, July). Scaling vision transformers to 22 billion parameters. In International conference on machine learning (pp. 7480-7512). PMLR.
- Chowdhery, A., Narang, S., Devlin, J., Bosma, M., Mishra, G., Roberts, A., … & Fiedel, N. (2023). Palm: Scaling language modeling with pathways. Journal of Machine Learning Research, 24(240), 1-113.
- Wortsman, M., Liu, P. J., Xiao, L., Everett, K., Alemi, A., Adlam, B., … & Kornblith, S. Small-scale proxies for large-scale transformer training instabilities, 2023. URL https://arxiv.org/abs/2309.14322.
- Ding, M., Yang, Z., Hong, W., Zheng, W., Zhou, C., Yin, D., … & Tang, J. (2021). Cogview: Mastering text-to-image generation via transformers. Advances in neural information processing systems, 34, 19822-19835.
- Kim, J., Lee, B., Park, C., Oh, Y., Kim, B., Yoo, T., … & Yoo, K. M. (2025). Peri-ln: Revisiting normalization layer in the transformer architecture. arXiv preprint arXiv:2502.02732.
- Zhuo, Z., Zeng, Y., Wang, Y., Zhang, S., Yang, J., Li, X., … & Ma, J. (2025). HybridNorm: Towards Stable and Efficient Transformer Training via Hybrid Normalization. arXiv preprint arXiv:2503.04598.
- Rybakov, O., Chrzanowski, M., Dykas, P., Xue, J., & Lanir, B. (2024). Methods of improving llm training stability. arXiv preprint arXiv:2410.16682.
- Su, J., Ahmed, M., Lu, Y., Pan, S., Bo, W., & Liu, Y. (2024). Roformer: Enhanced transformer with rotary position embedding. Neurocomputing, 568, 127063.
- Press, O., Smith, N. A., & Lewis, M. (2021). Train short, test long: Attention with linear biases enables input length extrapolation. arXiv preprint arXiv:2108.12409.
- Barbero, F., Vitvitskyi, A., Perivolaropoulos, C., Pascanu, R., & Veličković, P. (2024). Round and round we go! what makes rotary positional encodings useful?. arXiv preprint arXiv:2410.06205.
- Yang, B., Venkitesh, B., Talupuru, D., Lin, H., Cairuz, D., Blunsom, P., & Locatelli, A. (2025). Rope to nope and back again: A new hybrid attention strategy. arXiv preprint arXiv:2501.18795.
- Kazemnejad, A., Padhi, I., Natesan Ramamurthy, K., Das, P., & Reddy, S. (2023). The impact of positional encoding on length generalization in transformers. Advances in Neural Information Processing Systems, 36, 24892-24928.
- Nakanishi, K. M. (2025). Scalable-softmax is superior for attention. arXiv preprint arXiv:2501.19399.
- Anson, B., Wang, X., & Aitchison, L. (2025). Scale-invariant attention. arXiv preprint arXiv:2505.17083.
- Chiang, D., & Cholak, P. (2022). Overcoming a theoretical limitation of self-attention. arXiv preprint arXiv:2202.12172.
- Zhang, S., Khan, M., & Papyan, V. Attention Sinks: A’Catch, Tag, Release’Mechanism for Embeddings. In The Thirty-ninth Annual Conference on Neural Information Processing Systems.
- Darcet, T., Oquab, M., Mairal, J., & Bojanowski, P. (2023). Vision transformers need registers. arXiv preprint arXiv:2309.16588.
- Qiu, Z., Wang, Z., Zheng, B., Huang, Z., Wen, K., Yang, S., … & Lin, J. (2025). Gated Attention for Large Language Models: Non-linearity, Sparsity, and Attention-Sink-Free. arXiv preprint arXiv:2505.06708.
- Xiao, G., Tian, Y., Chen, B., Han, S., & Lewis, M. (2023). Efficient streaming language models with attention sinks. arXiv preprint arXiv:2309.17453.
- Beltagy, I., Peters, M. E., & Cohan, A. (2020). Longformer: The long-document transformer. arXiv preprint arXiv:2004.05150.
- Peng, B., Quesnelle, J., Fan, H., & Shippole, E. (2023). Yarn: Efficient context window extension of large language models. arXiv preprint arXiv:2309.00071.
- Yang, S., Kautz, J., & Hatamizadeh, A. (2024). Gated delta networks: Improving mamba2 with delta rule. arXiv preprint arXiv:2412.06464.
- Peng, B., Zhang, R., Goldstein, D., Alcaide, E., Du, X., Hou, H., … & Zhou-Zheng, C. (2025). Rwkv-7” goose" with expressive dynamic state evolution. arXiv preprint arXiv:2503.14456.
- Lieber, O., Lenz, B., Bata, H., Cohen, G., Osin, J., Dalmedigos, I., … & Shoham, Y. (2024). Jamba: A hybrid transformer-mamba language model. arXiv preprint arXiv:2403.19887.
- De, S., Smith, S. L., Fernando, A., Botev, A., Cristian-Muraru, G., Gu, A., … & Gulcehre, C. (2024). Griffin: Mixing gated linear recurrences with local attention for efficient language models. arXiv preprint arXiv:2402.19427.
- Glorioso, P., Anthony, Q., Tokpanov, Y., Whittington, J., Pilault, J., Ibrahim, A., & Millidge, B. (2024). Zamba: A compact 7b ssm hybrid model. arXiv preprint arXiv:2405.16712.
- Ren, L., Liu, Y., Lu, Y., Shen, Y., Liang, C., & Chen, W. (2024). Samba: Simple hybrid state space models for efficient unlimited context language modeling. arXiv preprint arXiv:2406.07522.
- Glorioso, P., Anthony, Q., Tokpanov, Y., Golubeva, A., Shyam, V., Whittington, J., … & Millidge, B. (2024). The zamba2 suite: Technical report. arXiv preprint arXiv:2411.15242.
- Dong, X., Fu, Y., Diao, S., Byeon, W., Chen, Z., Mahabaleshwarkar, A. S., … & Molchanov, P. (2024). Hymba: A hybrid-head architecture for small language models. arXiv preprint arXiv:2411.13676.
- Yang, M., Rezagholizadeh, M., Li, G., Appia, V., & Barsoum, E. (2025). Zebra-Llama: Towards Extremely Efficient Hybrid Models. arXiv preprint arXiv:2505.17272.
- IBM Research. (2025). https://research.ibm.com/blog/bamba-ssm-transformer-model (2025)
- Li, A., Gong, B., Yang, B., Shan, B., Liu, C., Zhu, C., … & Wu, Z. (2025). Minimax-01: Scaling foundation models with lightning attention. arXiv preprint arXiv:2501.08313.
- Minimax Team. (2025). Why full-attention instead of hybrid linear attention in Minimax-2? https://x.com/zpysky1125/status/1983383094607347992 (2025)
- Team, T. H., Liu, A., Zhou, B., Xu, C., Zhou, C., Zhang, C., … & He, L. (2025). Hunyuan-turbos: Advancing large language models through mamba-transformer synergy and adaptive chain-of-thought. arXiv preprint arXiv:2505.15431.
- Blakeman, A., Basant, A., Khattar, A., Renduchintala, A., Bercovich, A., Ficek, A., … & Ghosh, S. (2025). Nemotron-h: A family of accurate and efficient hybrid mamba-transformer models. arXiv preprint arXiv:2504.03624.
- Qwen Team. (2025). Qwen3-Next https://qwen.ai/blog?id=4074cca80393150c248e508aa62983f9cb7d27cd&from=research.latest-advancements-list (2025)
- Gu, Y., Hu, Q., Yang, S., Xi, H., Chen, J., Han, S., & Cai, H. (2025). Jet-Nemotron: Efficient Language Model with Post Neural Architecture Search. arXiv preprint arXiv:2508.15884.
- Team, K., Zhang, Y., Lin, Z., Yao, X., Hu, J., Meng, F., … & Du, Y. (2025). Kimi Linear: An Expressive, Efficient Attention Architecture. arXiv preprint arXiv:2510.26692.
- Gu, A., & Dao, T. (2024, May). Mamba: Linear-time sequence modeling with selective state spaces. In First conference on language modeling.
- Gu, A., Goel, K., & Ré, C. (2021). Efficiently modeling long sequences with structured state spaces. arXiv preprint arXiv:2111.00396.
- Katharopoulos, A., Vyas, A., Pappas, N., & Fleuret, F. (2020, November). Transformers are rnns: Fast autoregressive transformers with linear attention. In International conference on machine learning (pp. 5156-5165). PMLR.
- Wang, K. A., Shi, J., & Fox, E. B. (2025). Test-time regression: a unifying framework for designing sequence models with associative memory. arXiv preprint arXiv:2501.12352.
- Zhang, T., Bi, S., Hong, Y., Zhang, K., Luan, F., Yang, S., … & Tan, H. (2025). Test-time training done right. arXiv preprint arXiv:2505.23884.
- Yang, S., Wang, B., Shen, Y., Panda, R., & Kim, Y. (2023). Gated linear attention transformers with hardware-efficient training. arXiv preprint arXiv:2312.06635.
- Dao, T., & Gu, A. (2024). Transformers are ssms: Generalized models and efficient algorithms through structured state space duality. arXiv preprint arXiv:2405.21060.
- Schlag, I., Irie, K., & Schmidhuber, J. (2021, July). Linear transformers are secretly fast weight programmers. In International conference on machine learning (pp. 9355-9366). PMLR.
- Roy, A., Chou, T., Duvvuri, S. S., Chen, S., Yu, J., Wang, X., … & Anil, R. (2025). Fast and Simplex: 2-Simplicial Attention in Triton. arXiv preprint arXiv:2507.02754.
- Deepseek Team. (2025). Deepseek Sparse Attention. https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/main/DeepSeek_V3_2.pdf (2025)
- Resnick, S. I. (2008). Extreme values, regular variation, and point processes (Vol. 4). Springer Science & Business Media.
- Meta (2025). Llama-4 https://www.llama.com/models/llama-4/ (2025)
- Grok-2 temperature scaling - https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/triton_ops/decode_attention.py#L91 (2025)
- Yang, A., Li, A., Yang, B., Zhang, B., Hui, B., Zheng, B., … & Qiu, Z. (2025). Qwen3 technical report. arXiv preprint arXiv:2505.09388.
- Team, G., Riviere, M., Pathak, S., Sessa, P. G., Hardin, C., Bhupatiraju, S., … & Garg, S. (2024). Gemma 2: Improving open language models at a practical size. arXiv preprint arXiv:2408.00118.
- Team, G., Kamath, A., Ferret, J., Pathak, S., Vieillard, N., Merhej, R., … & Iqbal, S. (2025). Gemma 3 technical report. arXiv preprint arXiv:2503.19786.
- OLMo, Team, et al. “2 OLMo 2 Furious.” arXiv preprint arXiv:2501.00656 (2024).
- OLMo, Team, et al. (2025). “Olmo 3” https://allenai.org/papers/olmo3 (2025).
- Zeng, A., Lv, X., Zheng, Q., Hou, Z., Chen, B., Xie, C., … & Zhou, Z. (2025). Glm-4.5: Agentic, reasoning, and coding (arc) foundation models. arXiv preprint arXiv:2508.06471.
- Team, K., Bai, Y., Bao, Y., Chen, G., Chen, J., Chen, N., … & Zhang, H. (2025). Kimi k2: Open agentic intelligence. arXiv preprint arXiv:2507.20534.
- Agarwal, S., Ahmad, L., Ai, J., Altman, S., Applebaum, A., Arbus, E., … & Zhao, S. (2025). gpt-oss-120b & gpt-oss-20b model card. arXiv preprint arXiv:2508.10925.
- Marin Team (2025). Fixing spikes in Marin 32B pretraining https://x.com/dlwh/status/1938282073296671045 (2025)
- Cohere, T., Ahmadian, A., Ahmed, M., Alammar, J., Alizadeh, M., Alnumay, Y., … & Muppalla, V. (2025). Command a: An enterprise-ready large language model. arXiv preprint arXiv:2504.00698.
- Liu, A., Feng, B., Xue, B., Wang, B., Wu, B., Lu, C., … & Piao, Y. (2024). Deepseek-v3 technical report. arXiv preprint arXiv:2412.19437.
- Liu, A., Feng, B., Wang, B., Wang, B., Liu, B., Zhao, C., … & Xu, Z. (2024). Deepseek-v2: A strong, economical, and efficient mixture-of-experts language model. arXiv preprint arXiv:2405.04434.
- Brogan, J., Bharati, A., Moreira, D., Bowyer, K., Flynn, P., Rocha, A., & Scheirer, W. (2019). Needle in a haystack: A framework for seeking small objects in big datasets. arXiv preprint arXiv:1903.10019.
- Hsieh, C. P., Sun, S., Kriman, S., Acharya, S., Rekesh, D., Jia, F., … & Ginsburg, B. (2024). RULER: What’s the Real Context Size of Your Long-Context Language Models?. arXiv preprint arXiv:2404.06654.
- Liu, N. F., Lin, K., Hewitt, J., Paranjape, A., Bevilacqua, M., Petroni, F., & Liang, P. (2024). Lost in the middle: How language models use long contexts. Transactions of the Association for Computational Linguistics, 12, 157-173.
- Xiong, R., Yang, Y., He, D., Zheng, K., Zheng, S., Xing, C., … & Liu, T. (2020, November). On layer normalization in the transformer architecture. In International conference on machine learning (pp. 10524-10533). PMLR.
- So, D., Mańke, W., Liu, H., Dai, Z., Shazeer, N., & Le, Q. V. (2021). Searching for efficient transformers for language modeling. Advances in neural information processing systems, 34, 6010-6022.
- Zhang, B., & Sennrich, R. (2019). Root mean square layer normalization. Advances in neural information processing systems, 32.
- Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., … & Polosukhin, I. (2017). Attention is all you need. Advances in neural information processing systems, 30.
- Dehghani, M., Gouws, S., Vinyals, O., Uszkoreit, J., & Kaiser, Ł. (2018). Universal transformers. arXiv preprint arXiv:1807.03819.
- 苏剑林 (2025). 为什么线性注意力要加Short Conv?https://kexue.fm/archives/11320 (2025)
- Jordan, K., Bernstein, J., Rappazzo, B., @fernbear.bsky.social, Vlado, B., Jiacheng, Y., Cesista, Franz., Koszarsky, B., @Grad62304977 (2024). modded-nanogpt: Speedrunning the NanoGPT baseline
- nor (2025). The modded nanogpt speedrun, but in JAX and on TPUs. https://nor-blog.pages.dev/posts/2025-08-21-modded-nanogpt-jax (2025)
- Chinatalk team (2025). The Z.ai Playbook. https://www.chinatalk.media/p/the-zai-playbook (2025)
Final notes
Note that some of the explanations above are heuristic or just hypotheticals, if you find that something is incorrect, please let me know.
Cite this post
@online{a-short-note-on-some-aspects-of-long-context-attention,
author = {nor},
title = {A short note on some aspects of long context attention},
year = {2025},
month = {11},
day = {27},
url = {https://nor-blog.pages.dev/posts/2025-11-27-attention-and-long-context/},
}