Limitations of using JAX transforms inside of networks¶
Once a Haiku network has been transformed to a pure function using
hk.transform(), it’s possible to freely combine it with
any JAX transformations like
jax.grad(), and so on.
It’s also possible to use JAX transformations inside of a Haiku network, but
there can be issues with JAX’s tracing mechanism since Haiku functions can have
side effects (through
hk.get_parameter() and other stateful Haiku
To work around this, Haiku provides wrapped versions of JAX transforms under
haiku namespace. You can access these as
hk.grad(), and so on.
These wrappers turn the underlying Haiku function into a pure function, apply the corresponding JAX transformation and re-package the result as a stateful Haiku function. The wrapped transforms are currently considered experimental.