Interactive online version: Open In Colab

[ ]:
import haiku as hk
import jax
import jax.numpy as jnp

TL;DR: A JAX transform inside of a hk.transform is likely to transform a side effecting function, which will result in an UnexpectedTracerError. This page describes two ways to get around this.

Limitations of Nesting JAX Functions and Haiku Modules

Once a Haiku network has been transformed to a pair of pure functions using hk.transform, it’s possible to freely combine these with any JAX transformations like jax.jit, jax.grad, jax.scan and so on.

If you want to use JAX transformations inside of a hk.transform however, you need to be more careful. It’s possible, but most functions inside of the hk.transform boundary are still side effecting, and cannot safely be transformed by JAX. This is a common cause of UnexpectedTracerErrors in code using Haiku. These errors are a result of using a JAX transform on a side effecting function (for more information on this JAX error, see https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError).

An example with jax.eval_shape:

[ ]:
def net(x): # inside of a hk.transform, this is still side-effecting
  w = hk.get_parameter("w", (2, 2), init=jnp.ones)
  return w @ x

def eval_shape_net(x):
  output_shape = jax.eval_shape(net, x) # eval_shape on side-effecting function
  return net(x)                         # UnexpectedTracerError!

init, _ = hk.transform(eval_shape_net)
try:
  init(jax.random.PRNGKey(666), jnp.ones((2, 2)))
except jax.errors.UnexpectedTracerError:
  print("UnexpectedTracerError: applied JAX transform to side effecting function")
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
UnexpectedTracerError: applied JAX transform to side effecting function

These examples use jax.eval_shape, but could have used any higher-order JAX function (eg. jax.vmap, jax.scan, jax.while_loop, …).

The error points to hk.get_parameter. This is the operation which makes net a side effecting function. The side effect in this case is the creation of a parameter, which gets stored into the Haiku state. Similarly you would get an error using hk.next_rng_key, because it advances the Haiku RNG state and stores a new PRNGKey into the Haiku state. In general, transforming a non-transformed Haiku module will result in an UnexpectedTracerError.

You could re-write the code above to create the parameter outside of the eval_shape transformation, making net a pure function by threading through the parameter explictly as an argument:

[ ]:
def net(w, x): # no side effects!
  return w @ x

def eval_shape_net(x):
  w = hk.get_parameter("w", (3, 2), init=jnp.ones)
  output_shape = jax.eval_shape(net, w, x) # net is now side-effect free
  return output_shape, net(w, x)

key = jax.random.PRNGKey(777)
x = jnp.ones((2, 3))
init, apply = hk.transform(eval_shape_net)
params = init(key, x)
apply(params, key, x)
(ShapeDtypeStruct(shape=(3, 3), dtype=float32),
 DeviceArray([[2., 2., 2.],
              [2., 2., 2.],
              [2., 2., 2.]], dtype=float32))

However, that’s not always possible. Consider the following code which calls a Haiku module (hk.nets.MLP) which we don’t own. This module will internally call get_parameter.

[ ]:
def eval_shape_net(x):
  net = hk.nets.MLP([300, 100])
  output_shape = jax.eval_shape(net, x)
  return output_shape, net(x)

init, _ = hk.transform(eval_shape_net)
try:
  init(jax.random.PRNGKey(666), jnp.ones((2, 2)))
except jax.errors.UnexpectedTracerError:
  print("UnexpectedTracerError: applied JAX transform to side effecting function")
UnexpectedTracerError: applied JAX transform to side effecting function

Using hk.lift

We want a way to get access to our implicit Haiku state, and get a functionally pure version of hk.nets.MLP. The way to usually achieve this is by using a hk.transform, so all we need is a way to nest an inner hk.tranform inside an outer hk.transform! We’ll create another pair of init and apply functions through hk.transform, and these can then be safely combined with any higher-order JAX function.

However, we need a way to register this nested hk.tranform state into the outer scope. We can use hk.lift for this. Wrapping our inner init function with hk.lift will register our inner params into the outer parameter scope.

[ ]:
def eval_shape_net(x):
  net = hk.nets.MLP([300, 100])    # still side-effecting
  init, apply = hk.transform(net)  # nested transform
  params = hk.lift(init, name="inner")(hk.next_rng_key(), x) # register parameters in outer module scope with name "inner"
  output_shape = jax.eval_shape(apply, params, hk.next_rng_key(), x) # apply is a functionaly pure function and can be transformed!
  out = net(x)
  return out, output_shape


init, apply = hk.transform(eval_shape_net)
params = init(jax.random.PRNGKey(777), jnp.ones((100, 100)))
apply(params, jax.random.PRNGKey(777), jnp.ones((100, 100)))
jax.tree_map(lambda x: x.shape, params)
FlatMap({
  'inner/mlp/~/linear_0': FlatMap({'b': (300,), 'w': (100, 300)}),
  'inner/mlp/~/linear_1': FlatMap({'b': (100,), 'w': (300, 100)}),
  'mlp/~/linear_0': FlatMap({'b': (300,), 'w': (100, 300)}),
  'mlp/~/linear_1': FlatMap({'b': (100,), 'w': (300, 100)}),
})

Using Haiku versions of JAX transforms

Haiku also provides wrapped versions of some of the JAX functions for convenience. For example: hk.grad, hk.vmap, …. See https://dm-haiku.readthedocs.io/en/latest/api.html#jax-fundamentals for a full list of available functions.

These wrappers apply the JAX function to a functionally pure version of the Haiku function, by doing the explicit state threading for you. They don’t introduce an extra namescoping level like lift does.

[ ]:
def eval_shape_net(x):
  net = hk.nets.MLP([300, 100])         # still side-effecting
  output_shape = hk.eval_shape(net, x)  # hk.eval_shape threads through the Haiku state for you
  out = net(x)
  return out, output_shape


init, apply = hk.transform(eval_shape_net)
params = init(jax.random.PRNGKey(777), jnp.ones((100, 100)))
out = apply(params, jax.random.PRNGKey(777), jnp.ones((100, 100)))

Summary

To summarize, some good and bad examples of combining JAX transforms and Haiku modules:

What?

Works?

Example

vmapping outside a hk.transform

✔ yes!

jax.vmap(hk.transform(hk.nets.ResNet50))

vmapping inside a hk.transform

✖ unexpected tracer error

hk.transform(jax.vmap(hk.nets.ResNet50))

vmapping a nested hk.transform (without lift)

✖ inner state is not registered

hk.transform(jax.vmap(hk.transform(hk.nets.ResNet50)))

vmapping a nested hk.transform (with lift)

✔ yes!

hk.transform(jax.vmap(hk.lift(hk.transform(hk.nets.ResNet50))))

using hk.vmap

✔ yes!

hk.transform(hk.vmap(hk.nets.ResNet50))