Haiku Documentation
Contents
Haiku Documentation#
Haiku is a library built on top of JAX designed to provide simple, composable abstractions for machine learning research.
import haiku as hk
import jax
import jax.numpy as jnp
def forward(x):
mlp = hk.nets.MLP([300, 100, 10])
return mlp(x)
forward = hk.transform(forward)
rng = hk.PRNGSequence(jax.random.PRNGKey(42))
x = jnp.ones([8, 28 * 28])
params = forward.init(next(rng), x)
logits = forward.apply(params, next(rng), x)
Installation#
See https://github.com/google/jax#pip-installation for instructions on installing JAX.
We suggest installing the latest version of Haiku by running:
$ pip install git+https://github.com/deepmind/dm-haiku
Alternatively, you can install via PyPI:
$ pip install -U dm-haiku
Known issues#
Warning
Using JAX transformations like jax.jit()
and jax.remat()
inside
of Haiku networks can lead to hard to interpret tracing errors and
potentially silently wrong results. Read Limitations of Nesting JAX Functions and Haiku Modules to find
out how to work around these issues.
Contribute#
Issue tracker: https://github.com/deepmind/dm-haiku/issues
Source code: https://github.com/deepmind/dm-haiku/tree/main
Support#
If you are having issues, please let us know by filing an issue on our issue tracker.
License#
Haiku is licensed under the Apache 2.0 License.