🤖 AI Summary
The author set out to “learn CUTLASS the hard way” by building up from first principles to optimized GEMM kernels on an NVIDIA RTX 4090, with the explicit aim of understanding what real CUTLASS kernels do under the hood and trying to beat PyTorch matmul. The write-up walks through successive optimization stages—naive per-thread element kernels, shared-memory tiling, tensor cores/WMMA, swizzling, pipelining and autotuning—and emphasizes modern low-precision matmuls (bf16, fp8, mxfp8) that dominate today’s ML workloads. Along the way the author used visualizations (created with Claude Code) and references recent deep dives like Pranjal Shankhodhar’s and Aleksa Gordić’s posts.
Key technical takeaways: GEMM is C = αAB + βC with computational cost 2·M·N·K FLOPs, and optimized GEMMs aim for very high arithmetic intensity so they become compute-bound on GPUs. On an RTX 4090 (82.6 TFLOPS FP32, 660.6 TFLOPS tensor FP8, 1,008 GB/s memory bandwidth, 128 KB shared memory/SM) a naive CUDA kernel that assigns one output element per thread suffers from poor memory coalescing and resource utilization. Benchmarks show the naive kernel is ~133× slower than PyTorch (only ~0.76% of PyTorch’s TFLOPS) and uses ~0.9 GB/s vs 120 GB/s bandwidth. The post therefore demonstrates why warp-level accesses, shared-memory tiling, tensor-core WMMA, swizzling and careful autotuning are essential to approach CUTLASS/cuBLAS performance.
Loading comments...
login to comment
loading comments...
no comments yet