🤖 AI Summary
Grampax is a new PyPI package (Python 3.10+) that brings a torch.autocast-style mixed-precision convenience layer to JAX. It provides an autocast transformation that wraps any callable (functions or Equinox models) and automatically casts selected expensive FP ops — by default matrix multiplications (lax.dot_general_p) and convolutions (lax.conv_general_dilated_p) — into a lower precision (default jnp.bfloat16) for the op, then casts outputs back to the original dtype. Because the wrapper is applied at trace-time using Quax, autocast plays nicely with JAX transformations (jax.grad, vjp, etc.), so forward/backward behavior is consistent without manual rewrites.
Technically, Grampax wraps the first positional argument in an AutocastArray that propagates through primitives; when an "autocast" primitive is hit its inputs are cast to the configured dtype and results are cast back. You can override the default mapping with a config dict to set per-primitive dtypes (including higher precision for specific ops) or selectively apply autocast to submodules (Equinox model surgery). It doesn’t implement AMP-style master parameter copies or automatic loss-scaling, but defaults to bfloat16 (often avoiding loss-scaling) and gives examples for manual/constant scaling; it also interoperates with other Quax types, although complex multi-type primitive dispatches may require explicit Quax registrations.
Loading comments...
login to comment
loading comments...
no comments yet