🤖 AI Summary
Researchers implemented a hand-optimized CUDA Top-K kernel for LLM sampling that cuts PyTorch’s torch.topk latency on an H100 from ~1.15 ms to 0.112 ms (a ~10× speedup) — and beats a naive Thrust sort (23 ms) by over 200×. Because Top-K is executed billions of times in production, this single-op improvement can materially reduce inference tail latency and throughput cost for large-vocab sampling (vocab sizes ≈50k, k≈2,048).
The kernel avoids full sorting by building a 512-bin histogram in shared memory, running a prefix-sum to find the bin containing the k-th element, collecting all items above that threshold, and only sorting the small “threshold” bin. A key trick is extractBinIdx: convert float→FP16 bit-pattern, flip bits for negatives and set the sign bit for positives so the IEEE‑754 layout becomes order-preserving; the emitted PTX uses selp (select-with-predicate) to avoid divergent branches. Practical lessons include atomic adds into shared mem for bin counts, memory-level parallelism (MLP) benefits from heavy loop unrolling (8× unroll was fastest), and surprising nvcc/PTX behaviors — 4× compiler unroll plus algebraic index transforms actually slowed the kernel. The takeaway: careful bit-level FP tricks and compiler-aware optimization (and attention to FP16/BF16 conversion/rounding) yield large real-world inference gains, but nvcc can emit nonoptimal PTX even at -O3.
Loading comments...
login to comment
loading comments...
no comments yet