Haiku Documentation

Haiku is a library built on top of JAX designed to provide simple, composable abstractions for machine learning research.

import haiku as hk
import jax
import jax.numpy as jnp

def forward(x):
  mlp = hk.nets.MLP([300, 100, 10])
  return mlp(x)

forward = hk.transform(forward)

rng = jax.random.PRNGKey(42)
x = jnp.ones([8, 28 * 28])
params = forward.init(rng, x)
logits = forward.apply(params, x)

Installation

See https://github.com/google/jax#pip-installation for instructions on installing JAX.

Install Haiku by running:

$ pip install git+https://github.com/deepmind/dm-haiku

Known issues

Warning

Using JAX transformations like jax.jit() and jax.remat() inside of Haiku networks can lead to hard to interpret tracing errors and potentially silently wrong results. Read Limitations of using JAX transforms inside of networks to find out how to work around these issues.

Support

If you are having issues, please let us know by filing an issue on our issue tracker.

License

Haiku is licensed under the Apache 2.0 License.

Indices and tables