Interactive online version: Open In Colab

Parameter sharing in Haiku#

Introduction#

In Haiku, parameter reuse is determined uniquely by module instance names, i.e., if a module instance has the same name as another module instance, they share parameters.

Unless specified, module names are automatically determined by Haiku based on the module class name (following a pattern that was established in TensorFlow 1 with Sonnet V1). More in detail, module naming follows these rules:

  1. Module names are assigned when the module instance is constructed. Unless a module instance name is provided as an argument to the constructor, Haiku generates one from the current module class name (basically: to_snake_case(CurrentClassName)).

  2. If the module instance name doesn’t end in a _N (where N is a number) and another module instance with the same name already exists, Haiku adds an incremental number to the end of the new module instance name (e.g. module_1).

  3. When two modules are nested (i.e., a module instance is constructed inside another module’s class definition), then the inner module name will be prepended by the outer module name and, possibly (see the next point), the outer module current method being called. The constructor (i.e., __init__) is replaced by the tilde ~ symbol.

  4. If the calling method name is __call__ this will be ignored (the method name will be prepended by the outer module name only).

  5. When there are multiple layers of nesting, the previous rule is applied at each level of nesting, and each inner module name is based on the module name and calling method name of the module immediately preceding the current module in the hierarchy of calls.

Let’s see how this works with a practical example.

Flat modules (no nesting)#

This section covers parameter sharing when the modules are not nested.

[4]:
#@title Imports and accessory functions
import functools
import haiku as hk
import jax
import jax.numpy as jnp


def parameter_shapes(params):
  """Make printing parameters a little more readable."""
  return jax.tree_util.tree_map(lambda p: p.shape, params)


def transform_and_print_shapes(fn, x_shape=(2, 3)):
  """Print name and shape of the parameters."""
  rng = jax.random.PRNGKey(42)
  x = jnp.ones(x_shape)

  transformed_fn = hk.transform(fn)
  params = transformed_fn.init(rng, x)
  print('\nThe name and shape of the parameters are:')
  print(parameter_shapes(params))

def assert_all_equal(params_1, params_2):
  assert all(jax.tree_util.tree_leaves(
      jax.tree_util.tree_map(lambda a, b: (a == b).all(), params_1, params_2)))
[6]:
w_init = hk.initializers.TruncatedNormal(stddev=1)

class SimpleModule(hk.Module):
  """A simple module class with one variable."""

  def __init__(self, output_channels, name=None):
    super().__init__(name)
    assert isinstance(output_channels, int)
    self._output_channels = output_channels

  def __call__(self, x):
    w_shape = (x.shape[-1], self._output_channels)
    w = hk.get_parameter("w", w_shape, x.dtype, init=w_init)
    return jnp.dot(x, w)
[ ]:
def f(x):
  # This instance will be named `a_simple_module`.
  simple = SimpleModule(output_channels=2)
  simple_out = simple(x)  # implicitly calls module_install.__call__()
  print(f'The name assigned to "simple" is: "{simple.module_name}".')
  return simple_out

transform_and_print_shapes(f)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
The name assigned to "simple" is: "simple_module".

The name and shape of the parameters are:
{'simple_module': {'w': (3, 2)}}

Great! Here we see that indeed if we create a SimpleModule instance and do not specify a name, Haiku assigns to it the name a_simple_module. This is also reflected in the parameters associated to the module.

What happens if we instantiate SimpleModule twice though? Does Haiku assign to both instances the same name?

[ ]:
def f(x):
  # This instance will be named `a_simple_module`.
  simple_one = SimpleModule(output_channels=2)
  # This instance will be named `a_simple_module_1`.
  simple_two = SimpleModule(output_channels=2)
  first_out = simple_one(x)
  second_out = simple_two(x)
  print(f'The name assigned to "simple_one" is: "{simple_one.module_name}".')
  print(f'The name assigned to "simple_two" is: "{simple_two.module_name}".')
  return first_out + second_out

transform_and_print_shapes(f)
The name assigned to "simple_one" is: "simple_module".
The name assigned to "simple_two" is: "simple_module_1".

The name and shape of the parameters are:
{'simple_module': {'w': (3, 2)}, 'simple_module_1': {'w': (3, 2)}}

As expected Haiku is smart enough to differentiate the two instances and avoid accidental parameter sharing: the second instance is named a_simple_module_1 and each instance has its own set of parameters. Good!

But what if we wanted to share parameters? In this case, we would have to instantiate the module only once and call it multiple times. Let’s see how this works:

[ ]:
def f(x):
  # This instance will be named `a_simple_module`.
  simple_one = SimpleModule(output_channels=2)
  first_out = simple_one(x)
  second_out = simple_one(x)  # share parameters w/ previous call
  print(f'The name assigned to "simple_one" is: "{simple_one.module_name}".')
  return first_out + second_out

transform_and_print_shapes(f)
The name assigned to "simple_one" is: "simple_module".

The name and shape of the parameters are:
{'simple_module': {'w': (3, 2)}}

Nested modules#

In this section we’ll see what happens when we nest one hk.Module into another.

[ ]:
class NestedModule(hk.Module):
  """A module class with a nested module created in the constructor."""

  def __init__(self, output_channels, name=None):
    super().__init__(name)
    assert isinstance(output_channels, int)
    self._output_channels = output_channels
    self.inner_simple = SimpleModule(self._output_channels)

  def __call__(self, x):
    w_shape = (x.shape[-1], self._output_channels)
    # Another variable that is also called `w`.
    w = hk.get_parameter("w", w_shape, x.dtype, init=w_init)
    return jnp.dot(x, w) + self.inner_simple(x)
[ ]:
def f(x):
  # This will be named `a_nested_module` and the SimpleModule instance created
  # inside it will be named `a_nested_module/a_simple_module`.
  nested = NestedModule(output_channels=2)
  nested_out = nested(x)
  print('The name assigned to outer module (i.e., "nested") is: '
        f'"{nested.module_name}".')
  print('The name assigned to the inner module (i.e., inside "nested") is: "'
        f'{nested.inner_simple.module_name}".')
  return nested_out

transform_and_print_shapes(f)
The name assigned to outer module (i.e., "nested") is: "nested_module".
The name assigned to the inner module (i.e., inside "nested") is: "nested_module/~/simple_module".

The name and shape of the parameters are:
{'nested_module': {'w': (3, 2)}, 'nested_module/~/simple_module': {'w': (3, 2)}}

As expected, the inner module name depends on: (a) the outer module name; and (b) the outer module’s method being called.

Note also how the outer module’s constructor name __init__ is replaced by a ~ in the parameter names. If the inner module instance was created inside the __call__ method of the outer module, the inner module instance name would have been 'a_nested_module/a_simple_module'.

In this example we defined all the modules from scratch, but the same holds for any of the modules and networks defined in Haiku, e.g., hk.Linear, hk.nets.MLP, … . If you are curious, see what happens if you assign to self.inner_simple an instance of hk.Linear instead of SimpleModule.

Let’s try now multiple levels of nesting:

[ ]:
class TwiceNestedModule(hk.Module):
  """A module class with a nested module containing a nested module."""

  def __init__(self, output_channels, name=None):
    super().__init__(name)
    assert isinstance(output_channels, int)
    self._output_channels = output_channels
    self.inner_nested = NestedModule(self._output_channels)

  def __call__(self, x):
    w_shape = (x.shape[-1], self._output_channels)
    w = hk.get_parameter("w", w_shape, x.dtype, init=w_init)
    return jnp.dot(x, w) + self.inner_nested(x)
[ ]:
def f(x):
  """Create the module instances and inspect their names."""
  # Instantiate a NestedModule instance. This will be named `a_nested_module`.
  # The SimpleModule instance created inside it will be named
  # a_nested_module/a_simple_module`.
  outer = TwiceNestedModule(output_channels=2)
  outer_out = outer(x)
  print(f'The name assigned to the most outer class is: "{outer.module_name}".')
  print('The name assigned to the module inside "double_nested" is: "'
        f'{outer.inner_nested.module_name}".')
  print('The name assigned to the module inside it is "'
        f'{outer.inner_nested.inner_simple.module_name}".')
  return outer_out

transform_and_print_shapes(f)
The name assigned to the most outer class is: "twice_nested_module".
The name assigned to the module inside "double_nested" is: "twice_nested_module/~/nested_module".
The name assigned to the module inside it is "twice_nested_module/~/nested_module/~/simple_module".

The name and shape of the parameters are:
{'twice_nested_module': {'w': (3, 2)}, 'twice_nested_module/~/nested_module': {'w': (3, 2)}, 'twice_nested_module/~/nested_module/~/simple_module': {'w': (3, 2)}}

Great, this also works as expected: the full hierarchy of module names and calls is reflected in the inner module names.

Multitransform: merge the parameters without sharing them#

Sometimes when we have multiple transformed functions it can be convenient to merge all the parameters in a unique structure, to reduce the number of dictionaries we have to store and pass around. It can be the case though that some of these functions instantiate the same modules, and we want to make sure that their parameters don’t get shared accidentally.

hk.multi_transform comes to rescue in this case, and merges the parameters in a unique dictionary making sure that duplicated parameters are renamed to avoid accidental sharing.

[ ]:
def f(x):
  """A SimpleModule followed by a Linear layer."""
  module_instance = SimpleModule(output_channels=2)
  out = module_instance(x)
  linear = hk.Linear(40)
  return linear(out)

def g(x):
  """A SimpleModule followed by an MLP."""
  module_instance = SimpleModule(output_channels=2)
  return module_instance(x) * 2  # twice

# Transform both functions, and print their respective parameter shapes.
rng = jax.random.PRNGKey(42)
x = jnp.ones((2, 3))
transformed_f = hk.transform(f)
params_f = transformed_f.init(rng, x)
transformed_g = hk.transform(g)
params_g = transformed_g.init(rng, x)
print('f parameters:', parameter_shapes(params_f))
print('g parameters:', parameter_shapes(params_g))

# Transform both functions at once with hk.multi_transform , and print the
# resulting merged parameter structure.

def multitransform_f_and_g():
  def template(x):
    return f(x), g(x)
  return template, (f, g)
init, (f_apply, g_apply) = hk.multi_transform(multitransform_f_and_g)
merged_params = init(rng, x)

print('\nThe name and shape of the multi-transform parameters are:\n',
      parameter_shapes(merged_params))
f parameters: {'linear': {'b': (40,), 'w': (2, 40)}, 'simple_module': {'w': (3, 2)}}
g parameters: {'simple_module': {'w': (3, 2)}}

The name and shape of the multi-transform parameters are:
 {'linear': {'b': (40,), 'w': (2, 40)}, 'simple_module': {'w': (3, 2)}, 'simple_module_1': {'w': (3, 2)}}

In this example f and g both instantiate a SimpleModule instance with the same arguments, and if we transform them separately we see that both dictionaries contain a 'simple_module' key.

When we transform them together instead, hk.multi_transform takes care for us of renaming one of them to 'simple_module_1', thus preventing accidental parameter sharing.

Sharing parameters between transformed functions#

Now that we understood how module names are assigned and how this affects parameter sharing, let’s see how we can share parameters between transformed functions.

In this section we will consider two functions, f and g, and explore different strategies to share parameters. We will consider a number of cases that differ in how many of the modules instantiated by each function are the same, and if their parameters have the same shape.

Case 1: All modules have the same names, and the same shape#

Let’s reuse one of the modules we created before, and try to instantiate it twice inside two different functions:

[ ]:
def f(x):
  """Apply SimpleModule to x."""
  module_instance = SimpleModule(output_channels=2)
  out = module_instance(x)
  return out

def g(x):
  """Like f, but double the output"""
  module_instance = SimpleModule(output_channels=2)
  out = module_instance(x)
  return out * 2

# Transform both functions, and print the parameter shapes.
rng = jax.random.PRNGKey(42)
x = jnp.ones((2, 3))

transformed_f = hk.transform(f)
params_f = transformed_f.init(rng, x)
transformed_g = hk.transform(g)
params_g = transformed_g.init(rng, x)

print('f parameters:', parameter_shapes(params_f))
print('g parameters:', parameter_shapes(params_g))
f parameters: {'simple_module': {'w': (3, 2)}}
g parameters: {'simple_module': {'w': (3, 2)}}

Great! Since f and g are using exactly the same modules, the sets of initialized variables generated with each have the same name structure (note that the actual values might differ, depending on initialization).

Now, if we wanted to share parameters in this case, we could initialize only one of the two functions (e.g., f) and use the resulting parameters for both functions, i.e., when we call transformed_f.apply and transformed_g.apply.

Case 2: Common modules have the same names, and the same shape#

This is a nice trick, but what if the functions were not identical? Let’s build two such functions:

[ ]:
def f(x):
  """A SimpleModule followed by a Linear layer."""
  module_instance = SimpleModule(output_channels=2)
  out = module_instance(x)
  linear = hk.Linear(40)
  return linear(out)

def g(x):
  """A SimpleModule followed by an MLP."""
  module_instance = SimpleModule(output_channels=2)
  out = module_instance(x)
  linear = hk.nets.MLP((10, 40))
  return linear(out)

# Transform both functions, and print the parameter shapes.
rng = jax.random.PRNGKey(42)
x = jnp.ones((2, 3))

transformed_f = hk.transform(f)
params_f = transformed_f.init(rng, x)
transformed_g = hk.transform(g)
params_g = transformed_g.init(rng, x)

print('\nThe name and shape of the f parameters are:\n',
      parameter_shapes(params_f))
print('\nThe name and shape of the g parameters are:\n',
      parameter_shapes(params_g))

The name and shape of the f parameters are:
 {'linear': {'b': (40,), 'w': (2, 40)}, 'simple_module': {'w': (3, 2)}}

The name and shape of the g parameters are:
 {'mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'mlp/~/linear_1': {'b': (40,), 'w': (10, 40)}, 'simple_module': {'w': (3, 2)}}

Now we have a problem! Both sets of parameters have a 'simple_module' component, but they also each contain parameters that are specific only to that function, so we cannot simply initialise only one of the functions and use the returned parameters for both as we did before. But we would still like to share the parameters of 'simple_module'. How can we do that?

One option here is to use `haiku.data_structures.merge <https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.data_structures.merge>`__ to combine the two sets of parameters. This will merge the two structures, keeping only the value from the last structure when both structures have the same parameters (i.e., 'simple_module' in our example). Let try that:

[ ]:
merged_params = hk.data_structures.merge(params_f, params_g)
print('\nThe name and shape of the shared parameters are:\n',
      parameter_shapes(merged_params))

The name and shape of the shared parameters are:
 {'linear': {'b': (40,), 'w': (2, 40)}, 'mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'mlp/~/linear_1': {'b': (40,), 'w': (10, 40)}, 'simple_module': {'w': (3, 2)}}

Brilliant! Now we have a shared set of parameters that contains all the disjoint parameters and a single set of parameters for the shared 'simple_module'. Let’s verify that we can use this set of parameters when calling either function:

[ ]:
f_out = transformed_f.apply(merged_params, rng, x)
g_out = transformed_g.apply(merged_params, rng, x)

print('f_out mean:', f_out.mean())
print('g_out mean:', g_out.mean())
f_out mean: 0.037986994
g_out mean: 0.104857825

This gives us little control over what gets shared though: what if the two functions had parameters with the same name that we don’t want to share?

Case 3: Common modules have the same names, but different shapes#

Let’s modify our previous example to use a hk.Linear layer in both functions:

[ ]:
def f(x):
  """A SimpleModule followed by two Linear layers."""
  module_instance = SimpleModule(output_channels=2)
  out = module_instance(x)
  mlp = hk.nets.MLP((10, 5))
  out = mlp(out)
  last_linear = hk.Linear(4)
  return last_linear(out)

def g(x):
  """Same as f, with a bigger final layer."""
  module_instance = SimpleModule(output_channels=2)
  out = module_instance(x)
  mlp = hk.nets.MLP((10, 5))
  out = mlp(out)
  last_linear = hk.Linear(20)  # another Linear, but bigger
  return last_linear(out)

# Transform both functions, and print the parameter shapes.
rng = jax.random.PRNGKey(42)
x = jnp.ones((2, 3))

transformed_f = hk.transform(f)
params_f = transformed_f.init(rng, x)
transformed_g = hk.transform(g)
params_g = transformed_g.init(rng, x)

print('\nThe name and shape of the f parameters are:\n',
      parameter_shapes(params_f))
print('\nThe name and shape of the g parameters are:\n',
      parameter_shapes(params_g))

The name and shape of the f parameters are:
 {'linear': {'b': (4,), 'w': (5, 4)}, 'mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'mlp/~/linear_1': {'b': (5,), 'w': (10, 5)}, 'simple_module': {'w': (3, 2)}}

The name and shape of the g parameters are:
 {'linear': {'b': (20,), 'w': (5, 20)}, 'mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'mlp/~/linear_1': {'b': (5,), 'w': (10, 5)}, 'simple_module': {'w': (3, 2)}}

Now we have a problem! Both sets of parameters have a 'linear' component, but their respective parameters have different shapes. If we merged them as we did before, the parameters of the 'linear' from f would be dropped and we couldn’t use the merged parameters to call it:

merged_params = hk.data_structures.merge(params_f, params_g)
print('\nThe name and shape of the merged parameters are:\n',
      parameter_shapes(merged_params))

f_out = transformed_f.apply(merged_params, rng, x)  # fails
# ValueError: 'linear/w' with retrieved shape (5, 20) does not match shape=[5, 4] dtype=dtype('float32')

How can we share the parameters of 'simple_module' and mlp, but keep the parameters of the two output linear layers separated?

A solution would to instantiate simple_module and mlp outside of the functions, so that they get instantiated only once, and then use that instance in both functions. But all Haiku modules must be initialised in a transform, so doing so naively would incur in an error:

module_instance = SimpleModule(output_channels=2)  # this fails
# ValueError: All `hk.Module`s must be initialized inside an `hk.transform`.
mlp = hk.nets.MLP((10, 5))

def f(x):
  """A SimpleModule followed by a Linear layer."""
  out = module_instance(x)
  out = mlp(out)
  linear = hk.Linear(4)
  return linear(out)

def g(x):
  """A SimpleModule followed by a bigger Linear layer."""
  out = module_instance(x)
  out = mlp(out)
  linear = hk.Linear(20)  # another Linear, but bigger
  return linear(out)

We can work around that by creating another function

[ ]:
class CachedModule():

  def __call__(self, *inputs):
    # Create the instances if are not in the cache.
    if not hasattr(self, 'cached_simple_module'):
      self.cached_simple_module = SimpleModule(output_channels=2)
    if not hasattr(self, 'cached_mlp'):
      self.cached_mlp = hk.nets.MLP((10, 5))

    # Apply the cached instances.
    out = self.cached_simple_module(*inputs)
    out = self.cached_mlp(out)
    return out


def f(x):
  """A SimpleModule followed by a Linear layer."""
  shared_preprocessing = CachedModule()
  out = shared_preprocessing(x)
  linear = hk.Linear(4)
  return linear(out)

def g(x):
  """A SimpleModule followed by a bigger Linear layer."""
  shared_preprocessing = CachedModule()
  out = shared_preprocessing(x)
  linear = hk.Linear(20)  # another Linear, but bigger
  return linear(out)


# Transform both functions, and print the parameter shapes.
rng = jax.random.PRNGKey(42)
x = jnp.ones((2, 3))

transformed_f = hk.transform(f)
params_f = transformed_f.init(rng, x)
transformed_g = hk.transform(g)
params_g = transformed_g.init(rng, x)

print('\nThe name and shape of the f parameters are:\n',
      parameter_shapes(params_f))
print('\nThe name and shape of the g parameters are:\n',
      parameter_shapes(params_g))

# Verify that the simple module parameters are shared.
assert_all_equal(params_f['mlp/~/linear_0'],
                 params_g['mlp/~/linear_0'])
assert_all_equal(params_f['mlp/~/linear_1'],
                 params_g['mlp/~/linear_1'])
print('\nThe MLP parameters are shared!')

The name and shape of the f parameters are:
 {'linear': {'b': (4,), 'w': (5, 4)}, 'mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'mlp/~/linear_1': {'b': (5,), 'w': (10, 5)}, 'simple_module': {'w': (3, 2)}}

The name and shape of the g parameters are:
 {'linear': {'b': (20,), 'w': (5, 20)}, 'mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'mlp/~/linear_1': {'b': (5,), 'w': (10, 5)}, 'simple_module': {'w': (3, 2)}}

The MLP parameters are shared!

If we want to share a big number of modules it can become tedious to cache each one of them manually inside of CachedModule. Furthermore, it would be nice if we didn’t have to define a different CachedModule object for every function we want to cache.

We can use hk.to_module to create a more general CachedModule object that takes an arbitrary Haiku function and caches it:

[ ]:
class CachedModule():
  """Cache one instance of the function and call it multiple times."""
  def __init__(self, fn):
    self._fn = fn

  def __call__(self, *args, **kwargs):
    if not hasattr(self, "_instance"):
      ModularisedFn = hk.to_module(self._fn)
      self._instance = ModularisedFn()
    return self._instance(*args, **kwargs)

def shared_preprocessing_fn(x):
  simple_module = SimpleModule(output_channels=2)
  out = simple_module(x)
  mlp = hk.nets.MLP((10, 5))
  return mlp(out)

def f(x):
  """A SimpleModule followed by a Linear layer."""
  shared_preprocessing = CachedModule(shared_preprocessing_fn)
  out = shared_preprocessing(x)
  linear = hk.Linear(4)
  return linear(out)

def g(x):
  """A SimpleModule followed by a bigger Linear layer."""
  shared_preprocessing = CachedModule(shared_preprocessing_fn)
  out = shared_preprocessing(x)
  linear = hk.Linear(20)  # another Linear, but bigger
  return linear(out)


# Transform both functions, and print the parameter shapes.
rng = jax.random.PRNGKey(42)
x = jnp.ones((2, 3))

transformed_f = hk.transform(f)
params_f = transformed_f.init(rng, x)
transformed_g = hk.transform(g)
params_g = transformed_g.init(rng, x)

print('\nThe name and shape of the f parameters are:\n',
      parameter_shapes(params_f))
print('\nThe name and shape of the g parameters are:\n',
      parameter_shapes(params_g))

# Verify that the simple module parameters are shared.
assert_all_equal(params_f['shared_preprocessing_fn/mlp/~/linear_0'],
                 params_g['shared_preprocessing_fn/mlp/~/linear_0'])
assert_all_equal(params_f['shared_preprocessing_fn/mlp/~/linear_1'],
                 params_g['shared_preprocessing_fn/mlp/~/linear_1'])
print('\nThe MLP parameters are shared!')

The name and shape of the f parameters are:
 {'linear': {'b': (4,), 'w': (5, 4)}, 'shared_preprocessing_fn/mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'shared_preprocessing_fn/mlp/~/linear_1': {'b': (5,), 'w': (10, 5)}, 'shared_preprocessing_fn/simple_module': {'w': (3, 2)}}

The name and shape of the g parameters are:
 {'linear': {'b': (20,), 'w': (5, 20)}, 'shared_preprocessing_fn/mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'shared_preprocessing_fn/mlp/~/linear_1': {'b': (5,), 'w': (10, 5)}, 'shared_preprocessing_fn/simple_module': {'w': (3, 2)}}

The MLP parameters are shared!

When we work with objects it can also be convenient to define a decorator to do the same:

[7]:
def share_parameters():
  def decorator(fn):
    def wrapper(*args, **kwargs):
      if wrapper.instance is None:
        wrapper.instance = hk.to_module(fn)()
      return wrapper.instance(*args, **kwargs)
    wrapper.instance = None
    return functools.wraps(fn)(wrapper)
  return decorator


class Wrapper():

  @share_parameters()
  def shared_preprocessing(self, x):
    simple_module = SimpleModule(output_channels=2)
    out = simple_module(x)
    mlp = hk.nets.MLP((10, 5))
    return mlp(out)

  def f(self, x):
    """A SimpleModule followed by a Linear layer."""
    out = self.shared_preprocessing(x)
    linear = hk.Linear(4)
    return linear(out)

  def g(self, x):
    """A SimpleModule followed by a bigger Linear layer."""
    out = self.shared_preprocessing(x)
    linear = hk.Linear(20)  # another Linear, but bigger
    return linear(out)

# Transform both functions, and print the parameter shapes.
rng = jax.random.PRNGKey(42)
x = jnp.ones((2, 3))

wrapper = Wrapper()
transformed_f = hk.transform(wrapper.f)
params_f = transformed_f.init(rng, x)
transformed_g = hk.transform(wrapper.g)
params_g = transformed_g.init(rng, x)

print('\nThe name and shape of the f parameters are:\n',
      parameter_shapes(params_f))
print('\nThe name and shape of the g parameters are:\n',
      parameter_shapes(params_g))

# Verify that the simple module parameters are shared.
assert_all_equal(params_f['shared_preprocessing/mlp/~/linear_0'],
                 params_g['shared_preprocessing/mlp/~/linear_0'])
assert_all_equal(params_f['shared_preprocessing/mlp/~/linear_1'],
                 params_g['shared_preprocessing/mlp/~/linear_1'])
print('\nThe MLP parameters are shared!')

The name and shape of the f parameters are:
 {'linear': {'b': (4,), 'w': (5, 4)}, 'shared_preprocessing/mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'shared_preprocessing/mlp/~/linear_1': {'b': (5,), 'w': (10, 5)}, 'shared_preprocessing/simple_module': {'w': (3, 2)}}

The name and shape of the g parameters are:
 {'linear': {'b': (20,), 'w': (5, 20)}, 'shared_preprocessing/mlp/~/linear_0': {'b': (10,), 'w': (2, 10)}, 'shared_preprocessing/mlp/~/linear_1': {'b': (5,), 'w': (10, 5)}, 'shared_preprocessing/simple_module': {'w': (3, 2)}}

The MLP parameters are shared!