It's a Jax, Jax, Jax, Jax World (statmodeling.stat.columbia.edu)

🤖 AI Summary
Stan’s developers warn that the probabilistic-programming landscape is tilting toward JAX: users and institutions are increasingly replacing Stan with JAX-based tools (often via NumPyro). Bob highlights concrete signs — a talk at StanCon where the presenter switched to NumPyro, the CDC porting Stan workflows to JAX, an L.A. Dodgers job listing favoring JAX/NumPyro, and a draft MCMC Handbook chapter showing massively parallel HMC on GPUs with JAX. New samplers (micro-canonical HMC variants) are being implemented primarily in JAX and shared through BlackJAX and Inference Gym, and adapters let JAX call Stan’s C++ when needed. Hardware is the main migration bottleneck: JAX shines on GPUs and multicore setups, though improving ARM Macs and broader access to accelerators may accelerate the shift. Technically, the tension is static vs. dynamic autodiff: Stan was built with dynamic autodiff (like PyTorch), while JAX uses XLA’s static-graph approach for much higher performance but with limitations (e.g., restricting parameter-dependent control flow). JAX’s composability, PyTree/ vmap primitives, Oryx transforms (for parameter constraining and Jacobians), and full compiler integration make it easier to express and compile complex differentiable programs directly in Python — often removing the need for a middleware language. Bob doesn’t expect Stan-to-JAX code generation to be a priority and isn’t abandoning Stan — instead he’s developing standalone samplers (WALNUTS) and embracing faster Python-based options like Nutpie, while acknowledging Stan will remain in use for years.
Loading comments...
loading comments...