Show HN: Optimizing DeepSeek's NSA for TPUs – A Kernel Worklog (henryhmko.github.io)

🤖 AI Summary
A developer ported and optimized DeepSeek’s NSA (a GPU-optimized sparse attention algorithm) to TPUs using JAX and Pallas, publishing a Colab with the code, kernels and profiles. Key wins before kernel work: a corrected, simplified form of the paper’s Equation 9 for blockwise pooling under NSA’s stride/divisibility assumptions, and a vectorized JAX rework that avoids gigantic index tensors by using vmap-based gathers — yielding roughly a 286× speedup versus the naive XLA-unfriendly implementation. Profiling showed most time and memory was still spent materializing K_slc/V_slc, motivating fused Pallas kernels to reduce memory bandwidth pressure and shift work to MXU (systolic) units rather than VPU-bound gathers. The writeup digs into why NSA’s dynamic sparsity is hard on TPUs: JAX/XLA dislike runtime/dynamic indexing, Pallas enforces lexicographic grid traversal, and TPUs favor large dense tiles. Practical solutions include scalar prefetch to skip unused blocks, tiling/fused kernels to maximize MXU utilization, and exploiting online softmax’s order invariance to sort top-k block indices so traversal becomes lexicographic (sorting is cheap since top-n is tiny). A caveat: order-invariance holds numerically for BF16 with FP32 accumulation, but low-precision (FP8/FP4) exposes stability and underflow risks, so precision-aware accumulation and traversal order remain important when adapting sparse GPU algorithms to TPU architectures.
Loading comments...
loading comments...