🤖 AI Summary
Berkeley’s Language Model Policy Optimization (LMPO) team published a focused performance note on a minimal JAX-based RL loop for LLMs, targeting end-to-end ease-of-research while keeping experiments fast on TPU pods (tested on a v4-32 via the TPU Research Cloud). Rather than split sampling and training into separate frameworks, they reimplemented Qwen3‑1.7B as a single Flax graph that supports a KV cache so the same code handles prefill, autoregressive sampling and gradient updates—reducing mismatch risk and simplifying deployment. The key point: for LLM-based RL the sampling phase (autoregressive per-token inference) often dominates cost, so optimizing memory and communication is crucial for throughput in multi‑TPU setups using FSDP sharding.
They profiled a naive sampling pass (batch 32, seq len 1024, 28 layers, 8 KV heads, head dim 128) and found the KV cache drove most of the 18.6 GB per-device footprint. Two practical wins cut that in half: store the KV cache in bfloat16 (bf16) instead of fp32, and avoid paying memory twice for the cache (the latter by exploiting JAX compilation/aliasing patterns). Using bf16 reduced total per-device memory from ~18.6 GB to ~9.9 GB, enabling larger effective batch sizes (which improves seqs/sec because communication/VMEM transfers dominate small-token forward passes). They also demonstrate using JAX’s compiled lower/compile memory_analysis to guide optimizations, a useful pattern for anyone tuning LLM RL workloads on TPUs.
Loading comments...
login to comment
loading comments...
no comments yet