Interactive online version: Open In Colab

Haiku and Flax interop πŸ₯‚#

Utilities to move seamlessly between Haiku and Flax.

Flax inside Haiku#

Using a Flax module inside a hk.transform (or hk.transform_with_state) is straight forward.

First construct an instance of your module, and then use hkflax.lift to β€œlift” (see [hk.lift]) the parameters and state from the Flax module into the Haiku transform.

Example:

[ ]:
import jax
import jax.numpy as jnp
import haiku as hk
import haiku.experimental.flax as hkflax
import flax.linen as flax_nn

def f(x):
  mod = hkflax.lift(flax_nn.Dense(10), name='my_flax_module')
  x = mod(x)
  return x

f = hk.transform(f)
x = jnp.ones([1, 1])
rng = jax.random.PRNGKey(42)
params = f.init(rng, x)   # params contains the parameters for MyFlaxModule.
f.apply(params, None, x)  # MyFlaxModule will be passed parameters from params.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Array([[ 0.33030465, -1.3496182 ,  0.02847686, -1.6579462 , -0.9166192 ,
         0.2883583 , -0.046898  ,  0.6414894 , -0.404975  , -2.1162813 ]],      dtype=float32)

To use a stateful module simply swap hk.transform for hk.transform_with_state.

Haiku inside Flax#

There are two supported approaches for converting Haiku code to Flax. Both produce a Flax linen nn.Module which encapsulates the Haiku code and provides init and apply methods to create and use parameters and state.

  • Convert an ``hk.Module` to nn.Module <#hk-Module>`__.

  • Convert an ``hk.transform` to nn.Module <#hk-transform>`__.

  • Convert an ``hk.transform_with_state` to nn.Module <#hk-transform>`__.

Converting hk.Module#

For stateless modules you simply need to construct the Flax module via hkflax.Module.create:

[ ]:
mod = hkflax.Module.create(hk.Linear, 1)  # hk.Linear(1)

You can use this like a regular Flax nn.Module (because it is one!):

[ ]:
rng = jax.random.PRNGKey(42)
x = jnp.ones([1, 1])
variables = mod.init(rng, x)
out = mod.apply(variables, x)

For a stateful module like ResNet, you need to also handle output state, again this is the same as stateful Flax modules:

[ ]:
mod = hkflax.Module.create(hk.nets.ResNet50, 10)

# Regular Flax code from here on:
rng = jax.random.PRNGKey(42)
x = jnp.ones([1, 224, 224, 3])
variables = mod.init(rng, x, is_training=True)
for _ in range(10):
  out, state_out = mod.apply(variables, x, is_training=True,
                             mutable=['state'])
  variables = {**variables, **state_out}

Converting hk.transform or hk.transform_with_state#

hkflax.Module can be created from the result of hk.transform or hk.transform_with_state if you prefer:

[ ]:
def mlp(x):
  x = hk.Linear(300)(x)
  x = hk.Linear(100)(x)
  x = hk.Linear(10)(x)
  return x

mlp = hk.transform(mlp)
mlp = hkflax.Module(mlp)

rng = jax.random.PRNGKey(42)
x = jnp.ones([1, 28 * 28])
variables = mlp.init(rng, x)
out = mlp.apply(variables, x)

Gotchas#

Initialization is different#

Flax and Haiku take different approaches to RNG key splitting. As such at the moment the parameters returned from hkflax.Module(f).init will differ from hk.transform(f).init.

We have a route to support making Haiku transform match initialization of the Flax module, but there is not a path for the opposite direction at the moment.

If aligning initialization across Haiku and Flax is important to you, we recommend using one of the libraries to create parameters, and then manipulate the params/state dictionary to match the other library as needed:

# Utilities.
import haiku.data_structures as hkds

make_flat = {f'{m}/{n}': w for m, n, w in hkds.traverse(d)}

def make_nested(d):
  out = {}
  for k, w in d.items():
    m, n = k.rsplit('/', 1)
    out.setdefault(m, {})
    out[m][n] = w
  return out

# The two modules here should be equivalent when run with Flax or Haiku.
f = hk.transform_with_state(...)
flax_mod = hkflax.Module(f)

# Option 1: Convert Haiku initialized params/state to Flax.
params, state = f.init(...)
variables = {'params': make_flat(params), 'state': make_flat(state)}

# Option 2: Convert Flax initialized variables to Haiku.
variables = flax_mod.init(...)
params = make_nested(variables.get('params', {}))
state = make_nested(variables.get('state', {}))

# The output of the Haiku transformed function or the Flax function should be
# equivalent with either init.
out, state = f.apply(params, state, ...)
out, variables_out = flax_mod.apply(variables, ..., mutable=['state'])

Multiple forward methods#

hkflax.Module only support __call__ at the moment, please let us know if this is a blocker for you.