Limitations of using JAX transforms inside of networks

Once a Haiku network has been transformed to a pure function using hk.transform(), it’s possible to freely combine it with any JAX transformations like jax.jit(), jax.grad(), and so on. It’s also possible to use JAX transformations inside of a Haiku network, but there can be issues with JAX’s tracing mechanism since Haiku functions can have side effects (through hk.next_rng_key(), hk.get_parameter() and other stateful Haiku calls). To work around this, Haiku provides wrapped versions of JAX transforms under the haiku namespace. You can access these as hk.jit(), hk.grad(), and so on.

These wrappers turn the underlying Haiku function into a pure function, apply the corresponding JAX transformation and re-package the result as a stateful Haiku function. The wrapped transforms are currently considered experimental.