Limitations of Nesting JAX Functions and Haiku Modules
Contents
[ ]:
import haiku as hk
import jax
import jax.numpy as jnp
TL;DR: A JAX transform inside of a hk.transform
is likely to transform a side effecting function, which will result in an UnexpectedTracerError
. This page describes two ways to get around this.
Limitations of Nesting JAX Functions and Haiku Modules#
Once a Haiku network has been transformed to a pair of pure functions using hk.transform
, it’s possible to freely combine these with any JAX transformations like jax.jit
, jax.grad
, jax.lax.scan
and so on.
If you want to use JAX transformations inside of a hk.transform
however, you need to be more careful. It’s possible, but most functions inside of the hk.transform
boundary are still side effecting, and cannot safely be transformed by JAX. This is a common cause of UnexpectedTracerError
s in code using Haiku. These errors are a result of using a JAX transform on a side effecting function (for more information on this JAX error, see
https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError).
An example with jax.eval_shape
:
[ ]:
def net(x): # inside of a hk.transform, this is still side-effecting
w = hk.get_parameter("w", (2, 2), init=jnp.ones)
return w @ x
def eval_shape_net(x):
output_shape = jax.eval_shape(net, x) # eval_shape on side-effecting function
return net(x) # UnexpectedTracerError!
init, _ = hk.transform(eval_shape_net)
try:
init(jax.random.PRNGKey(666), jnp.ones((2, 2)))
except jax.errors.UnexpectedTracerError:
print("UnexpectedTracerError: applied JAX transform to side effecting function")
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
UnexpectedTracerError: applied JAX transform to side effecting function
These examples use jax.eval_shape
, but could have used any higher-order JAX function (eg. jax.vmap
, jax.lax.scan
, jax.while_loop
, …).
The error points to hk.get_parameter
. This is the operation which makes net
a side effecting function. The side effect in this case is the creation of a parameter, which gets stored into the Haiku state. Similarly you would get an error using hk.next_rng_key
, because it advances the Haiku RNG state and stores a new PRNGKey into the Haiku state. In general, transforming a non-transformed Haiku module will result in an UnexpectedTracerError
.
You could re-write the code above to create the parameter outside of the eval_shape
transformation, making net
a pure function by threading through the parameter explictly as an argument:
[ ]:
def net(w, x): # no side effects!
return w @ x
def eval_shape_net(x):
w = hk.get_parameter("w", (3, 2), init=jnp.ones)
output_shape = jax.eval_shape(net, w, x) # net is now side-effect free
return output_shape, net(w, x)
key = jax.random.PRNGKey(777)
x = jnp.ones((2, 3))
init, apply = hk.transform(eval_shape_net)
params = init(key, x)
apply(params, key, x)
(ShapeDtypeStruct(shape=(3, 3), dtype=float32),
DeviceArray([[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]], dtype=float32))
However, that’s not always possible. Consider the following code which calls a Haiku module (hk.nets.MLP
) which we don’t own. This module will internally call get_parameter
.
[ ]:
def eval_shape_net(x):
net = hk.nets.MLP([300, 100])
output_shape = jax.eval_shape(net, x)
return output_shape, net(x)
init, _ = hk.transform(eval_shape_net)
try:
init(jax.random.PRNGKey(666), jnp.ones((2, 2)))
except jax.errors.UnexpectedTracerError:
print("UnexpectedTracerError: applied JAX transform to side effecting function")
UnexpectedTracerError: applied JAX transform to side effecting function
Using hk.lift#
We want a way to get access to our implicit Haiku state, and get a functionally pure version of hk.nets.MLP
. The way to usually achieve this is by using a hk.transform
, so all we need is a way to nest an inner hk.tranform
inside an outer hk.transform
! We’ll create another pair of init
and apply
functions through hk.transform
, and these can then be safely combined with any higher-order JAX function.
However, we need a way to register this nested hk.tranform
state into the outer scope. We can use hk.lift
for this. Wrapping our inner init
function with hk.lift
will register our inner params
into the outer parameter scope.
[ ]:
def eval_shape_net(x):
net = hk.nets.MLP([300, 100]) # still side-effecting
init, apply = hk.transform(net) # nested transform
params = hk.lift(init, name="inner")(hk.next_rng_key(), x) # register parameters in outer module scope with name "inner"
output_shape = jax.eval_shape(apply, params, hk.next_rng_key(), x) # apply is a functionaly pure function and can be transformed!
out = net(x)
return out, output_shape
init, apply = hk.transform(eval_shape_net)
params = init(jax.random.PRNGKey(777), jnp.ones((100, 100)))
apply(params, jax.random.PRNGKey(777), jnp.ones((100, 100)))
jax.tree_util.tree_map(lambda x: x.shape, params)
FlatMap({
'inner/mlp/~/linear_0': FlatMap({'b': (300,), 'w': (100, 300)}),
'inner/mlp/~/linear_1': FlatMap({'b': (100,), 'w': (300, 100)}),
'mlp/~/linear_0': FlatMap({'b': (300,), 'w': (100, 300)}),
'mlp/~/linear_1': FlatMap({'b': (100,), 'w': (300, 100)}),
})
Using Haiku versions of JAX transforms#
Haiku also provides wrapped versions of some of the JAX functions for convenience. For example: hk.grad
, hk.vmap
, …. See https://dm-haiku.readthedocs.io/en/latest/api.html#jax-fundamentals for a full list of available functions.
These wrappers apply the JAX function to a functionally pure version of the Haiku function, by doing the explicit state threading for you. They don’t introduce an extra namescoping level like lift
does.
[ ]:
def eval_shape_net(x):
net = hk.nets.MLP([300, 100]) # still side-effecting
output_shape = hk.eval_shape(net, x) # hk.eval_shape threads through the Haiku state for you
out = net(x)
return out, output_shape
init, apply = hk.transform(eval_shape_net)
params = init(jax.random.PRNGKey(777), jnp.ones((100, 100)))
out = apply(params, jax.random.PRNGKey(777), jnp.ones((100, 100)))
Summary#
To summarize, some good and bad examples of combining JAX transforms and Haiku modules:
What? |
Works? |
Example |
---|---|---|
vmapping outside a hk.transform |
✔ yes! |
jax.vmap(hk.transform(hk.nets.ResNet50)) |
vmapping inside a hk.transform |
✖ unexpected tracer error |
hk.transform(jax.vmap(hk.nets.ResNet50)) |
vmapping a nested hk.transform (without lift) |
✖ inner state is not registered |
hk.transform(jax.vmap(hk.transform(hk.nets.ResNet50))) |
vmapping a nested hk.transform (with lift) |
✔ yes! |
hk.transform(jax.vmap(hk.lift(hk.transform(hk.nets.ResNet50)))) |
using hk.vmap |
✔ yes! |
hk.transform(hk.vmap(hk.nets.ResNet50)) |