🤖 AI Summary
A step‑by‑step guide demonstrates how to write a high‑performance matrix‑multiply kernel for NVIDIA Blackwell using Pallas/plgpu primitives, starting from a simple single‑warpgroup implementation and iteratively adding optimizations until it matches or exceeds tuned libraries (cuBLAS/CUTLASS). The paper shows measured TensorCore utilization rising from 37.6% in the basic kernel to 69.4% with full grid tiling—109.6% of the cuBLAS baseline and comparable to CUTLASS—while cautioning that input data distribution (their benchmarks use iid normal float16, a relatively slow case) can strongly affect numbers.
Key technical moves behind the speedups include: careful tiling (tile_m, tile_n, tile_k) with a rule of thumb tile_k ≈ 128 / bytewidth; SMEM/TMEM layout tuning via swizzle/transforms to match MMA expectations; an explicit compute/memory pipeline (plgpu.emit_pipeline) with delay_release and multi‑stage prefetching (max_concurrent_steps); async GMEM→SMEM copies (TMA) and TMEM accumulators with async_load_tmem for the epilogue; fine‑grained barrier management to serialize loads and MMA consumption; warp specialization by splitting a warpgroup into four warps using pl.core_map and plgpu.WarpMesh to separate memory and compute roles; and larger changes like collective (2CTA) MMA, persistent kernels, dedicated epilogue warpgroups, and grid tiling. Full code and an optimized standalone kernel are available in the accompanying test files and the Pallas ops package.
Loading comments...
login to comment
loading comments...
no comments yet