Haiku is a library built on top of JAX designed to provide simple, composable abstractions for machine learning research.
import haiku as hk import jax import jax.numpy as jnp def forward(x): mlp = hk.nets.MLP([300, 100, 10]) return mlp(x) forward = hk.transform(forward) rng = jax.random.PRNGKey(42) x = jnp.ones([8, 28 * 28]) params = forward.init(rng, x) logits = forward.apply(params, rng, x)
See https://github.com/google/jax#pip-installation for instructions on installing JAX.
Install Haiku by running:
$ pip install git+https://github.com/deepmind/dm-haiku
Using JAX transformations like
jax.remat() inside of Haiku
networks can lead to hard to interpret tracing errors and potentially
silently wrong results. Read Limitations of using JAX transforms inside of networks to find out
how to work around these issues.
Haiku is licensed under the Apache 2.0 License.