Haiku and Flax interop π₯
Contents
Haiku and Flax interop π₯#
Utilities to move seamlessly between Haiku and Flax.
Flax inside Haiku#
Using a Flax module inside a hk.transform
(or hk.transform_with_state
) is straight forward.
First construct an instance of your module, and then use hkflax.lift
to βliftβ (see [hk.lift
]) the parameters and state from the Flax module into the Haiku transform.
Example:
[ ]:
import jax
import jax.numpy as jnp
import haiku as hk
import haiku.experimental.flax as hkflax
import flax.linen as flax_nn
def f(x):
mod = hkflax.lift(flax_nn.Dense(10), name='my_flax_module')
x = mod(x)
return x
f = hk.transform(f)
x = jnp.ones([1, 1])
rng = jax.random.PRNGKey(42)
params = f.init(rng, x) # params contains the parameters for MyFlaxModule.
f.apply(params, None, x) # MyFlaxModule will be passed parameters from params.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Array([[ 0.33030465, -1.3496182 , 0.02847686, -1.6579462 , -0.9166192 ,
0.2883583 , -0.046898 , 0.6414894 , -0.404975 , -2.1162813 ]], dtype=float32)
To use a stateful module simply swap hk.transform
for hk.transform_with_state
.
Haiku inside Flax#
There are two supported approaches for converting Haiku
code to Flax
. Both produce a Flax linen nn.Module
which encapsulates the Haiku code and provides init
and apply
methods to create and use parameters and state.
Convert an ``hk.Module` to
nn.Module
<#hk-Module>`__.Convert an ``hk.transform` to
nn.Module
<#hk-transform>`__.Convert an ``hk.transform_with_state` to
nn.Module
<#hk-transform>`__.
Converting hk.Module
#
For stateless modules you simply need to construct the Flax module via hkflax.Module.create
:
[ ]:
mod = hkflax.Module.create(hk.Linear, 1) # hk.Linear(1)
You can use this like a regular Flax nn.Module
(because it is one!):
[ ]:
rng = jax.random.PRNGKey(42)
x = jnp.ones([1, 1])
variables = mod.init(rng, x)
out = mod.apply(variables, x)
For a stateful module like ResNet, you need to also handle output state, again this is the same as stateful Flax modules:
[ ]:
mod = hkflax.Module.create(hk.nets.ResNet50, 10)
# Regular Flax code from here on:
rng = jax.random.PRNGKey(42)
x = jnp.ones([1, 224, 224, 3])
variables = mod.init(rng, x, is_training=True)
for _ in range(10):
out, state_out = mod.apply(variables, x, is_training=True,
mutable=['state'])
variables = {**variables, **state_out}
Converting hk.transform
or hk.transform_with_state
#
hkflax.Module
can be created from the result of hk.transform
or hk.transform_with_state
if you prefer:
[ ]:
def mlp(x):
x = hk.Linear(300)(x)
x = hk.Linear(100)(x)
x = hk.Linear(10)(x)
return x
mlp = hk.transform(mlp)
mlp = hkflax.Module(mlp)
rng = jax.random.PRNGKey(42)
x = jnp.ones([1, 28 * 28])
variables = mlp.init(rng, x)
out = mlp.apply(variables, x)
Gotchas#
Initialization is different#
Flax and Haiku take different approaches to RNG key splitting. As such at the moment the parameters returned from hkflax.Module(f).init
will differ from hk.transform(f).init
.
We have a route to support making Haiku transform match initialization of the Flax module, but there is not a path for the opposite direction at the moment.
If aligning initialization across Haiku and Flax is important to you, we recommend using one of the libraries to create parameters, and then manipulate the params/state dictionary to match the other library as needed:
# Utilities.
import haiku.data_structures as hkds
make_flat = {f'{m}/{n}': w for m, n, w in hkds.traverse(d)}
def make_nested(d):
out = {}
for k, w in d.items():
m, n = k.rsplit('/', 1)
out.setdefault(m, {})
out[m][n] = w
return out
# The two modules here should be equivalent when run with Flax or Haiku.
f = hk.transform_with_state(...)
flax_mod = hkflax.Module(f)
# Option 1: Convert Haiku initialized params/state to Flax.
params, state = f.init(...)
variables = {'params': make_flat(params), 'state': make_flat(state)}
# Option 2: Convert Flax initialized variables to Haiku.
variables = flax_mod.init(...)
params = make_nested(variables.get('params', {}))
state = make_nested(variables.get('state', {}))
# The output of the Haiku transformed function or the Flax function should be
# equivalent with either init.
out, state = f.apply(params, state, ...)
out, variables_out = flax_mod.apply(variables, ..., mutable=['state'])
Multiple forward methods#
hkflax.Module
only support __call__
at the moment, please let us know if this is a blocker for you.