Haiku Fundamentals

Haiku Transforms

transform(f, *[, apply_rng])

Transforms a function using Haiku modules into a pair of pure functions.

transform_with_state(f)

Transforms a function using Haiku modules into a pair of pure functions.

multi_transform(f)

Transforms a collection of functions using Haiku into pure functions.

multi_transform_with_state(f)

Transforms a collection of functions using Haiku into pure functions.

without_apply_rng(f)

Removes the rng argument from the apply function.

without_state(f)

Wraps a transformed tuple and ignores state in/out.

transform

haiku.transform(f, *, apply_rng=True)[source]

Transforms a function using Haiku modules into a pair of pure functions.

For a function out = f(*a, **k) this function returns a pair of two pure functions that call f(*a, **k) explicitly collecting and injecting parameter values:

params = init(rng, *a, **k)
out = apply(params, rng, *a, **k)

Note that the rng argument is typically not required for apply and passing None is accepted.

The first thing to do is to define a Module. A module encapsulates some parameters and a computation on those parameters:

>>> class MyModule(hk.Module):
...   def __call__(self, x):
...     w = hk.get_parameter("w", [], init=jnp.zeros)
...     return x + w

Next, define some function that creates and applies modules. We use transform() to transform that function into a pair of functions that allow us to lift all the parameters out of the function (f.init) and apply the function with a given set of parameters (f.apply):

>>> def f(x):
...   a = MyModule()
...   b = MyModule()
...   return a(x) + b(x)
>>> f = hk.transform(f)

To get the initial state of the module call init with an example input:

>>> params = f.init(None, 1)
>>> params
{'my_module': {'w': DeviceArray(0., dtype=float32)},
 'my_module_1': {'w': DeviceArray(0., dtype=float32)}}

You can then apply the function with the given parameters by calling apply (note that since we don’t use Haiku’s random number APIs to apply our network we pass None as an RNG key):

>>> f.apply(params, None, 1)
DeviceArray(2., dtype=float32)

It is expected that your program will at some point produce updated parameters and you will want to re-apply apply. You can do this by calling apply with different parameters:

>>> new_params = {"my_module": {"w": jnp.array(2.)},
...               "my_module_1": {"w": jnp.array(3.)}}
>>> f.apply(new_params, None, 2)
DeviceArray(9., dtype=float32)

If your transformed function needs to maintain internal state (e.g. moving averages in batch norm) then see transform_with_state().

Parameters
  • f – A function closing over Module instances.

  • apply_rng – In the process of being removed. Can only value True.

Return type

Transformed

Returns

A Transformed tuple with init and apply pure functions.

transform_with_state

haiku.transform_with_state(f)[source]

Transforms a function using Haiku modules into a pair of pure functions.

See transform() for general details on Haiku transformations.

For a function out = f(*a, **k) this function returns a pair of two pure functions that call f(*a, **k) explicitly collecting and injecting parameter values and state:

params, state = init(rng, *a, **k)
out, state = apply(params, state, rng, *a, **k)

Note that the rng argument is typically not required for apply and passing None is accepted.

This function is equivalent to transform(), however it allows you to maintain and update internal state (e.g. ExponentialMovingAverage in BatchNorm) via get_state() and set_state():

>>> def f():
...   counter = hk.get_state("counter", shape=[], dtype=jnp.int32,
...                          init=jnp.zeros)
...   hk.set_state("counter", counter + 1)
...   return counter
>>> f = hk.transform_with_state(f)
>>> params, state = f.init(None)
>>> for _ in range(10):
...   counter, state = f.apply(params, state, None)
>>> counter
DeviceArray(9, dtype=int32)
Parameters

f – A function closing over Module instances.

Return type

TransformedWithState

Returns

A TransformedWithState tuple with init and apply pure functions.

multi_transform

haiku.multi_transform(f)[source]

Transforms a collection of functions using Haiku into pure functions.

In many scenarios we have several modules which are used either as primitives for several Haiku modules/functions, or whose pure versions are to be reused in downstream code. This utility enables this by applying transform() to an arbitrary tree of Haiku functions which share modules and have a common init function.

f is expected to return a tuple of two elements. First is a template Haiku function which provides an example of how all internal Haiku modules are connected. This function is used to create a common init function (with your parameters).

The second object is an arbitrary tree of Haiku functions all of which reuse the modules connected in the template function. These functions are transformed to pure apply functions.

Example:

>>> def f():
...   encoder = hk.Linear(1, name="encoder")
...   decoder = hk.Linear(1, name="decoder")
...
...   def init(x):
...     z = encoder(x)
...     return decoder(z)
...
...   return init, (encoder, decoder)
>>> f = hk.multi_transform(f)
>>> rng = jax.random.PRNGKey(42)
>>> x = jnp.ones([1, 1])
>>> params = f.init(rng, x)
>>> jax.tree_map(jnp.shape, params)
{'decoder': {'b': (1,), 'w': (1, 1)},
 'encoder': {'b': (1,), 'w': (1, 1)}}
>>> encode, decode = f.apply
>>> z = encode(params, None, x)
>>> y = decode(params, None, z)
Parameters

f (Callable[[], Tuple[TemplateFn, TreeOfApplyFns]]) – A factory function that returns two functions, firstly a common init function that creates all modules, and secondly a pytree of apply functions which make use of those modules.

Return type

MultiTransformed

Returns

A MultiTransformed instance which contains a pure init function

that creates all parameters, and a pytree of pure apply functions that given the params apply the given function.

See also

multi_transform_with_state(): Equivalent for modules using state.

multi_transform_with_state

haiku.multi_transform_with_state(f)[source]

Transforms a collection of functions using Haiku into pure functions.

See multi_transform() for more details.

Example:

>>> def f():
...   encoder = hk.Linear(1, name="encoder")
...   decoder = hk.Linear(1, name="decoder")
...
...   def init(x):
...     z = encoder(x)
...     return decoder(z)
...
...   return init, (encoder, decoder)
>>> f = hk.multi_transform_with_state(f)
>>> rng = jax.random.PRNGKey(42)
>>> x = jnp.ones([1, 1])
>>> params, state = f.init(rng, x)
>>> jax.tree_map(jnp.shape, params)
{'decoder': {'b': (1,), 'w': (1, 1)},
 'encoder': {'b': (1,), 'w': (1, 1)}}
>>> encode, decode = f.apply
>>> z, state = encode(params, state, None, x)
>>> y, state = decode(params, state, None, z)
Parameters

f (Callable[[], Tuple[TemplateFn, TreeOfApplyFns]]) – Function returning a “template” function and an arbitrary tree of functions using modules connected in the template function.

Return type

MultiTransformedWithState

Returns

An init function and a tree of pure apply functions.

See also

transform_with_state(): Transform a single apply function. multi_transform(): Transform multiple apply functions without state.

without_apply_rng

haiku.without_apply_rng(f)[source]

Removes the rng argument from the apply function.

Return type

TransformedT

without_state

haiku.without_state(f)[source]

Wraps a transformed tuple and ignores state in/out.

The example below is equivalent to f = hk.transform(f):

>>> def f(x):
...   mod = hk.Linear(10)
...   return mod(x)
>>> f = hk.without_state(hk.transform_with_state(f))
>>> rng = jax.random.PRNGKey(42)
>>> x = jnp.zeros([1, 1])
>>> params = f.init(rng, x)
>>> f.apply(params, rng, x)
DeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)
Parameters

f (TransformedWithState) – A transformed function.

Return type

Transformed

Returns

A transformed function that does not take or return state.

with_empty_state

haiku.with_empty_state(f)[source]

Wraps a transformed tuple and passes empty state in/out.

The example below is equivalent to f = hk.transform_with_state(f):

>>> def f(x):
...   mod = hk.Linear(10)
...   return mod(x)
>>> f = hk.with_empty_state(hk.transform(f))
>>> rng = jax.random.PRNGKey(42)
>>> x = jnp.zeros([1, 1])
>>> params, state = f.init(rng, x)
>>> state
{}
>>> out, state = f.apply(params, state, rng, x)
>>> out
DeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)
>>> state
{}
Parameters

f (Transformed) – A transformed function.

Return type

TransformedWithState

Returns

A transformed function that does accepts and returns state.

Modules, Parameters and State

Module([name])

Base class for Haiku modules.

to_module(f)

Converts a function into a callable module class.

get_parameter(name, shape[, dtype, init])

Creates or reuses a parameter for the given transformed function.

get_state(name[, shape, dtype, init])

Gets the current value for state with an optional initializer.

set_state(name, value)

Sets the current value for some state.

transparent(method)

Decorator to wrap a method, preventing automatic variable scope wrapping.

Module

class haiku.Module(name=None)[source]

Base class for Haiku modules.

A Haiku module is a lightweight container for variables and other modules. Modules typically define one or more “forward” methods (e.g. __call__) which apply operations combining user input and module parameters.

Modules must be initialized inside a transform() call.

For example:

>>> class AddModule(hk.Module):
...   def __call__(self, x):
...     w = hk.get_parameter("w", [], init=jnp.ones)
...     return x + w
>>> def forward_fn(x):
...   mod = AddModule()
...   return mod(x)
>>> forward = hk.transform(forward_fn)
>>> x = 1.
>>> rng = None
>>> params = forward.init(rng, x)
>>> forward.apply(params, None, x)
DeviceArray(2., dtype=float32)
__init__(name=None)[source]

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

__post_init__(name=None)[source]

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

params_dict()[source]

Returns parameters keyed by name for this module and submodules.

Return type

Mapping[str, jnp.array]

state_dict()[source]

Returns state keyed by name for this module and submodules.

Return type

Mapping[str, jnp.array]

to_module

haiku.to_module(f)[source]

Converts a function into a callable module class.

Sample usage:

>>> def bias_fn(x):
...   b = hk.get_parameter("b", [], init=hk.initializers.RandomNormal())
...   return x + b
>>> Bias = hk.to_module(bias_fn)
>>> def net(x, y):
...   b = Bias(name="my_bias")
...   # Bias x and y by the same amount.
...   return b(x) * b(y)
Parameters

f (Callable[.., Any]) – The function to convert.

Return type

Type[CallableModule]

Returns

A module class which runs f when called.

get_parameter

haiku.get_parameter(name, shape, dtype=<class 'jax._src.numpy.lax_numpy.float32'>, init=None)[source]

Creates or reuses a parameter for the given transformed function.

>>> hk.get_parameter("w", [], init=jnp.ones)
DeviceArray(1., dtype=float32)

Parameters within the same transform() and/or Module with the same name have the same value:

>>> w1 = hk.get_parameter("w", [], init=jnp.zeros)
>>> w2 = hk.get_parameter("w", [], init=jnp.zeros)
>>> assert w1 is w2
Parameters
  • name (str) – A name for the parameter.

  • shape (Sequence[int]) – The shape of the parameter.

  • dtype (Any) – The dtype of the parameter.

  • init (Optional[Initializer]) – A callable of shape, dtype to generate an initial value for the parameter.

Return type

jnp.ndarray

Returns

A jnp.ndarray with the parameter of the given shape.

get_state

haiku.get_state(name, shape=None, dtype=<class 'jax._src.numpy.lax_numpy.float32'>, init=None)[source]

Gets the current value for state with an optional initializer.

“State” can be used to represent mutable state in your network. The most common usage of state is to represent the moving averages used in batch normalization (see ExponentialMovingAverage). If your network uses “state” then you are required to use transform_with_state() and pass state into and out of the apply function.

>>> hk.get_state("counter", [], init=jnp.zeros)
DeviceArray(0., dtype=float32)

If the value for the given state is already defined (e.g. using set_state()) then you can call with just the name:

>>> hk.get_state("counter")
DeviceArray(0., dtype=float32)

MOTE: state within the same transform() and/or Module with the same name have the same value:

>>> c1 = hk.get_state("counter")
>>> c2 = hk.get_state("counter")
>>> assert c1 is c2
Parameters
  • name (str) – A name for the state.

  • shape (Optional[Sequence[int]]) – The shape of the state.

  • dtype (Any) – The dtype of the state.

  • init (Optional[Initializer]) – A callable f(shape, dtype) that returns an initial value for the state.

Return type

jnp.ndarray

Returns

A jnp.ndarray with the state of the given shape.

set_state

haiku.set_state(name, value)[source]

Sets the current value for some state.

See get_state().

“State” can be used to represent mutable state in your network. The most common usage of state is to represent the moving averages used in batch normalization (see ExponentialMovingAverage). If your network uses “state” then you are required to use transform_with_state() and pass state into and out of the apply function.

>>> hk.set_state("counter", jnp.zeros([]))
>>> hk.get_state("counter")
DeviceArray(0., dtype=float32)

NOTE: state within the same transform() and/or Module with the same name have the same value:

>>> w1 = hk.get_state("counter")
>>> w2 = hk.get_state("counter")
>>> assert w1 is w2
Parameters
  • name (str) – A name for the state.

  • value – A value to set.

transparent

haiku.transparent(method)[source]

Decorator to wrap a method, preventing automatic variable scope wrapping.

By default, all variables and modules created in a method are scoped by the module and method names. This is undesirable in some cases. Any method decorated with transparent() will create variables and modules in the scope in which it was called.

Parameters

method (T) – the method to wrap.

Return type

T

Returns

The method, with a flag indicating no name scope wrapping should occur.

Random Numbers

PRNGSequence(key_or_seed)

Iterator of JAX random keys.

next_rng_key()

Returns a unique JAX random key split from the current global key.

next_rng_keys(num)

Returns one or more JAX random keys split from the current global key.

maybe_next_rng_key()

next_rng_key() if random numbers are available, else None.

reserve_rng_keys(num)

Pre-allocate some number of JAX RNG keys.

with_rng(key)

Provides a new sequence for next_rng_key() to draw from.

PRNGSequence

class haiku.PRNGSequence(key_or_seed)[source]

Iterator of JAX random keys.

>>> seq = hk.PRNGSequence(42)  # OR pass a jax.random.PRNGKey
>>> key1 = next(seq)
>>> key2 = next(seq)
>>> assert key1 is not key2

If you know how many keys you will want then you can use reserve() to more efficiently split the keys you need:

>>> seq.reserve(4)
>>> keys = [next(seq) for _ in range(4)]
__init__(key_or_seed)[source]

Creates a new PRNGSequence.

reserve(num)[source]

Splits additional num keys for later use.

__next__()[source]

Return the next item from the iterator. When exhausted, raise StopIteration

Return type

PRNGKey

next()[source]

Return the next item from the iterator. When exhausted, raise StopIteration

Return type

PRNGKey

next_rng_key

haiku.next_rng_key()[source]

Returns a unique JAX random key split from the current global key.

>>> key = hk.next_rng_key()
>>> _ = jax.random.uniform(key, [])
Return type

PRNGKey

Returns

A unique (within a call to init or apply) JAX rng key that can be used with APIs such as jax.random.uniform().

next_rng_keys

haiku.next_rng_keys(num)[source]

Returns one or more JAX random keys split from the current global key.

>>> k1, k2 = hk.next_rng_keys(2)
>>> assert (k1 != k2).all()
>>> a = jax.random.uniform(k1, [])
>>> b = jax.random.uniform(k2, [])
>>> assert a != b
Parameters

num (int) – The number of keys to split.

Return type

Tuple[PRNGKey, ..]

Returns

One or more unique (within a transformed function) JAX rng keys that can be used with APIs such as jax.random.uniform().

maybe_next_rng_key

haiku.maybe_next_rng_key()[source]

next_rng_key() if random numbers are available, else None.

Return type

Optional[PRNGKey]

reserve_rng_keys

haiku.reserve_rng_keys(num)[source]

Pre-allocate some number of JAX RNG keys.

See next_rng_key().

This API offers a way to micro-optimize how RNG keys are split when using Haiku. It is unlikely that you need it unless you find compilation time of your init function to be a problem, or you sample a lot of random numbers in apply.

>>> hk.reserve_rng_keys(2)  # Pre-allocate 2 keys for us to consume.
>>> _ = hk.next_rng_key()   # Takes the first pre-allocated key.
>>> _ = hk.next_rng_key()   # Takes the second pre-allocated key.
>>> _ = hk.next_rng_key()   # Splits a new key.
Parameters

num (int) – The number of JAX rng keys to allocate.

with_rng

haiku.with_rng(key)[source]

Provides a new sequence for next_rng_key() to draw from.

When next_rng_key() is called, it draws a new key from the PRNGSequence defined by the input key to the transformed function. This context manager overrides the sequence for the duration of the scope.

>>> with hk.with_rng(jax.random.PRNGKey(428)):
...   s = jax.random.uniform(hk.next_rng_key(), ())
>>> print("{:.1f}".format(s))
0.5
Parameters

key (PRNGKey) – The key to seed the sequence with.

Returns

Context manager under which the given sequence is active.

Type Hints

LSTMState(hidden, cell)

An LSTM core state consists of hidden and cell vectors.

Params

The central part of internal API.

State

The central part of internal API.

Transformed(init, hk.Params], apply, Any])

Holds a pair of pure functions.

TransformedWithState(init, Tuple[hk.Params, …)

Holds a pair of pure functions.

MultiTransformed(init, hk.Params], apply)

Holds a collection of pure functions.

MultiTransformedWithState(init, …)

Holds a collection of pure functions.

LSTMState

class haiku.LSTMState(hidden: jnp.ndarray, cell: jnp.ndarray)[source]

An LSTM core state consists of hidden and cell vectors.

hidden

Hidden state.

cell

Cell state.

Params

haiku.Params

State

haiku.State

Transformed

class haiku.Transformed(init: Callable[, hk.Params], apply: Callable[, Any])[source]

Holds a pair of pure functions.

init

A pure function: params = init(rng, *a, **k)

apply

A pure function: out = apply(params, rng, *a, **k)

TransformedWithState

class haiku.TransformedWithState(init: Callable[, Tuple[hk.Params, hk.State]], apply: Callable[, Tuple[Any, hk.State]])[source]

Holds a pair of pure functions.

init

A pure function: params, state = init(rng, *a, **k)

apply

A pure function: out, state = apply(params, state, rng, *a, **k)

MultiTransformed

class haiku.MultiTransformed(init: Callable[, hk.Params], apply: Any)[source]

Holds a collection of pure functions.

init

A pure function: params = init(rng, *a, **k)

apply

A JAX tree of pure functions each with the signature: out = apply(params, rng, *a, **k).

See also

Transformed: Single apply variant of multi-transform. MultiTransformedWithState: Multi apply with state variant.

MultiTransformedWithState

class haiku.MultiTransformedWithState(init: Callable[, Tuple[hk.Params, hk.State]], apply: Any)[source]

Holds a collection of pure functions.

init

A pure function: params, state = init(rng, *a, **k)

apply

A JAX tree of pure functions each with the signature: out, state = apply(params, state, rng, *a, **k).

See also

TransformedWithState: Single apply variant of multi-transform. MultiTransformed: Multi apply with state variant.

Common Modules

Linear

Linear(output_size[, with_bias, w_init, …])

Linear module.

Bias([output_size, bias_dims, b_init, name])

Adds a bias to inputs.

Linear

class haiku.Linear(output_size, with_bias=True, w_init=None, b_init=None, name=None)[source]

Linear module.

__init__(output_size, with_bias=True, w_init=None, b_init=None, name=None)[source]

Constructs the Linear module.

Parameters
  • output_size (int) – Output dimensionality.

  • with_bias (bool) – Whether to add a bias to the output.

  • w_init (Optional[hk.initializers.Initializer]) – Optional initializer for weights. By default, uses random values from truncated normal, with stddev 1 / sqrt(fan_in). See https://arxiv.org/abs/1502.03167v3.

  • b_init (Optional[hk.initializers.Initializer]) – Optional initializer for bias. By default, zero.

  • name (Optional[str]) – Name of the module.

__call__(inputs, *, precision=None)[source]

Computes a linear transform of the input.

Return type

jnp.ndarray

Bias

class haiku.Bias(output_size=None, bias_dims=None, b_init=None, name=None)[source]

Adds a bias to inputs.

Example Usage:

>>> N, H, W, C = 1, 2, 3, 4
>>> x = jnp.ones([N, H, W, C])
>>> scalar_bias = hk.Bias(bias_dims=[])
>>> scalar_bias_output = scalar_bias(x)
>>> assert scalar_bias.bias_shape == ()

Create a bias over all non-minibatch dimensions:

>>> all_bias = hk.Bias()
>>> all_bias_output = all_bias(x)
>>> assert all_bias.bias_shape == (H, W, C)

Create a bias over the last non-minibatch dimension:

>>> last_bias = hk.Bias(bias_dims=[-1])
>>> last_bias_output = last_bias(x)
>>> assert last_bias.bias_shape == (C,)

Create a bias over the first non-minibatch dimension:

>>> first_bias = hk.Bias(bias_dims=[1])
>>> first_bias_output = first_bias(x)
>>> assert first_bias.bias_shape == (H, 1, 1)

Subtract and later add the same learned bias:

>>> bias = hk.Bias()
>>> h1 = bias(x, multiplier=-1)
>>> h2 = bias(x)
>>> h3 = bias(x, multiplier=-1)
>>> reconstructed_x = bias(h3)
>>> assert (x == reconstructed_x).all()
__init__(output_size=None, bias_dims=None, b_init=None, name=None)[source]

Constructs a Bias module that supports broadcasting.

Parameters
  • output_size (Optional[Sequence[int]]) – Output size (output shape without batch dimension). If output_size is left as None, the size will be directly inferred by the input.

  • bias_dims (Optional[Sequence[int]]) – Sequence of which dimensions to retain from the input shape when constructing the bias. The remaining dimensions will be broadcast over (given size of 1), and leading dimensions will be removed completely. See class doc for examples.

  • b_init (Optional[hk.initializers.Initializer]) – Optional initializer for the bias. Default to zeros.

  • name (Optional[str]) – Name of the module.

__call__(inputs, multiplier=None)[source]

Adds bias to inputs and optionally multiplies by multiplier.

Parameters
  • inputs (jnp.ndarray) – A Tensor of size [batch_size, input_size1, ...].

  • multiplier (Union[float, jnp.ndarray]) – A scalar or Tensor which the bias term is multiplied by before adding it to inputs. Anything which works in the expression bias * multiplier is acceptable here. This may be useful if you want to add a bias in one place and subtract the same bias in another place via multiplier=-1.

Return type

jnp.ndarray

Returns

A Tensor of size [batch_size, input_size1, ...].

Pooling

avg_pool(value, window_shape, strides, padding)

Average pool.

AvgPool(window_shape, strides, padding[, …])

Average pool.

max_pool(value, window_shape, strides, padding)

Max pool.

MaxPool(window_shape, strides, padding[, …])

Max pool.

Average Pool

haiku.avg_pool(value, window_shape, strides, padding, channel_axis=- 1)[source]

Average pool.

Parameters
  • value (jnp.ndarray) – Value to pool.

  • window_shape (Union[int, Sequence[int]]) – Shape of the pooling window, an int or same rank as value.

  • strides (Union[int, Sequence[int]]) – Strides of the pooling window, an int or same rank as value.

  • padding (str) – Padding algorithm. Either VALID or SAME.

  • channel_axis (Optional[int]) – Axis of the spatial channels for which pooling is skipped, used to infer window_shape or strides if they are an integer.

Return type

jnp.ndarray

Returns

Pooled result. Same rank as value.

Raises

ValueError – If the padding is not valid.

class haiku.AvgPool(window_shape, strides, padding, channel_axis=- 1, name=None)[source]

Average pool.

Equivalent to partial application of avg_pool().

__init__(window_shape, strides, padding, channel_axis=- 1, name=None)[source]

Average pool.

Parameters
  • window_shape (Union[int, Sequence[int]]) – Shape of window to pool over. Same rank as value or int.

  • strides (Union[int, Sequence[int]]) – Strides for the window. Same rank as value or int.

  • padding (str) – Padding algorithm. Either VALID or SAME.

  • channel_axis (Optional[int]) – Axis of the spatial channels for which pooling is skipped.

  • name (Optional[str]) – String name for the module.

__call__(value)[source]

Call self as a function.

Return type

jnp.ndarray

Max Pool

haiku.max_pool(value, window_shape, strides, padding, channel_axis=- 1)[source]

Max pool.

Parameters
  • value (jnp.ndarray) – Value to pool.

  • window_shape (Union[int, Sequence[int]]) – Shape of the pooling window, an int or same rank as value.

  • strides (Union[int, Sequence[int]]) – Strides of the pooling window, an int or same rank as value.

  • padding (str) – Padding algorithm. Either VALID or SAME.

  • channel_axis (Optional[int]) – Axis of the spatial channels for which pooling is skipped, used to infer window_shape or strides if they are an integer.

Return type

jnp.ndarray

Returns

Pooled result. Same rank as value.

class haiku.MaxPool(window_shape, strides, padding, channel_axis=- 1, name=None)[source]

Max pool.

Equivalent to partial application of max_pool().

__init__(window_shape, strides, padding, channel_axis=- 1, name=None)[source]

Max pool.

Parameters
  • window_shape (Union[int, Sequence[int]]) – Shape of window to pool over. Same rank as value or int.

  • strides (Union[int, Sequence[int]]) – Strides for the window. Same rank as value or int.

  • padding (str) – Padding algorithm. Either VALID or SAME.

  • channel_axis (Optional[int]) – Axis of the spatial channels for which pooling is skipped.

  • name (Optional[str]) – String name for the module.

__call__(value)[source]

Call self as a function.

Return type

jnp.ndarray

Dropout

dropout(rng, rate, x)

Randomly drop units in the input at a given rate.

dropout

haiku.dropout(rng, rate, x)[source]

Randomly drop units in the input at a given rate.

See: http://www.cs.toronto.edu/~hinton/absps/dropout.pdf

Parameters
  • rng (PRNGKey) – A JAX random key.

  • rate (float) – Probability that each element of x is discarded. Must be a scalar in the range [0, 1).

  • x (jnp.ndarray) – The value to be dropped out.

Return type

jnp.ndarray

Returns

x, but dropped out and scaled by 1 / (1 - rate).

Combinator

Sequential(layers[, name])

Sequentially calls the given list of layers.

Sequential

class haiku.Sequential(layers, name=None)[source]

Sequentially calls the given list of layers.

Note that Sequential is limited in the range of possible architectures it can handle. This is a deliberate design decision; Sequential is only meant to be used for the simple case of fusing together modules/ops where the input of a particular module/op is the output of the previous one.

Another restriction is that it is not possible to have extra arguments in the __call__() method that are passed to the constituents of the module - for example, if there is a BatchNorm module in Sequential and the user wishes to switch the is_training flag. If this is the desired use case, the recommended solution is to subclass Module and implement __call__:

>>> class CustomModule(hk.Module):
...   def __call__(self, x, is_training):
...     x = hk.Conv2D(32, 4, 2)(x)
...     x = hk.BatchNorm(True, True, 0.9)(x, is_training)
...     x = jax.nn.relu(x)
...     return x
__init__(layers, name=None)[source]

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

__call__(inputs, *args, **kwargs)[source]

Calls all layers sequentially.

Convolutional

ConvND(num_spatial_dims, output_channels, …)

General N-dimensional convolutional.

Conv1D(output_channels, kernel_shape[, …])

One dimensional convolution.

Conv2D(output_channels, kernel_shape[, …])

Two dimensional convolution.

Conv3D(output_channels, kernel_shape[, …])

Three dimensional convolution.

ConvNDTranspose(num_spatial_dims, …[, …])

General n-dimensional transposed convolution (aka.

Conv1DTranspose(output_channels, kernel_shape)

One dimensional transposed convolution (aka.

Conv2DTranspose(output_channels, kernel_shape)

Two dimensional transposed convolution (aka.

Conv3DTranspose(output_channels, kernel_shape)

Three dimensional transposed convolution (aka.

DepthwiseConv2D(channel_multiplier, kernel_shape)

2-D Depthwise Convolution Module.

DepthwiseConv2D(channel_multiplier, kernel_shape)

2-D Depthwise Convolution Module.

ConvND

class haiku.ConvND(num_spatial_dims, output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='channels_last', mask=None, feature_group_count=1, name=None)[source]

General N-dimensional convolutional.

__init__(num_spatial_dims, output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='channels_last', mask=None, feature_group_count=1, name=None)[source]

Initializes the module.

Parameters
  • num_spatial_dims (int) – The number of spatial dimensions of the input.

  • output_channels (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length num_spatial_dims.

  • stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length num_spatial_dims. Defaults to 1.

  • rate (Union[int, Sequence[int]]) – Optional kernel dilation rate. Either an integer or a sequence of length num_spatial_dims. 1 corresponds to standard ND convolution, rate > 1 corresponds to dilated convolution. Defaults to 1.

  • padding (Union[str, Sequence[Tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – Optional padding algorithm. Either VALID or SAME or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. or a callable or sequence of callables of size num_spatial_dims. Any callables must take a single integer argument equal to the effective kernel size and return a sequence of two integers representing the padding before and after. See haiku.pad.* for more details and example functions. Defaults to SAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.

  • with_bias (bool) – Whether to add a bias. By default, true.

  • w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.

  • b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.

  • data_format (str) – The data format of the input. Can be either channels_first, channels_last, N...C or NC.... By default, channels_last.

  • mask (Optional[jnp.ndarray]) – Optional mask of the weights.

  • feature_group_count (int) – Optional number of groups in group convolution. Default value of 1 corresponds to normal dense convolution. If a higher value is used, convolutions are applied separately to that many groups, then stacked together. This reduces the number of parameters and possibly the compute for a given output_channels. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.

  • name (Optional[str]) – The name of the module.

__call__(inputs, *, precision=None)[source]

Connects ConvND layer.

Parameters
  • inputs (jnp.ndarray) – An array of shape [spatial_dims, C] and rank-N+1 if unbatched, or an array of shape [N, spatial_dims, C] and rank-N+2 if batched.

  • precision (Optional[lax.Precision]) – Optional jax.lax.Precision to pass to jax.lax.conv_general_dilated().

Return type

jnp.ndarray

Returns

An array of shape [spatial_dims, output_channels] and rank-N+1 if

unbatched, or an array of shape [N, spatial_dims, output_channels] and rank-N+2 if batched.

Conv1D

class haiku.Conv1D(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', mask=None, feature_group_count=1, name=None)[source]

One dimensional convolution.

__init__(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', mask=None, feature_group_count=1, name=None)[source]

Initializes the module.

Parameters
  • output_channels (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 1.

  • stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 1. Defaults to 1.

  • rate (Union[int, Sequence[int]]) – Optional kernel dilation rate. Either an integer or a sequence of length 1. 1 corresponds to standard ND convolution, rate > 1 corresponds to dilated convolution. Defaults to 1.

  • padding (Union[str, Sequence[Tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – Optional padding algorithm. Either VALID or SAME or a callable or sequence of callables of length 1. Any callables must take a single integer argument equal to the effective kernel size and return a list of two integers representing the padding before and after. See haiku.pad.* for more details and example functions. Defaults to SAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.

  • with_bias (bool) – Whether to add a bias. By default, true.

  • w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.

  • b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.

  • data_format (str) – The data format of the input. Either NWC or NCW. By default, NWC.

  • mask (Optional[jnp.ndarray]) – Optional mask of the weights.

  • feature_group_count (int) – Optional number of groups in group convolution. Default value of 1 corresponds to normal dense convolution. If a higher value is used, convolutions are applied separately to that many groups, then stacked together. This reduces the number of parameters and possibly the compute for a given output_channels. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.

  • name (Optional[str]) – The name of the module.

Conv2D

class haiku.Conv2D(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', mask=None, feature_group_count=1, name=None)[source]

Two dimensional convolution.

__init__(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', mask=None, feature_group_count=1, name=None)[source]

Initializes the module.

Parameters
  • output_channels (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 2.

  • stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 2. Defaults to 1.

  • rate (Union[int, Sequence[int]]) – Optional kernel dilation rate. Either an integer or a sequence of length 2. 1 corresponds to standard ND convolution, rate > 1 corresponds to dilated convolution. Defaults to 1.

  • padding (Union[str, Sequence[Tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – Optional padding algorithm. Either VALID or SAME or a callable or sequence of callables of length 2. Any callables must take a single integer argument equal to the effective kernel size and return a list of two integers representing the padding before and after. See haiku.pad.* for more details and example functions. Defaults to SAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.

  • with_bias (bool) – Whether to add a bias. By default, true.

  • w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.

  • b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.

  • data_format (str) – The data format of the input. Either NHWC or NCHW. By default, NHWC.

  • mask (Optional[jnp.ndarray]) – Optional mask of the weights.

  • feature_group_count (int) – Optional number of groups in group convolution. Default value of 1 corresponds to normal dense convolution. If a higher value is used, convolutions are applied separately to that many groups, then stacked together. This reduces the number of parameters and possibly the compute for a given output_channels. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.

  • name (Optional[str]) – The name of the module.

Conv3D

class haiku.Conv3D(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', mask=None, feature_group_count=1, name=None)[source]

Three dimensional convolution.

__init__(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', mask=None, feature_group_count=1, name=None)[source]

Initializes the module.

Parameters
  • output_channels (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 3.

  • stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 3. Defaults to 1.

  • rate (Union[int, Sequence[int]]) – Optional kernel dilation rate. Either an integer or a sequence of length 3. 1 corresponds to standard ND convolution, rate > 1 corresponds to dilated convolution. Defaults to 1.

  • padding (Union[str, Sequence[Tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – Optional padding algorithm. Either VALID or SAME or a callable or sequence of callables of length 3. Any callables must take a single integer argument equal to the effective kernel size and return a list of two integers representing the padding before and after. See haiku.pad.* for more details and example functions. Defaults to SAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.

  • with_bias (bool) – Whether to add a bias. By default, true.

  • w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.

  • b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.

  • data_format (str) – The data format of the input. Either NDHWC or NCDHW. By default, NDHWC.

  • mask (Optional[jnp.ndarray]) – Optional mask of the weights.

  • feature_group_count (int) – Optional number of groups in group convolution. Default value of 1 corresponds to normal dense convolution. If a higher value is used, convolutions are applied separately to that many groups, then stacked together. This reduces the number of parameters and possibly the compute for a given output_channels. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.

  • name (Optional[str]) – The name of the module.

ConvNDTranspose

class haiku.ConvNDTranspose(num_spatial_dims, output_channels, kernel_shape, stride=1, output_shape=None, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='channels_last', mask=None, name=None)[source]

General n-dimensional transposed convolution (aka. deconvolution).

__init__(num_spatial_dims, output_channels, kernel_shape, stride=1, output_shape=None, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='channels_last', mask=None, name=None)[source]

Initializes the module.

Parameters
  • num_spatial_dims (int) – The number of spatial dimensions of the input.

  • output_channels (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length num_spatial_dims.

  • stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length num_spatial_dims. Defaults to 1.

  • output_shape (Optional[Union[int, Sequence[int]]]) – Output shape of the spatial dimensions of a transpose convolution. Can be either an integer or an iterable of integers. If a None value is given, a default shape is automatically calculated.

  • padding (Union[str, Sequence[Tuple[int, int]]]) – Optional padding algorithm. Either “VALID” or “SAME”. Defaults to “SAME”. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.

  • with_bias (bool) – Whether to add a bias. By default, true.

  • w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.

  • b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.

  • data_format (str) – The data format of the input. Can be either channels_first, channels_last, N...C or NC.... By default, channels_last.

  • mask (Optional[jnp.ndarray]) – Optional mask of the weights.

  • name (Optional[str]) – The name of the module.

__call__(inputs, *, precision=None)[source]

Computes the transposed convolution of the input.

Parameters
  • inputs (jnp.ndarray) – An array of shape [spatial_dims, C] and rank-N+1 if unbatched, or an array of shape [N, spatial_dims, C] and rank-N+2 if batched.

  • precision (Optional[lax.Precision]) – Optional jax.lax.Precision to pass to jax.lax.conv_transpose().

Return type

jnp.ndarray

Returns

An array of shape [spatial_dims, output_channels] and rank-N+1 if

unbatched, or an array of shape [N, spatial_dims, output_channels] and rank-N+2 if batched.

Conv1DTranspose

class haiku.Conv1DTranspose(output_channels, kernel_shape, stride=1, output_shape=None, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', mask=None, name=None)[source]

One dimensional transposed convolution (aka. deconvolution).

__init__(output_channels, kernel_shape, stride=1, output_shape=None, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', mask=None, name=None)[source]

Initializes the module.

Parameters
  • output_channels (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 1.

  • stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 1. Defaults to 1.

  • output_shape (Optional[Union[int, Sequence[int]]]) – Output shape of the spatial dimensions of a transpose convolution. Can be either an integer or an iterable of integers. If a None value is given, a default shape is automatically calculated.

  • padding (Union[str, Sequence[Tuple[int, int]]]) – Optional padding algorithm. Either VALID or SAME. Defaults to SAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.

  • with_bias (bool) – Whether to add a bias. By default, true.

  • w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.

  • b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.

  • data_format (str) – The data format of the input. Either NWC or NCW. By default, NWC.

  • mask (Optional[jnp.ndarray]) – Optional mask of the weights.

  • name (Optional[str]) – The name of the module.

Conv2DTranspose

class haiku.Conv2DTranspose(output_channels, kernel_shape, stride=1, output_shape=None, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', mask=None, name=None)[source]

Two dimensional transposed convolution (aka. deconvolution).

__init__(output_channels, kernel_shape, stride=1, output_shape=None, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', mask=None, name=None)[source]

Initializes the module.

Parameters
  • output_channels (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 2.

  • stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 2. Defaults to 1.

  • output_shape (Optional[Union[int, Sequence[int]]]) – Output shape of the spatial dimensions of a transpose convolution. Can be either an integer or an iterable of integers. If a None value is given, a default shape is automatically calculated.

  • padding (Union[str, Sequence[Tuple[int, int]]]) – Optional padding algorithm. Either VALID or SAME. Defaults to SAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.

  • with_bias (bool) – Whether to add a bias. By default, true.

  • w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.

  • b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.

  • data_format (str) – The data format of the input. Either NHWC or NCHW. By default, NHWC.

  • mask (Optional[jnp.ndarray]) – Optional mask of the weights.

  • name (Optional[str]) – The name of the module.

Conv3DTranspose

class haiku.Conv3DTranspose(output_channels, kernel_shape, stride=1, output_shape=None, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', mask=None, name=None)[source]

Three dimensional transposed convolution (aka. deconvolution).

__init__(output_channels, kernel_shape, stride=1, output_shape=None, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', mask=None, name=None)[source]

Initializes the module.

Parameters
  • output_channels (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 3.

  • stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 3. Defaults to 1.

  • output_shape (Optional[Union[int, Sequence[int]]]) – Output shape of the spatial dimensions of a transpose convolution. Can be either an integer or an iterable of integers. If a None value is given, a default shape is automatically calculated.

  • padding (Union[str, Sequence[Tuple[int, int]]]) – Optional padding algorithm. Either VALID or SAME. Defaults to SAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.

  • with_bias (bool) – Whether to add a bias. By default, true.

  • w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.

  • b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.

  • data_format (str) – The data format of the input. Either NDHWC or NCDHW. By default, NDHWC.

  • mask (Optional[jnp.ndarray]) – Optional mask of the weights.

  • name (Optional[str]) – The name of the module.

DepthwiseConv2D

class haiku.DepthwiseConv2D(channel_multiplier, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', name=None)[source]

2-D Depthwise Convolution Module.

__init__(channel_multiplier, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', name=None)[source]

Construct a 2D Depthwise Convolution.

Parameters
  • channel_multiplier (int) – Multiplicity of output channels. To keep the number of output channels the same as the number of input channels, set 1.

  • kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length num_spatial_dims.

  • stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length num_spatial_dims. Defaults to 1.

  • padding (Union[str, Sequence[Tuple[int, int]]]) – Optional padding algorithm. Either VALID, SAME or a sequence of before, after pairs. Defaults to SAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.

  • with_bias (bool) – Whether to add a bias. By default, true.

  • w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.

  • b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.

  • data_format (str) – The data format of the input. Can be either channels_first, channels_last, N...C or NC.... By default, channels_last.

  • name (Optional[str]) – The name of the module.

__call__(inputs)[source]

Call self as a function.

Return type

jnp.ndarray

SeparableDepthwiseConv2D

class haiku.SeparableDepthwiseConv2D(channel_multiplier, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', name=None)[source]

Separable 2-D Depthwise Convolution Module.

__init__(channel_multiplier, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', name=None)[source]

Construct a Separable 2D Depthwise Convolution module.

Parameters
  • channel_multiplier (int) – Multiplicity of output channels. To keep the number of output channels the same as the number of input channels, set 1.

  • kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length num_spatial_dims.

  • stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length num_spatial_dims. Defaults to 1.

  • padding (Union[str, Sequence[Tuple[int, int]]]) – Optional padding algorithm. Either VALID, SAME or a sequence of before, after pairs. Defaults to SAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.

  • with_bias (bool) – Whether to add a bias. By default, true.

  • w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.

  • b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.

  • data_format (str) – The data format of the input. Can be either channels_first, channels_last, N...C or NC.... By default, channels_last.

  • name (Optional[str]) – The name of the module.

__call__(inputs)[source]

Call self as a function.

Return type

jnp.ndarray

Normalization

BatchNorm(create_scale, create_offset, …)

Normalizes inputs to maintain a mean of ~0 and stddev of ~1.

GroupNorm(groups[, axis, create_scale, …])

Group normalization module.

InstanceNorm(create_scale, create_offset[, …])

Normalizes inputs along the spatial dimensions.

LayerNorm(axis, create_scale, create_offset)

LayerNorm module.

RMSNorm(axis[, eps, scale_init, name])

RMSNorm module.

SpectralNorm([eps, n_steps, name])

Normalizes an input by its first singular value.

ExponentialMovingAverage(decay[, …])

Maintains an exponential moving average.

SNParamsTree([eps, n_steps, ignore_regex, name])

Applies Spectral Normalization to all parameters in a tree.

EMAParamsTree(decay[, zero_debias, …])

Maintains an exponential moving average for all parameters in a tree.

BatchNorm

class haiku.BatchNorm(create_scale, create_offset, decay_rate, eps=1e-05, scale_init=None, offset_init=None, axis=None, cross_replica_axis=None, cross_replica_axis_index_groups=None, data_format='channels_last', name=None)[source]

Normalizes inputs to maintain a mean of ~0 and stddev of ~1.

See: https://arxiv.org/abs/1502.03167.

There are many different variations for how users want to manage scale and offset if they require them at all. These are:

  • No scale/offset in which case create_* should be set to False and scale/offset aren’t passed when the module is called.

  • Trainable scale/offset in which case create_* should be set to True and again scale/offset aren’t passed when the module is called. In this case this module creates and owns the scale/offset variables.

  • Externally generated scale/offset, such as for conditional normalization, in which case create_* should be set to False and then the values fed in at call time.

NOTE: jax.vmap(hk.transform(BatchNorm)) will update summary statistics and normalize values on a per-batch basis; we currently do not support normalizing across a batch axis introduced by vmap.

__init__(create_scale, create_offset, decay_rate, eps=1e-05, scale_init=None, offset_init=None, axis=None, cross_replica_axis=None, cross_replica_axis_index_groups=None, data_format='channels_last', name=None)[source]

Constructs a BatchNorm module.

Parameters
  • create_scale (bool) – Whether to include a trainable scaling factor.

  • create_offset (bool) – Whether to include a trainable offset.

  • decay_rate (float) – Decay rate for EMA.

  • eps (float) – Small epsilon to avoid division by zero variance. Defaults 1e-5, as in the paper and Sonnet.

  • scale_init (Optional[hk.initializers.Initializer]) – Optional initializer for gain (aka scale). Can only be set if create_scale=True. By default, 1.

  • offset_init (Optional[hk.initializers.Initializer]) – Optional initializer for bias (aka offset). Can only be set if create_offset=True. By default, 0.

  • axis (Optional[Sequence[int]]) – Which axes to reduce over. The default (None) signifies that all but the channel axis should be normalized. Otherwise this is a list of axis indices which will have normalization statistics calculated.

  • cross_replica_axis (Optional[str]) – If not None, it should be a string representing the axis name over which this module is being run within a jax.pmap. Supplying this argument means that batch statistics are calculated across all replicas on that axis.

  • cross_replica_axis_index_groups (Optional[Sequence[Sequence[int]]]) – Specifies how devices are grouped.

  • data_format (str) – The data format of the input. Can be either channels_first, channels_last, N...C or NC.... By default it is channels_last.

  • name (Optional[str]) – The module name.

__call__(inputs, is_training, test_local_stats=False, scale=None, offset=None)[source]

Computes the normalized version of the input.

Parameters
  • inputs (jnp.ndarray) – An array, where the data format is [..., C].

  • is_training (bool) – Whether this is during training.

  • test_local_stats (bool) – Whether local stats are used when is_training=False.

  • scale (Optional[jnp.ndarray]) – An array up to n-D. The shape of this tensor must be broadcastable to the shape of inputs. This is the scale applied to the normalized inputs. This cannot be passed in if the module was constructed with create_scale=True.

  • offset (Optional[jnp.ndarray]) – An array up to n-D. The shape of this tensor must be broadcastable to the shape of inputs. This is the offset applied to the normalized inputs. This cannot be passed in if the module was constructed with create_offset=True.

Return type

jnp.ndarray

Returns

The array, normalized across all but the last dimension.

GroupNorm

class haiku.GroupNorm(groups, axis=slice(1, None, None), create_scale=True, create_offset=True, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]

Group normalization module.

This applies group normalization to the x. This involves splitting the channels into groups before calculating the mean and variance. The default behaviour is to compute the mean and variance over the spatial dimensions and the grouped channels. The mean and variance will never be computed over the created groups axis.

It transforms the input x into:

\[\d{outputs} = \d{scale} \dfrac{x - \mu}{\sigma + \epsilon} + \d{offset}\]

Where \(\mu\) and \(\sigma\) are respectively the mean and standard deviation of x.

There are many different variations for how users want to manage scale and offset if they require them at all. These are:

  • No scale/offset in which case create_* should be set to False and scale/offset aren’t passed when the module is called.

  • Trainable scale/offset in which case create_* should be set to True and again scale/offset aren’t passed when the module is called. In this case this module creates and owns the scale/offset parameters.

  • Externally generated scale/offset, such as for conditional normalization, in which case create_* should be set to False and then the values fed in at call time.

__init__(groups, axis=slice(1, None, None), create_scale=True, create_offset=True, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]

Constructs a GroupNorm module.

Parameters
  • groups (int) – number of groups to divide the channels by. The number of channels must be divisible by this.

  • axis (Union[int, slice, Sequence[int]]) – int, slice or sequence of ints representing the axes which should be normalized across. By default this is all but the first dimension. For time series data use slice(2, None) to average over the none Batch and Time data.

  • create_scale (bool) – whether to create a trainable scale per channel applied after the normalization.

  • create_offset (bool) – whether to create a trainable offset per channel applied after normalization and scaling.

  • eps (float) – Small epsilon to add to the variance to avoid division by zero. Defaults to 1e-5.

  • scale_init (Optional[hk.initializers.Initializer]) – Optional initializer for the scale parameter. Can only be set if create_scale=True. By default scale is initialized to 1.

  • offset_init (Optional[hk.initializers.Initializer]) – Optional initializer for the offset parameter. Can only be set if create_offset=True. By default offset is initialized to 0.

  • data_format (str) – The data format of the input. Can be either channels_first, channels_last, N...C or NC.... By default it is channels_last.

  • name (Optional[str]) – Name of the module.

__call__(x, scale=None, offset=None)[source]

Returns normalized inputs.

Parameters
  • x (jnp.ndarray) – An n-D tensor of the data_format specified in the constructor on which the transformation is performed.

  • scale (Optional[jnp.ndarray]) – A tensor up to n-D. The shape of this tensor must be broadcastable to the shape of x. This is the scale applied to the normalized x. This cannot be passed in if the module was constructed with create_scale=True.

  • offset (Optional[jnp.ndarray]) – A tensor up to n-D. The shape of this tensor must be broadcastable to the shape of x. This is the offset applied to the normalized x. This cannot be passed in if the module was constructed with create_offset=True.

Return type

jnp.ndarray

Returns

An n-d tensor of the same shape as x that has been normalized.

InstanceNorm

class haiku.InstanceNorm(create_scale, create_offset, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]

Normalizes inputs along the spatial dimensions.

See LayerNorm for more details.

__init__(create_scale, create_offset, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]

Constructs an InstanceNorm module.

This method creates a module which normalizes over the spatial dimensions.

Parameters
  • create_scale (bool) – bool representing whether to create a trainable scale per channel applied after the normalization.

  • create_offset (bool) – bool representing whether to create a trainable offset per channel applied after normalization and scaling.

  • eps (float) – Small epsilon to avoid division by zero variance. Defaults to 1e-5.

  • scale_init (Optional[hk.initializers.Initializer]) – Optional initializer for the scale variable. Can only be set if create_scale=True. By default scale is initialized to 1.

  • offset_init (Optional[hk.initializers.Initializer]) – Optional initializer for the offset variable. Can only be set if create_offset=True. By default offset is initialized to 0.

  • data_format (str) – The data format of the input. Can be either channels_first, channels_last, N...C or NC.... By default it is channels_last.

  • name (Optional[str]) – Name of the module.

LayerNorm

class haiku.LayerNorm(axis, create_scale, create_offset, eps=1e-05, scale_init=None, offset_init=None, name=None)[source]

LayerNorm module.

See: https://arxiv.org/abs/1607.06450.

__init__(axis, create_scale, create_offset, eps=1e-05, scale_init=None, offset_init=None, name=None)[source]

Constructs a LayerNorm module.

Parameters
  • axis (Union[int, Sequence[int], slice]) – Integer, list of integers, or slice indicating which axes to normalize over.

  • create_scale (bool) – Bool, defines whether to create a trainable scale per channel applied after the normalization.

  • create_offset (bool) – Bool, defines whether to create a trainable offset per channel applied after normalization and scaling.

  • eps (float) – Small epsilon to avoid division by zero variance. Defaults 1e-5, as in the paper and Sonnet.

  • scale_init (Optional[hk.initializers.Initializer]) – Optional initializer for gain (aka scale). By default, one.

  • offset_init (Optional[hk.initializers.Initializer]) – Optional initializer for bias (aka offset). By default, zero.

  • name (Optional[str]) – The module name.

__call__(inputs, scale=None, offset=None)[source]

Connects the layer norm.

Parameters
  • inputs (jnp.ndarray) – An array, where the data format is [N, ..., C].

  • scale (Optional[jnp.ndarray]) – An array up to n-D. The shape of this tensor must be broadcastable to the shape of inputs. This is the scale applied to the normalized inputs. This cannot be passed in if the module was constructed with create_scale=True.

  • offset (Optional[jnp.ndarray]) – An array up to n-D. The shape of this tensor must be broadcastable to the shape of inputs. This is the offset applied to the normalized inputs. This cannot be passed in if the module was constructed with create_offset=True.

Return type

jnp.ndarray

Returns

The array, normalized.

RMSNorm

class haiku.RMSNorm(axis, eps=1e-05, scale_init=None, name=None)[source]

RMSNorm module.

RMSNorm provides an alternative that can be both faster and more stable than LayerNorm. The inputs are normalized by the root-mean-squared (RMS) and scaled by a learned parameter, but they are not recentered around their mean.

See https://arxiv.org/pdf/1910.07467.pdf

__init__(axis, eps=1e-05, scale_init=None, name=None)[source]

Constructs a RMSNorm modulke.

Parameters
  • axis (Union[int, Sequence[int], slice]) – Integer, list of integers, or slice indicating which axes to normalize over.

  • eps (float) – Small epsilon to avoid division by zero variance. Defaults to 1e-5.

  • scale_init (Optional[hk.initializers.Initializer]) – Optional initializer for gain (aka scale). By default, one.

  • name (Optional[str]) – The module name.

__call__(inputs)[source]

Connects the layer norm.

Parameters

inputs (jnp.ndarray) – An array, where the data format is [N, ..., C].

Returns

The normalized array, of the same shape as the inputs..

SpectralNorm

class haiku.SpectralNorm(eps=0.0001, n_steps=1, name=None)[source]

Normalizes an input by its first singular value.

This module uses power iteration to calculate this value based on the input and an internal hidden state.

__init__(eps=0.0001, n_steps=1, name=None)[source]

Initializes an SpectralNorm module.

Parameters
  • eps (float) – The constant used for numerical stability.

  • n_steps (int) – How many steps of power iteration to perform to approximate the singular value of the input.

  • name (Optional[str]) – The name of the module.

__call__(value, update_stats=True, error_on_non_matrix=False)[source]

Performs Spectral Normalization and returns the new value.

Parameters
  • value – The array-like object for which you would like to perform an spectral normalization on.

  • update_stats (bool) – A boolean defaulting to True. Regardless of this arg, this function will return the normalized input. When update_stats is True, the internal state of this object will also be updated to reflect the input value. When update_stats is False the internal stats will remain unchanged.

  • error_on_non_matrix (bool) – Spectral normalization is only defined on matrices. By default, this module will return scalars unchanged and flatten higher-order tensors in their leading dimensions. Setting this flag to True will instead throw errors in those cases.

Return type

jnp.ndarray

Returns

The input value normalized by it’s first singular value.

Raises

ValueError – If error_on_non_matrix is True and value has ndims > 2.

ExponentialMovingAverage

class haiku.ExponentialMovingAverage(decay, zero_debias=True, warmup_length=0, name=None)[source]

Maintains an exponential moving average.

This uses the Adam debiasing procedure. See https://arxiv.org/pdf/1412.6980.pdf for details.

__init__(decay, zero_debias=True, warmup_length=0, name=None)[source]

Initializes an ExponentialMovingAverage module.

Parameters
  • decay – The chosen decay. Must in [0, 1). Values close to 1 result in slow decay; values close to 0 result in fast decay.

  • zero_debias (bool) – Whether to run with zero-debiasing.

  • warmup_length (int) – A positive integer, EMA has no effect until the internal counter has reached warmup_length at which point the initial value for the decaying average is initialized to the input value after warmup_length iterations.

  • name (Optional[str]) – The name of the module.

initialize(shape, dtype=<class 'jax._src.numpy.lax_numpy.float32'>)[source]

If uninitialized sets the average to zeros of the given shape/dtype.

__call__(value, update_stats=True)[source]

Updates the EMA and returns the new value.

Parameters
  • value (jnp.ndarray) – The array-like object for which you would like to perform an exponential decay on.

  • update_stats (bool) – A Boolean, whether to update the internal state of this object to reflect the input value. When update_stats is False the internal stats will remain unchanged.

Return type

jnp.ndarray

Returns

The exponentially weighted average of the input value.

SNParamsTree

class haiku.SNParamsTree(eps=0.0001, n_steps=1, ignore_regex='', name=None)[source]

Applies Spectral Normalization to all parameters in a tree.

This is isomorphic to EMAParamsTree in moving_averages.py.

__init__(eps=0.0001, n_steps=1, ignore_regex='', name=None)[source]

Initializes an SNParamsTree module.

Parameters
  • eps (float) – The constant used for numerical stability.

  • n_steps (int) – How many steps of power iteration to perform to approximate the singular value of the input.

  • ignore_regex (str) – A string. Any parameter in the tree whose name matches this regex will not have spectral normalization applied to it. The empty string means this module applies to all parameters.

  • name (Optional[str]) – The name of the module.

__call__(tree, update_stats=True)[source]

Call self as a function.

EMAParamsTree

class haiku.EMAParamsTree(decay, zero_debias=True, warmup_length=0, ignore_regex='', name=None)[source]

Maintains an exponential moving average for all parameters in a tree.

While ExponentialMovingAverage is meant to be applied to single parameters within a function, this class is meant to be applied to the entire tree of parameters for a function.

Given a set of parameters for some network:

>>> network_fn = lambda x: hk.Linear(10)(x)
>>> x = jnp.ones([1, 1])
>>> params = hk.transform(network_fn).init(jax.random.PRNGKey(428), x)

You might use the EMAParamsTree like follows:

>>> ema_fn = hk.transform_with_state(lambda x: hk.EMAParamsTree(0.2)(x))
>>> _, ema_state = ema_fn.init(None, params)
>>> ema_params, ema_state = ema_fn.apply(None, ema_state, None, params)

Here, we are transforming a Haiku function and constructing its parameters via an init_fn as normal, but are creating a second transformed function which expects a tree of parameters as input. This function is then called with the current parameters as input, which then returns an identical tree with every parameter replaced with its exponentially decayed average. This ema_params object can then be passed into the network_fn as usual, and will cause it to run with EMA weights.

__init__(decay, zero_debias=True, warmup_length=0, ignore_regex='', name=None)[source]

Initializes an EMAParamsTree module.

Parameters
  • decay – The chosen decay. Must in [0, 1). Values close to 1 result in slow decay; values close to 0 result in fast decay.

  • zero_debias (bool) – Whether to run with zero-debiasing.

  • warmup_length (int) – A positive integer, EMA has no effect until the internal counter has reached warmup_length at which point the initial value for the decaying average is initialized to the input value after warmup_length iterations.

  • ignore_regex (str) – A string. Any parameter in the tree whose name matches this regex will not have any moving average applied to it. The empty string means this module will EMA all parameters.

  • name (Optional[str]) – The name of the module.

__call__(tree, update_stats=True)[source]

Call self as a function.

Recurrent

RNNCore([name])

Base class for RNN cores.

dynamic_unroll(core, input_sequence, …[, …])

Performs a dynamic unroll of an RNN.

static_unroll(core, input_sequence, …[, …])

Performs a static unroll of an RNN.

expand_apply(f[, axis])

Wraps f to temporarily add a size-1 axis to its inputs.

VanillaRNN(hidden_size[, double_bias, name])

Basic fully-connected RNN core.

LSTM(hidden_size[, name])

Long short-term memory (LSTM) RNN core.

GRU(hidden_size[, w_i_init, w_h_init, …])

Gated Recurrent Unit.

DeepRNN(layers[, name])

Wraps a sequence of cores and callables as a single core.

deep_rnn_with_skip_connections(layers[, name])

Constructs a DeepRNN with skip connections.

ResetCore(core[, name])

A wrapper for managing state resets during unrolls.

IdentityCore([name])

A recurrent core that forwards the inputs and an empty state.

Conv1DLSTM(input_shape, output_channels, …)

1-D convolutional LSTM.

Conv2DLSTM(input_shape, output_channels, …)

2-D convolutional LSTM.

Conv3DLSTM(input_shape, output_channels, …)

3-D convolutional LSTM.

RNNCore

class haiku.RNNCore(name=None)[source]

Base class for RNN cores.

This class defines the basic functionality that every core should implement: initial_state(), used to construct an example of the core state; and __call__() which applies the core parameterized by a previous state to an input.

Cores may be used with dynamic_unroll() and static_unroll() to iteratively construct an output sequence from the given input sequence.

abstract __call__(inputs, prev_state)[source]

Run one step of the RNN.

Parameters
  • inputs – An arbitrarily nested structure.

  • prev_state – Previous core state.

Return type

Tuple[Any, Any]

Returns

A tuple with two elements output, next_state. output is an arbitrarily nested structure. next_state is the next core state, this must be the same shape as prev_state.

abstract initial_state(batch_size)[source]

Constructs an initial state for this core.

Parameters

batch_size (Optional[int]) – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.

Returns

Arbitrarily nested initial state for this core.

dynamic_unroll

haiku.dynamic_unroll(core, input_sequence, initial_state, time_major=True)[source]

Performs a dynamic unroll of an RNN.

An unroll corresponds to calling the core on each element of the input sequence in a loop, carrying the state through:

state = initial_state
for t in range(len(input_sequence)):
   outputs, state = core(input_sequence[t], state)

A dynamic unroll preserves the loop structure when executed inside jax.jit(). See static_unroll() for an unroll function which replaces a loop with its body repeated multiple times.

Parameters
  • core – An RNNCore to unroll.

  • input_sequence – An arbitrarily nested structure of tensors of shape [T, ...] if time-major=True, or [B, T, ...] if time_major=False, where T is the number of time steps.

  • initial_state – An initial state of the given core.

  • time_major – If True, inputs are expected time-major, otherwise they are expected batch-major.

Returns

  • output_sequence - An arbitrarily nested structure of tensors of shape [T, ...] if time-major, otherwise [B, T, ...].

  • final_state - Core state at time step T.

Return type

A tuple with two elements

static_unroll

haiku.static_unroll(core, input_sequence, initial_state, time_major=True)[source]

Performs a static unroll of an RNN.

An unroll corresponds to calling the core on each element of the input sequence in a loop, carrying the state through:

state = initial_state
for t in range(len(input_sequence)):
   outputs, state = core(input_sequence[t], state)

A static unroll replaces a loop with its body repeated multiple times when executed inside jax.jit():

state = initial_state
outputs0, state = core(input_sequence[0], state)
outputs1, state = core(input_sequence[1], state)
outputs2, state = core(input_sequence[2], state)
...

See dynamic_unroll() for a loop-preserving unroll function.

Parameters
  • core – An RNNCore to unroll.

  • input_sequence – An arbitrarily nested structure of tensors of shape [T, ...] if time-major=True, or [B, T, ...] if time_major=False, where T is the number of time steps.

  • initial_state – An initial state of the given core.

  • time_major – If True, inputs are expected time-major, otherwise they are expected batch-major.

Returns

  • output_sequence - An arbitrarily nested structure of tensors of shape [T, ...] if time-major, otherwise [B, T, ...].

  • final_state - Core state at time step T.

Return type

A tuple with two elements

expand_apply

haiku.expand_apply(f, axis=0)[source]

Wraps f to temporarily add a size-1 axis to its inputs.

Syntactic sugar for:

ins = jax.tree_util.tree_map(lambda t: np.expand_dims(t, axis=axis), ins)
out = f(ins)
out = jax.tree_util.tree_map(lambda t: np.squeeze(t, axis=axis), out)

This may be useful for applying a function built for [Time, Batch, ...] arrays to a single timestep.

Parameters
  • f – The callable to be applied to the expanded inputs.

  • axis – Where to add the extra axis.

Returns

f, wrapped as described above.

VanillaRNN

class haiku.VanillaRNN(hidden_size, double_bias=True, name=None)[source]

Basic fully-connected RNN core.

Given \(x_t\) and the previous hidden state \(h_{t-1}\) the core computes

\[h_t = \operatorname{ReLU}(w_i x_t + b_i + w_h h_{t-1} + b_h)\]

The output is equal to the new state, \(h_t\).

__init__(hidden_size, double_bias=True, name=None)[source]

Constructs a vanilla RNN core.

Parameters
  • hidden_size (int) – Hidden layer size.

  • double_bias (bool) – Whether to use a bias in the two linear layers. This changes nothing to the learning performance of the cell. However, doubling will create two sets of bias parameters rather than one.

  • name (Optional[str]) – Name of the module.

__call__(inputs, prev_state)[source]

Run one step of the RNN.

Parameters
  • inputs – An arbitrarily nested structure.

  • prev_state – Previous core state.

Returns

A tuple with two elements output, next_state. output is an arbitrarily nested structure. next_state is the next core state, this must be the same shape as prev_state.

initial_state(batch_size)[source]

Constructs an initial state for this core.

Parameters

batch_size (Optional[int]) – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.

Returns

Arbitrarily nested initial state for this core.

LSTM

class haiku.LSTM(hidden_size, name=None)[source]

Long short-term memory (LSTM) RNN core.

The implementation is based on [1]. Given \(x_t\) and the previous state \((h_{t-1}, c_{t-1})\) the core computes

\[\begin{array}{ll} i_t = \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\ f_t = \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\ o_t = \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\]

where \(i_t\), \(f_t\), \(o_t\) are input, forget and output gate activations, and \(g_t\) is a vector of cell updates.

The output is equal to the new hidden, \(h_t\).

Notes

Forget gate initialization:

Following [2] we add 1.0 to \(b_f\) after initialization in order to reduce the scale of forgetting in the beginning of the training.

__init__(hidden_size, name=None)[source]

Constructs an LSTM.

Parameters
  • hidden_size (int) – Hidden layer size.

  • name (Optional[str]) – Name of the module.

__call__(inputs, prev_state)[source]

Run one step of the RNN.

Parameters
  • inputs (jnp.ndarray) – An arbitrarily nested structure.

  • prev_state (LSTMState) – Previous core state.

Return type

Tuple[jnp.ndarray, LSTMState]

Returns

A tuple with two elements output, next_state. output is an arbitrarily nested structure. next_state is the next core state, this must be the same shape as prev_state.

initial_state(batch_size)[source]

Constructs an initial state for this core.

Parameters

batch_size (Optional[int]) – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.

Return type

LSTMState

Returns

Arbitrarily nested initial state for this core.

GRU

class haiku.GRU(hidden_size, w_i_init=None, w_h_init=None, b_init=None, name=None)[source]

Gated Recurrent Unit.

The implementation is based on: https://arxiv.org/pdf/1412.3555v1.pdf with biases.

Given \(x_t\) and the previous state \(h_{t-1}\) the core computes

\[\begin{array}{ll} z_t &= \sigma(W_{iz} x_t + W_{hz} h_{t-1} + b_z) \\ r_t &= \sigma(W_{ir} x_t + W_{hr} h_{t-1} + b_r) \\ a_t &= \tanh(W_{ia} x_t + W_{ha} (r_t \bigodot h_{t-1}) + b_a) \\ h_t &= (1 - z_t) \bigodot h_{t-1} + z_t \bigodot a_t \end{array}\]

where \(z_t\) and \(r_t\) are reset and update gates.

The output is equal to the new hidden state, \(h_t\).

Warning: Backwards compatibility of GRU weights is currently unsupported.

TODO(tycai): Make policy decision/benchmark performance for GRU variants.

__init__(hidden_size, w_i_init=None, w_h_init=None, b_init=None, name=None)[source]

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

__call__(inputs, state)[source]

Run one step of the RNN.

Parameters
  • inputs – An arbitrarily nested structure.

  • prev_state – Previous core state.

Returns

A tuple with two elements output, next_state. output is an arbitrarily nested structure. next_state is the next core state, this must be the same shape as prev_state.

initial_state(batch_size)[source]

Constructs an initial state for this core.

Parameters

batch_size (Optional[int]) – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.

Returns

Arbitrarily nested initial state for this core.

DeepRNN

class haiku.DeepRNN(layers, name=None)[source]

Wraps a sequence of cores and callables as a single core.

>>> deep_rnn = hk.DeepRNN([
...     hk.LSTM(hidden_size=4),
...     jax.nn.relu,
...     hk.LSTM(hidden_size=2),
... ])

The state of a DeepRNN is a tuple with one element per RNNCore. If no layers are RNNCores, the state is an empty tuple.

__init__(layers, name=None)[source]

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

haiku.deep_rnn_with_skip_connections(layers, name=None)[source]

Constructs a DeepRNN with skip connections.

Skip connections alter the dependency structure within a DeepRNN. Specifically, input to the i-th layer (i > 0) is given by a concatenation of the core’s inputs and the outputs of the (i-1)-th layer.

The output of the DeepRNN is the concatenation of the outputs of all cores.

outputs0, ... = layers[0](inputs, ...)
outputs1, ... = layers[1](tf.concat([inputs, outputs0], axis=-1], ...)
outputs2, ... = layers[2](tf.concat([inputs, outputs1], axis=-1], ...)
...
Parameters
  • layers (Sequence[RNNCore]) – List of RNNCores.

  • name (Optional[str]) – Name of the module.

Return type

RNNCore

Returns

A _DeepRNN with skip connections.

Raises

ValueError – If any of the layers is not an RNNCore.

ResetCore

class haiku.ResetCore(core, name=None)[source]

A wrapper for managing state resets during unrolls.

When unrolling an RNNCore on a batch of inputs sequences it may be necessary to reset the core’s state at different timesteps for different elements of the batch. The ResetCore class enables this by taking a batch of should_reset booleans in addition to the batch of inputs, and conditionally resetting the core’s state for individual elements of the batch. You may also reset individual entries of the state by passing a should_reset nest compatible with the state structure.

__init__(core, name=None)[source]

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

__call__(inputs, state)[source]

Run one step of the wrapped core, handling state reset.

Parameters
  • inputs – Tuple with two elements, inputs, should_reset, where should_reset is the signal used to reset the wrapped core’s state. should_reset can be either tensor or nest. If nest, should_reset must match the state structure, and its components’ shapes must be prefixes of the corresponding entries tensors’ shapes in the state nest. If tensor, supported shapes are all commom shape prefixes of the state component tensors, e.g. [batch_size].

  • state – Previous wrapped core state.

Returns

Tuple of the wrapped core’s output, next_state.

initial_state(batch_size)[source]

Constructs an initial state for this core.

Parameters

batch_size (Optional[int]) – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.

Returns

Arbitrarily nested initial state for this core.

IdentityCore

class haiku.IdentityCore(name=None)[source]

A recurrent core that forwards the inputs and an empty state.

This is commonly used when switching between recurrent and feedforward versions of a model while preserving the same interface.

__call__(inputs, state)[source]

Run one step of the RNN.

Parameters
  • inputs – An arbitrarily nested structure.

  • prev_state – Previous core state.

Returns

A tuple with two elements output, next_state. output is an arbitrarily nested structure. next_state is the next core state, this must be the same shape as prev_state.

initial_state(batch_size)[source]

Constructs an initial state for this core.

Parameters

batch_size (Optional[int]) – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.

Returns

Arbitrarily nested initial state for this core.

Conv1DLSTM

class haiku.Conv1DLSTM(input_shape, output_channels, kernel_shape, name=None)[source]

1-D convolutional LSTM.

The implementation is based on [3]. Given \(x_t\) and the previous state \((h_{t-1}, c_{t-1})\) the core computes

\[\begin{array}{ll} i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\ f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\ o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\]

where \(*\) denotes the convolution operator; \(i_t\), \(f_t\), \(o_t\) are input, forget and output gate activations, and \(g_t\) is a vector of cell updates.

The output is equal to the new hidden state, \(h_t\).

Notes

Forget gate initialization:

Following [2] we add 1.0 to \(b_f\) after initialization in order to reduce the scale of forgetting in the beginning of the training.

__init__(input_shape, output_channels, kernel_shape, name=None)[source]

Constructs a 1-D convolutional LSTM.

Parameters
  • input_shape (Sequence[int]) – Shape of the inputs excluding batch size.

  • output_channels (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – Sequence of kernel sizes (of length 1), or an int. kernel_shape will be expanded to define a kernel size in all dimensions.

  • name (Optional[str]) – Name of the module.

Conv2DLSTM

class haiku.Conv2DLSTM(input_shape, output_channels, kernel_shape, name=None)[source]

2-D convolutional LSTM.

The implementation is based on [3]. Given \(x_t\) and the previous state \((h_{t-1}, c_{t-1})\) the core computes

\[\begin{array}{ll} i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\ f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\ o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\]

where \(*\) denotes the convolution operator; \(i_t\), \(f_t\), \(o_t\) are input, forget and output gate activations, and \(g_t\) is a vector of cell updates.

The output is equal to the new hidden state, \(h_t\).

Notes

Forget gate initialization:

Following [2] we add 1.0 to \(b_f\) after initialization in order to reduce the scale of forgetting in the beginning of the training.

__init__(input_shape, output_channels, kernel_shape, name=None)[source]

Constructs a 2-D convolutional LSTM.

Parameters
  • input_shape (Sequence[int]) – Shape of the inputs excluding batch size.

  • output_channels (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – Sequence of kernel sizes (of length 2), or an int. kernel_shape will be expanded to define a kernel size in all dimensions.

  • name (Optional[str]) – Name of the module.

Conv3DLSTM

class haiku.Conv3DLSTM(input_shape, output_channels, kernel_shape, name=None)[source]

3-D convolutional LSTM.

The implementation is based on [3]. Given \(x_t\) and the previous state \((h_{t-1}, c_{t-1})\) the core computes

\[\begin{array}{ll} i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\ f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\ o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\]

where \(*\) denotes the convolution operator; \(i_t\), \(f_t\), \(o_t\) are input, forget and output gate activations, and \(g_t\) is a vector of cell updates.

The output is equal to the new hidden state, \(h_t\).

Notes

Forget gate initialization:

Following [2] we add 1.0 to \(b_f\) after initialization in order to reduce the scale of forgetting in the beginning of the training.

__init__(input_shape, output_channels, kernel_shape, name=None)[source]

Constructs a 3-D convolutional LSTM.

Parameters
  • input_shape (Sequence[int]) – Shape of the inputs excluding batch size.

  • output_channels (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – Sequence of kernel sizes (of length 3), or an int. kernel_shape will be expanded to define a kernel size in all dimensions.

  • name (Optional[str]) – Name of the module.

Attention

MultiHeadAttention

class haiku.MultiHeadAttention(num_heads, key_size, w_init_scale, value_size=None, model_size=None, name=None)[source]

Multi-headed attention mechanism.

As described in the vanilla Transformer paper:

“Attention is all you need” https://arxiv.org/abs/1706.03762

__init__(num_heads, key_size, w_init_scale, value_size=None, model_size=None, name=None)[source]

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

__call__(query, key, value, mask=None)[source]

Compute (optionally masked) MHA with queries, keys & values.

Return type

jnp.ndarray

Batch

Reshape(output_shape[, preserve_dims, name])

Reshapes input Tensor, preserving the batch dimension.

Flatten([preserve_dims, name])

Flattens the input, preserving the batch dimension(s).

BatchApply(f[, num_dims])

Temporarily merges leading dimensions of input tensors.

Reshape

class haiku.Reshape(output_shape, preserve_dims=1, name=None)[source]

Reshapes input Tensor, preserving the batch dimension.

For example, given an input tensor with shape [B, H, W, C, D]:

>>> B, H, W, C, D = range(1, 6)
>>> x = jnp.ones([B, H, W, C, D])

The default behavior when output_shape is (-1, D) is to flatten all dimensions between B and D:

>>> mod = hk.Reshape(output_shape=(-1, D))
>>> assert mod(x).shape == (B, H*W*C, D)

You can change the number of preserved leading dimensions via preserve_dims:

>>> mod = hk.Reshape(output_shape=(-1, D), preserve_dims=2)
>>> assert mod(x).shape == (B, H, W*C, D)

>>> mod = hk.Reshape(output_shape=(-1, D), preserve_dims=3)
>>> assert mod(x).shape == (B, H, W, C, D)

>>> mod = hk.Reshape(output_shape=(-1, D), preserve_dims=4)
>>> assert mod(x).shape == (B, H, W, C, 1, D)

Alternatively, a negative value of preserve_dims specifies the number of trailing dimensions to replace with output_shape:

>>> mod = hk.Reshape(output_shape=(-1, D), preserve_dims=-3)
>>> assert mod(x).shape == (B, H, W*C, D)

This is useful in the case of applying the same module to batched and unbatched outputs:

>>> mod = hk.Reshape(output_shape=(-1, D), preserve_dims=-3)
>>> assert mod(x[0]).shape == (H, W*C, D)
__init__(output_shape, preserve_dims=1, name=None)[source]

Constructs a Reshape module.

Parameters
  • output_shape (Sequence[int]) – Shape to reshape the input tensor to while preserving its first preserve_dims dimensions. When the special value -1 appears in output_shape the corresponding size is automatically inferred. Note that -1 can only appear once in output_shape. To flatten all non-batch dimensions use Flatten.

  • preserve_dims (int) – Number of leading dimensions that will not be reshaped. If negative, this is interpreted instead as the number of trailing dimensions to replace with the new shape.

  • name (Optional[str]) – Name of the module.

Raises

ValueError – If preserve_dims is zero.

__call__(inputs)[source]

Call self as a function.

Flatten

class haiku.Flatten(preserve_dims=1, name=None)[source]

Flattens the input, preserving the batch dimension(s).

By default, Flatten combines all dimensions except the first. Additional leading dimensions can be preserved by setting preserve_dims.

>>> x = jnp.ones([3, 2, 4])
>>> flat = hk.Flatten()
>>> flat(x).shape
(3, 8)

When the input to flatten has fewer than preserve_dims dimensions it is returned unchanged:

>>> x = jnp.ones([3])
>>> flat(x).shape
(3,)

Alternatively, a negative value of preserve_dims specifies the number of trailing dimensions flattened:

>>> x = jnp.ones([3, 2, 4])
>>> negative_flat = hk.Flatten(preserve_dims=-2)
>>> negative_flat(x).shape
(3, 8)

This allows the same module to be seamlessly applied to a single element or a batch of elements with the same element shape:

>> negative_flat(x[0]).shape (8,)

__init__(preserve_dims=1, name=None)[source]

Constructs a Reshape module.

Parameters
  • output_shape – Shape to reshape the input tensor to while preserving its first preserve_dims dimensions. When the special value -1 appears in output_shape the corresponding size is automatically inferred. Note that -1 can only appear once in output_shape. To flatten all non-batch dimensions use Flatten.

  • preserve_dims (int) – Number of leading dimensions that will not be reshaped. If negative, this is interpreted instead as the number of trailing dimensions to replace with the new shape.

  • name (Optional[str]) – Name of the module.

Raises

ValueError – If preserve_dims is zero.

BatchApply

class haiku.BatchApply(f, num_dims=2)[source]

Temporarily merges leading dimensions of input tensors.

Merges the leading dimensions of a tensor into a single dimension, runs the given callable, then splits the leading dimension of the result to match the input.

Input arrays whose rank is smaller than the number of dimensions to collapse are passed unmodified.

This may be useful for applying a module to each timestep of e.g. a [Time, Batch, ...] array.

For some fs and platforms, this may be more efficient than jax.vmap(), especially when combined with other transformations like jax.grad().

__init__(f, num_dims=2)[source]

Constructs a BatchApply module.

Parameters
  • f – The callable to be applied to the reshaped array.

  • num_dims – The number of dimensions to merge.

__call__(*args, **kwargs)[source]

Call self as a function.

Embedding

Embed([vocab_size, embed_dim, …])

Module for embedding tokens in a low-dimensional space.

EmbedLookupStyle(value)

How to return the embedding matrices given IDs.

Embed

class haiku.Embed(vocab_size=None, embed_dim=None, embedding_matrix=None, w_init=None, lookup_style='ARRAY_INDEX', name=None)[source]

Module for embedding tokens in a low-dimensional space.

__init__(vocab_size=None, embed_dim=None, embedding_matrix=None, w_init=None, lookup_style='ARRAY_INDEX', name=None)[source]

Constructs an Embed module.

Parameters
  • vocab_size (Optional[int]) – The number of unique tokens to embed. If not provided, an existing vocabulary matrix from which vocab_size can be inferred must be provided as embedding_matrix.

  • embed_dim (Optional[int]) – Number of dimensions to assign to each embedding. If an existing vocabulary matrix initializes the module, this should not be provided as it will be inferred.

  • embedding_matrix (Optional[jnp.ndarray]) – A matrix-like object equivalent in size to [vocab_size, embed_dim]. If given, it is used as the initial value for the embedding matrix and neither vocab_size or embed_dim need be given. If they are given, their values are checked to be consistent with the dimensions of embedding_matrix.

  • w_init (Optional[hk.initializers.Initializer]) – An initializer for the embeddings matrix. As a default, embeddings are initialized via a truncated normal distribution.

  • lookup_style (Union[str, hk.EmbedLookupStyle]) – One of the enum values of EmbedLookupStyle determining how to access the value of the embeddings given an ID. Regardless the input should be a dense array of integer values representing ids. This setting changes how internally this module maps those ids to embeddings. The result is the same, but the speed and memory tradeoffs are different. It defaults to using NumPy-style array indexing. This value is only the default for the module, and at any given invocation can be overridden in __call__().

  • name (Optional[str]) – Optional name for this module.

Raises

ValueError – If none of embed_dim, embedding_matrix and vocab_size are supplied, or if embedding_matrix is supplied and embed_dim or vocab_size is not consistent with the supplied matrix.

__call__(ids, lookup_style=None)[source]

Lookup embeddings.

Looks up an embedding vector for each value in ids. All ids must be within [0, vocab_size) to prevent NaNs from propagating.

Parameters
  • ids (jnp.ndarray) – integer array.

  • lookup_style (Optional[Union[str, hk.EmbedLookupStyle]]) – Overrides the lookup_style given in the constructor.

Return type

jnp.ndarray

Returns

Tensor of ids.shape + [embedding_dim].

Raises
  • AttributeError – If lookup_style is not valid.

  • ValueError – If ids is not an integer array.

EmbedLookupStyle

class haiku.EmbedLookupStyle(value)[source]

How to return the embedding matrices given IDs.

ARRAY_INDEX = 1
ONE_HOT = 2

Initializers

Initializer

The central part of internal API.

Constant(constant)

Initializes with a constant.

Identity([gain])

Initializer that generates the identity matrix.

Orthogonal([scale, axis])

Uniform scaling initializer.

RandomNormal([stddev, mean])

Initializes by sampling from a normal distribution.

RandomUniform([minval, maxval])

Initializes by sampling from a uniform distribution.

TruncatedNormal([stddev, mean])

Initializes by sampling from a truncated normal distribution.

VarianceScaling([scale, mode, distribution])

Initializer which adapts its scale to the shape of the initialized array.

UniformScaling([scale])

Uniform scaling initializer.

Initializer

haiku.initializers.Initializer

Constant

class haiku.initializers.Constant(constant)[source]

Initializes with a constant.

__init__(constant)[source]

Constructs a Constant initializer.

Parameters

constant (Union[float, int, jnp.ndarray]) – Constant to initialize with.

__call__(shape, dtype)[source]

Call self as a function.

Return type

jnp.ndarray

Identity

class haiku.initializers.Identity(gain=1.0)[source]

Initializer that generates the identity matrix.

Constructs a 2D identity matrix or batches of these.

__init__(gain=1.0)[source]

Constructs an Identity initializer.

Parameters

gain (Union[float, jnp.ndarray]) – Multiplicative factor to apply to the identity matrix.

__call__(shape, dtype)[source]

Call self as a function.

Return type

jnp.ndarray

Orthogonal

class haiku.initializers.Orthogonal(scale=1.0, axis=- 1)[source]

Uniform scaling initializer.

__init__(scale=1.0, axis=- 1)[source]

Construct an initializer for uniformly distributed orthogonal matrices.

These matrices will be row-orthonormal along the access specified by axis. If the rank of the weight is greater than 2, the shape will be flattened in all other dimensions and then will be row-orthonormal along the final dimension. Note that this only works if the axis dimension is larger, otherwise the matrix will be transposed (equivalently, it will be column orthonormal instead of row orthonormal).

If the shape is not square, the matrices will have orthonormal rows or columns depending on which side is smaller.

Parameters
  • scale – Scale factor.

  • axis – Which axis corresponds to the “output dimension” of the tensor.

Returns

An orthogonally initialized parameter.

__call__(shape, dtype)[source]

Call self as a function.

Return type

jnp.ndarray

RandomNormal

class haiku.initializers.RandomNormal(stddev=1.0, mean=0.0)[source]

Initializes by sampling from a normal distribution.

__init__(stddev=1.0, mean=0.0)[source]

Constructs a RandomNormal initializer.

Parameters
  • stddev – The standard deviation of the normal distribution to sample from.

  • mean – The mean of the normal distribution to sample from.

__call__(shape, dtype)[source]

Call self as a function.

Return type

jnp.ndarray

RandomUniform

class haiku.initializers.RandomUniform(minval=0.0, maxval=1.0)[source]

Initializes by sampling from a uniform distribution.

__init__(minval=0.0, maxval=1.0)[source]

Constructs a RandomUniform initializer.

Parameters
  • minval – The lower limit of the uniform distribution.

  • maxval – The upper limit of the uniform distribution.

__call__(shape, dtype)[source]

Call self as a function.

Return type

jnp.ndarray

TruncatedNormal

class haiku.initializers.TruncatedNormal(stddev=1.0, mean=0.0)[source]

Initializes by sampling from a truncated normal distribution.

__init__(stddev=1.0, mean=0.0)[source]

Constructs a TruncatedNormal initializer.

Parameters
  • stddev (Union[float, jnp.ndarray]) – The standard deviation parameter of the truncated normal distribution.

  • mean (Union[float, jnp.ndarray]) – The mean of the truncated normal distribution.

__call__(shape, dtype)[source]

Call self as a function.

Return type

jnp.ndarray

VarianceScaling

class haiku.initializers.VarianceScaling(scale=1.0, mode='fan_in', distribution='truncated_normal')[source]

Initializer which adapts its scale to the shape of the initialized array.

The initializer first computes the scaling factor s = scale / n, where n is:

  • Number of input units in the weight tensor, if mode = fan_in.

  • Number of output units, if mode = fan_out.

  • Average of the numbers of input and output units, if mode = fan_avg.

Then, with distribution="truncated_normal" or "normal", samples are drawn from a distribution with a mean of zero and a standard deviation (after truncation, if used) stddev = sqrt(s).

With distribution=uniform, samples are drawn from a uniform distribution within [-limit, limit], with limit = sqrt(3 * s).

The variance scaling initializer can be configured to generate other standard initializers using the scale, mode and distribution arguments. Here are some example configurations:

Name

Parameters

glorot_uniform

VarianceScaling(1.0, “fan_avg”, “uniform”)

glorot_normal

VarianceScaling(1.0, “fan_avg”, “truncated_normal”)

lecun_uniform

VarianceScaling(1.0, “fan_in”, “uniform”)

lecun_normal

VarianceScaling(1.0, “fan_in”, “truncated_normal”)

he_uniform

VarianceScaling(2.0, “fan_in”, “uniform”)

he_normal

VarianceScaling(2.0, “fan_in”, “truncated_normal”)

__init__(scale=1.0, mode='fan_in', distribution='truncated_normal')[source]

Constructs the VarianceScaling initializer.

Parameters
  • scale – Scale to multiply the variance by.

  • mode – One of fan_in, fan_out, fan_avg

  • distribution – Random distribution to use. One of truncated_normal, normal or uniform.

__call__(shape, dtype)[source]

Call self as a function.

Return type

jnp.ndarray

UniformScaling

class haiku.initializers.UniformScaling(scale=1.0)[source]

Uniform scaling initializer.

Initializes by sampling from a uniform distribution, but with the variance scaled by the inverse square root of the number of input units, multiplied by the scale.

__init__(scale=1.0)[source]

Constructs the UniformScaling initializer.

Parameters

scale – Scale to multiply the upper limit of the uniform distribution by.

__call__(shape, dtype)[source]

Call self as a function.

Return type

jnp.ndarray

Paddings

PadFn

The central part of internal API.

is_padfn(padding)

Tests whether the given argument is a single or sequence of PadFns.

create(padding, kernel, rate, n)

Generates the padding required for a given padding algorithm.

create_from_padfn(padding, kernel, rate, n)

Generates the padding required for a given padding algorithm.

create_from_tuple(padding, n)

Create a padding tuple using partially specified padding tuple.

causal(effective_kernel_size)

Pre-padding such that output has no dependence on the future.

full(effective_kernel_size)

Maximal padding whilst not convolving over just padded elements.

reverse_causal(effective_kernel_size)

Post-padding such that output has no dependence on the past.

same(effective_kernel_size)

Pads such that the output size matches input size for stride=1.

valid(effective_kernel_size)

No padding.

PadFn

haiku.pad.PadFn

is_padfn

haiku.pad.is_padfn(padding)[source]

Tests whether the given argument is a single or sequence of PadFns.

Return type

bool

create

haiku.pad.create(padding, kernel, rate, n)[source]

Generates the padding required for a given padding algorithm.

Parameters
  • padding (Union[hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – callable/tuple or a sequence of callables/tuples. The callables take an integer representing the effective kernel size (kernel size when the rate is 1) and return a sequence of two integers representing the padding before and padding after for that dimension. The tuples are defined with two elements, padding before and after. If padding is a sequence it must be of length 1 or n.

  • kernel (Union[int, Sequence[int]]) – int or sequence of ints of length n. The size of the kernel for each dimension. If it is an int it will be replicated for the non channel and batch dimensions.

  • rate (Union[int, Sequence[int]]) – int or sequence of ints of length n. The dilation rate for each dimension. If it is an int it will be replicated for the non channel and batch dimensions.

  • n (int) – the number of spatial dimensions.

Return type

Sequence[Tuple[int, int]]

Returns

A sequence of length n containing the padding for each element. These are of the form [pad_before, pad_after].

create_from_padfn

haiku.pad.create_from_padfn(padding, kernel, rate, n)[source]

Generates the padding required for a given padding algorithm.

Parameters
  • padding (Union[hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – callable/tuple or a sequence of callables/tuples. The callables take an integer representing the effective kernel size (kernel size when the rate is 1) and return a sequence of two integers representing the padding before and padding after for that dimension. The tuples are defined with two elements, padding before and after. If padding is a sequence it must be of length 1 or n.

  • kernel (Union[int, Sequence[int]]) – int or sequence of ints of length n. The size of the kernel for each dimension. If it is an int it will be replicated for the non channel and batch dimensions.

  • rate (Union[int, Sequence[int]]) – int or sequence of ints of length n. The dilation rate for each dimension. If it is an int it will be replicated for the non channel and batch dimensions.

  • n (int) – the number of spatial dimensions.

Return type

Sequence[Tuple[int, int]]

Returns

A sequence of length n containing the padding for each element. These are of the form [pad_before, pad_after].

create_from_tuple

haiku.pad.create_from_tuple(padding, n)[source]

Create a padding tuple using partially specified padding tuple.

Return type

Sequence[Tuple[int, int]]

causal

haiku.pad.causal(effective_kernel_size)[source]

Pre-padding such that output has no dependence on the future.

Return type

Tuple[int, int]

full

haiku.pad.full(effective_kernel_size)[source]

Maximal padding whilst not convolving over just padded elements.

Return type

Tuple[int, int]

reverse_causal

haiku.pad.reverse_causal(effective_kernel_size)[source]

Post-padding such that output has no dependence on the past.

Return type

Tuple[int, int]

same

haiku.pad.same(effective_kernel_size)[source]

Pads such that the output size matches input size for stride=1.

Return type

Tuple[int, int]

valid

haiku.pad.valid(effective_kernel_size)[source]

No padding.

Return type

Tuple[int, int]

Full Networks

MLP

class haiku.nets.MLP(output_sizes, w_init=None, b_init=None, with_bias=True, activation=<jax.custom_derivatives.custom_jvp object>, activate_final=False, name=None)[source]

A multi-layer perceptron module.

__init__(output_sizes, w_init=None, b_init=None, with_bias=True, activation=<jax.custom_derivatives.custom_jvp object>, activate_final=False, name=None)[source]

Constructs an MLP.

Parameters
  • output_sizes (Iterable[int]) – Sequence of layer sizes.

  • w_init (Optional[hk.initializers.Initializer]) – Initializer for Linear weights.

  • b_init (Optional[hk.initializers.Initializer]) – Initializer for Linear bias. Must be None if with_bias=False.

  • with_bias (bool) – Whether or not to apply a bias in each layer.

  • activation (Callable[[jnp.ndarray], jnp.ndarray]) – Activation function to apply between Linear layers. Defaults to ReLU.

  • activate_final (bool) – Whether or not to activate the final layer of the MLP.

  • name (Optional[str]) – Optional name for this module.

Raises

ValueError – If with_bias is False and b_init is not None.

__call__(inputs, dropout_rate=None, rng=None)[source]

Connects the module to some inputs.

Parameters
  • inputs (jnp.ndarray) – A Tensor of shape [batch_size, input_size].

  • dropout_rate (Optional[float]) – Optional dropout rate.

  • rng – Optional RNG key. Require when using dropout.

Return type

jnp.ndarray

Returns

The output of the model of size [batch_size, output_size].

reverse(activate_final=None, name=None)[source]

Returns a new MLP which is the layer-wise reverse of this MLP.

NOTE: Since computing the reverse of an MLP requires knowing the input size of each linear layer this method will fail if the module has not been called at least once.

The contract of reverse is that the reversed module will accept the output of the parent module as input and produce an output which is the input size of the parent.

>>> mlp = hk.nets.MLP([1, 2, 3])
>>> y = mlp(jnp.ones([1, 2]))
>>> rev = mlp.reverse()
>>> rev(y)
DeviceArray(...)
Parameters
  • activate_final (Optional[bool]) – Whether the final layer of the MLP should be activated.

  • name (Optional[str]) – Optional name for the new module. The default name will be the name of the current module prefixed with "reversed_".

Return type

‘MLP’

Returns

An MLP instance which is the reverse of the current instance. Note these instances do not share weights and, apart from being symmetric to each other, are not coupled in any way.

MobileNet

MobileNetV1

class haiku.nets.MobileNetV1(strides=(1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1), channels=(64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024), num_classes=1000, use_bn=True, name=None)[source]

MobileNetV1 model.

__init__(strides=(1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1), channels=(64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024), num_classes=1000, use_bn=True, name=None)[source]

Constructs a MobileNetV1 model.

Parameters
  • strides (Sequence[int]) – The stride to use the in depthwise convolution in each mobilenet block.

  • channels (Sequence[int]) – Number of output channels from the pointwise convolution to use in each block.

  • num_classes (int) – Number of classes.

  • use_bn (bool) – Whether or not to use batch normalization. Defaults to True. When true, biases are not used. When false, biases are used.

  • name (Optional[str]) – Name of the module.

__call__(inputs, is_training)[source]

Call self as a function.

Return type

jnp.ndarray

ResNet

ResNet(blocks_per_group, num_classes[, …])

ResNet model.

ResNet.BlockGroup(channels, num_blocks, …)

Higher level block for ResNet implementation.

ResNet.BlockV1(channels, stride, …[, name])

ResNet V1 block with optional bottleneck.

ResNet.BlockV2(channels, stride, …[, name])

ResNet V2 block with optional bottleneck.

ResNet18(num_classes[, bn_config, …])

ResNet18.

ResNet34(num_classes[, bn_config, …])

ResNet34.

ResNet50(num_classes[, bn_config, …])

ResNet50.

ResNet101(num_classes[, bn_config, …])

ResNet101.

ResNet152(num_classes[, bn_config, …])

ResNet152.

ResNet200(num_classes[, bn_config, …])

ResNet200.

ResNet

class haiku.nets.ResNet(blocks_per_group, num_classes, bn_config=None, resnet_v2=False, bottleneck=True, channels_per_group=(256, 512, 1024, 2048), use_projection=(True, True, True, True), logits_config=None, name=None, initial_conv_config=None)[source]

ResNet model.

class BlockGroup(channels, num_blocks, stride, bn_config, resnet_v2, bottleneck, use_projection, name=None)[source]

Higher level block for ResNet implementation.

__call__(inputs, is_training, test_local_stats)[source]

Call self as a function.

__init__(channels, num_blocks, stride, bn_config, resnet_v2, bottleneck, use_projection, name=None)[source]

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

class BlockV1(channels, stride, use_projection, bn_config, bottleneck, name=None)[source]

ResNet V1 block with optional bottleneck.

__call__(inputs, is_training, test_local_stats)[source]

Call self as a function.

__init__(channels, stride, use_projection, bn_config, bottleneck, name=None)[source]

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

class BlockV2(channels, stride, use_projection, bn_config, bottleneck, name=None)[source]

ResNet V2 block with optional bottleneck.

__call__(inputs, is_training, test_local_stats)[source]

Call self as a function.

__init__(channels, stride, use_projection, bn_config, bottleneck, name=None)[source]

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

__init__(blocks_per_group, num_classes, bn_config=None, resnet_v2=False, bottleneck=True, channels_per_group=(256, 512, 1024, 2048), use_projection=(True, True, True, True), logits_config=None, name=None, initial_conv_config=None)[source]

Constructs a ResNet model.

Parameters
  • blocks_per_group (Sequence[int]) – A sequence of length 4 that indicates the number of blocks created in each group.

  • num_classes (int) – The number of classes to classify the inputs into.

  • bn_config (Optional[Mapping[str, FloatStrOrBool]]) – A dictionary of two elements, decay_rate and eps to be passed on to the BatchNorm layers. By default the decay_rate is 0.9 and eps is 1e-5.

  • resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to False.

  • bottleneck (bool) – Whether the block should bottleneck or not. Defaults to True.

  • channels_per_group (Sequence[int]) – A sequence of length 4 that indicates the number of channels used for each block in each group.

  • use_projection (Sequence[bool]) – A sequence of length 4 that indicates whether each residual block should use projection.

  • logits_config (Optional[Mapping[str, Any]]) – A dictionary of keyword arguments for the logits layer.

  • name (Optional[str]) – Name of the module.

  • initial_conv_config (Optional[Mapping[str, FloatStrOrBool]]) – Keyword arguments passed to the constructor of the initial Conv2D module.

__call__(inputs, is_training, test_local_stats=False)[source]

Call self as a function.

ResNet18

class haiku.nets.ResNet18(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None)[source]

ResNet18.

__init__(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None)[source]

Constructs a ResNet model.

Parameters
  • num_classes (int) – The number of classes to classify the inputs into.

  • bn_config (Optional[Mapping[str, FloatStrOrBool]]) – A dictionary of two elements, decay_rate and eps to be passed on to the BatchNorm layers.

  • resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to False.

  • logits_config (Optional[Mapping[str, Any]]) – A dictionary of keyword arguments for the logits layer.

  • name (Optional[str]) – Name of the module.

  • initial_conv_config (Optional[Mapping[str, FloatStrOrBool]]) – Keyword arguments passed to the constructor of the initial Conv2D module.

ResNet34

class haiku.nets.ResNet34(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None)[source]

ResNet34.

__init__(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None)[source]

Constructs a ResNet model.

Parameters
  • num_classes (int) – The number of classes to classify the inputs into.

  • bn_config (Optional[Mapping[str, FloatStrOrBool]]) – A dictionary of two elements, decay_rate and eps to be passed on to the BatchNorm layers.

  • resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to False.

  • logits_config (Optional[Mapping[str, Any]]) – A dictionary of keyword arguments for the logits layer.

  • name (Optional[str]) – Name of the module.

  • initial_conv_config (Optional[Mapping[str, FloatStrOrBool]]) – Keyword arguments passed to the constructor of the initial Conv2D module.

ResNet50

class haiku.nets.ResNet50(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None)[source]

ResNet50.

__init__(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None)[source]

Constructs a ResNet model.

Parameters
  • num_classes (int) – The number of classes to classify the inputs into.

  • bn_config (Optional[Mapping[str, FloatStrOrBool]]) – A dictionary of two elements, decay_rate and eps to be passed on to the BatchNorm layers.

  • resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to False.

  • logits_config (Optional[Mapping[str, Any]]) – A dictionary of keyword arguments for the logits layer.

  • name (Optional[str]) – Name of the module.

  • initial_conv_config (Optional[Mapping[str, FloatStrOrBool]]) – Keyword arguments passed to the constructor of the initial Conv2D module.

ResNet101

class haiku.nets.ResNet101(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None)[source]

ResNet101.

__init__(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None)[source]

Constructs a ResNet model.

Parameters
  • num_classes (int) – The number of classes to classify the inputs into.

  • bn_config (Optional[Mapping[str, FloatStrOrBool]]) – A dictionary of two elements, decay_rate and eps to be passed on to the BatchNorm layers.

  • resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to False.

  • logits_config (Optional[Mapping[str, Any]]) – A dictionary of keyword arguments for the logits layer.

  • name (Optional[str]) – Name of the module.

  • initial_conv_config (Optional[Mapping[str, FloatStrOrBool]]) – Keyword arguments passed to the constructor of the initial Conv2D module.

ResNet152

class haiku.nets.ResNet152(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None)[source]

ResNet152.

__init__(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None)[source]

Constructs a ResNet model.

Parameters
  • num_classes (int) – The number of classes to classify the inputs into.

  • bn_config (Optional[Mapping[str, FloatStrOrBool]]) – A dictionary of two elements, decay_rate and eps to be passed on to the BatchNorm layers.

  • resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to False.

  • logits_config (Optional[Mapping[str, Any]]) – A dictionary of keyword arguments for the logits layer.

  • name (Optional[str]) – Name of the module.

  • initial_conv_config (Optional[Mapping[str, FloatStrOrBool]]) – Keyword arguments passed to the constructor of the initial Conv2D module.

ResNet200

class haiku.nets.ResNet200(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None)[source]

ResNet200.

__init__(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None)[source]

Constructs a ResNet model.

Parameters
  • num_classes (int) – The number of classes to classify the inputs into.

  • bn_config (Optional[Mapping[str, FloatStrOrBool]]) – A dictionary of two elements, decay_rate and eps to be passed on to the BatchNorm layers.

  • resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to False.

  • logits_config (Optional[Mapping[str, Any]]) – A dictionary of keyword arguments for the logits layer.

  • name (Optional[str]) – Name of the module.

  • initial_conv_config (Optional[Mapping[str, FloatStrOrBool]]) – Keyword arguments passed to the constructor of the initial Conv2D module.

VectorQuantizer

VectorQuantizer(embedding_dim, …[, dtype, …])

Haiku module representing the VQ-VAE layer.

VectorQuantizerEMA(embedding_dim, …[, …])

Haiku module representing the VQ-VAE layer.

VectorQuantizer

class haiku.nets.VectorQuantizer(embedding_dim, num_embeddings, commitment_cost, dtype=<class 'jax._src.numpy.lax_numpy.float32'>, name=None)[source]

Haiku module representing the VQ-VAE layer.

Implements the algorithm presented in “Neural Discrete Representation Learning” by van den Oord et al. https://arxiv.org/abs/1711.00937

Input any tensor to be quantized. Last dimension will be used as space in which to quantize. All other dimensions will be flattened and will be seen as different examples to quantize.

The output tensor will have the same shape as the input.

For example a tensor with shape [16, 32, 32, 64] will be reshaped into [16384, 64] and all 16384 vectors (each of 64 dimensions) will be quantized independently.

embedding_dim

integer representing the dimensionality of the tensors in the quantized space. Inputs to the modules must be in this format as well.

num_embeddings

integer, the number of vectors in the quantized space.

commitment_cost

scalar which controls the weighting of the loss terms (see equation 4 in the paper - this variable is Beta).

__init__(embedding_dim, num_embeddings, commitment_cost, dtype=<class 'jax._src.numpy.lax_numpy.float32'>, name=None)[source]

Initializes a VQ-VAE module.

Parameters
  • embedding_dim (int) – dimensionality of the tensors in the quantized space. Inputs to the modules must be in this format as well.

  • num_embeddings (int) – number of vectors in the quantized space.

  • commitment_cost (float) – scalar which controls the weighting of the loss terms (see equation 4 in the paper - this variable is Beta).

  • dtype (Any) – dtype for the embeddings variable, defaults to float32.

  • name (Optional[str]) – name of the module.

__call__(inputs, is_training)[source]

Connects the module to some inputs.

Parameters
  • inputs – Tensor, final dimension must be equal to embedding_dim. All other leading dimensions will be flattened and treated as a large batch.

  • is_training – boolean, whether this connection is to training data.

Returns

Dictionary containing the following keys and values:
  • quantize: Tensor containing the quantized version of the input.

  • loss: Tensor containing the loss to optimize.

  • perplexity: Tensor containing the perplexity of the encodings.

  • encodings: Tensor containing the discrete encodings, ie which element of the quantized space each input element was mapped to.

  • encoding_indices: Tensor containing the discrete encoding indices, ie which element of the quantized space each input element was mapped to.

Return type

dict

quantize(encoding_indices)[source]

Returns embedding tensor for a batch of indices.

VectorQuantizerEMA

class haiku.nets.VectorQuantizerEMA(embedding_dim, num_embeddings, commitment_cost, decay, epsilon=1e-05, dtype=<class 'jax._src.numpy.lax_numpy.float32'>, cross_replica_axis=None, name=None)[source]

Haiku module representing the VQ-VAE layer.

Implements a slightly modified version of the algorithm presented in “Neural Discrete Representation Learning” by van den Oord et al. https://arxiv.org/abs/1711.00937

The difference between VectorQuantizerEMA and VectorQuantizer is that this module uses ExponentialMovingAverages to update the embedding vectors instead of an auxiliary loss. This has the advantage that the embedding updates are independent of the choice of optimizer (SGD, RMSProp, Adam, K-Fac, …) used for the encoder, decoder and other parts of the architecture. For most experiments the EMA version trains faster than the non-EMA version.

Input any tensor to be quantized. Last dimension will be used as space in which to quantize. All other dimensions will be flattened and will be seen as different examples to quantize.

The output tensor will have the same shape as the input.

For example a tensor with shape [16, 32, 32, 64] will be reshaped into [16384, 64] and all 16384 vectors (each of 64 dimensions) will be quantized independently.

embedding_dim

integer representing the dimensionality of the tensors in the quantized space. Inputs to the modules must be in this format as well.

num_embeddings

integer, the number of vectors in the quantized space.

commitment_cost

scalar which controls the weighting of the loss terms (see equation 4 in the paper).

decay

float, decay for the moving averages.

epsilon

small float constant to avoid numerical instability.

__init__(embedding_dim, num_embeddings, commitment_cost, decay, epsilon=1e-05, dtype=<class 'jax._src.numpy.lax_numpy.float32'>, cross_replica_axis=None, name=None)[source]

Initializes a VQ-VAE EMA module.

Parameters
  • embedding_dim – integer representing the dimensionality of the tensors in the quantized space. Inputs to the modules must be in this format as well.

  • num_embeddings – integer, the number of vectors in the quantized space.

  • commitment_cost – scalar which controls the weighting of the loss terms (see equation 4 in the paper - this variable is Beta).

  • decay – float between 0 and 1, controls the speed of the Exponential Moving Averages.

  • epsilon (float) – small constant to aid numerical stability, default 1e-5.

  • dtype (Any) – dtype for the embeddings variable, defaults to float32.

  • cross_replica_axis (Optional[str]) – If not None, it should be a string representing the axis name over which this module is being run within a jax.pmap(). Supplying this argument means that cluster statistics and the perplexity are calculated across all replicas on that axis.

  • name (Optional[str]) – name of the module.

__call__(inputs, is_training)[source]

Connects the module to some inputs.

Parameters
  • inputs – Tensor, final dimension must be equal to embedding_dim. All other leading dimensions will be flattened and treated as a large batch.

  • is_training – boolean, whether this connection is to training data. When this is set to False, the internal moving average statistics will not be updated.

Returns

Dictionary containing the following keys and values:
  • quantize: Tensor containing the quantized version of the input.

  • loss: Tensor containing the loss to optimize.

  • perplexity: Tensor containing the perplexity of the encodings.

  • encodings: Tensor containing the discrete encodings, ie which element of the quantized space each input element was mapped to.

  • encoding_indices: Tensor containing the discrete encoding indices, ie which element of the quantized space each input element was mapped to.

Return type

dict

quantize(encoding_indices)[source]

Returns embedding tensor for a batch of indices.

JAX Fundamentals

Control Flow

cond(*args, **kwargs)

Equivalent to jax.lax.cond() but with Haiku state passed in/out.

fori_loop(lower, upper, body_fun, init_val)

Equivalent to jax.lax.fori_loop() with Haiku state passed in/out.

scan(f, init, xs[, length, reverse, unroll])

Equivalent to jax.lax.scan() but with Haiku state passed in/out.

switch(index, branches, operand)

Equivalent to jax.lax.switch() but with Haiku state passed in/out.

while_loop(cond_fun, body_fun, init_val)

Equivalent to jax.lax.while_loop with Haiku state threaded in/out.

cond

haiku.cond(*args, **kwargs)[source]

Equivalent to jax.lax.cond() but with Haiku state passed in/out.

fori_loop

haiku.fori_loop(lower, upper, body_fun, init_val)[source]

Equivalent to jax.lax.fori_loop() with Haiku state passed in/out.

scan

haiku.scan(f, init, xs, length=None, reverse=False, unroll=1)[source]

Equivalent to jax.lax.scan() but with Haiku state passed in/out.

switch

haiku.switch(index, branches, operand)[source]

Equivalent to jax.lax.switch() but with Haiku state passed in/out.

while_loop

haiku.while_loop(cond_fun, body_fun, init_val)[source]

Equivalent to jax.lax.while_loop with Haiku state threaded in/out.

JAX Transforms

eval_shape(fun, *args, **kwargs)

Equivalent to jax.eval_shape with any changed Haiku state discarded.

grad(fun[, argnums, has_aux, holomorphic])

Creates a function which evaluates the gradient of fun.

jit(fun, *dec_args, **dec_kwargs)

Decorates a modified version of fun that passes Haiku state.

remat(fun, *dec_args, **dec_kwargs)

Decorates a modified version of fun that passes Haiku state.

value_and_grad(fun[, argnums, has_aux, …])

Creates a function which evaluates both fun and the grad of fun.

vmap(fun[, in_axes, out_axes, axis_name])

Equivalent to jax.vmap() with module parameters/state not mapped.

eval_shape

haiku.eval_shape(fun, *args, **kwargs)[source]

Equivalent to jax.eval_shape with any changed Haiku state discarded.

grad

haiku.grad(fun, argnums=0, has_aux=False, holomorphic=False)[source]

Creates a function which evaluates the gradient of fun.

NOTE: You only need this in a very specific case that you want to take a gradient inside a transform()ed function and the function you are differentiating uses set_state(). For example:

>>> class MyModule(hk.Module):
...   def __call__(self, x):
...     hk.set_state("last", x ** 2)
...     return x ** 2
>>> def f(x):
...   m = MyModule()
...   g = hk.grad(m)(x)
...   return g
>>> f = hk.transform_with_state(f)
>>> x = jnp.array(2.)
>>> params, state = jax.jit(f.init)(None, x)
>>> state["my_module"]["last"]
DeviceArray(4., dtype=float32)
Parameters
  • fun – Function to be differentiated. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape () but not arrays with shape (1,) etc.)

  • argnums – Optional, integer or tuple of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

  • has_aux – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • holomorphic – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.

Returns

A function with the same arguments as fun, that evaluates the gradient of fun. If argnums is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. If has_aux is True then a pair of gradient, auxiliary_data is returned.

For example:

>>> grad_tanh = jax.grad(jax.numpy.tanh)
>>> print(grad_tanh(0.2))
0.96...

jit

haiku.jit(fun, *dec_args, **dec_kwargs)[source]

Decorates a modified version of fun that passes Haiku state.

remat

haiku.remat(fun, *dec_args, **dec_kwargs)[source]

Decorates a modified version of fun that passes Haiku state.

value_and_grad

haiku.value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False)[source]

Creates a function which evaluates both fun and the grad of fun.

NOTE: You only need this in a very specific case that you want to take a gradient inside a transform()ed function and the function you are differentiating uses set_state(). For example:

>>> class MyModule(hk.Module):
...   def __call__(self, x):
...     hk.set_state("last", jnp.sum(x))
...     return x ** 2
>>> def f(x):
...   m = MyModule()
...   y, g = hk.value_and_grad(m)(x)
...   return y, g
>>> f = hk.transform_with_state(f)
>>> x = jnp.array(2.)
>>> _ = jax.jit(f.init)(None, x)
Parameters
  • fun – Function to be differentiated. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape () but not arrays with shape (1,) etc.)

  • argnums – Optional, integer or tuple of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

  • has_aux – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • holomorphic – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.

Returns

A function with the same arguments as fun that evaluates both fun and the gradient of fun and returns them as a pair (a two-element tuple). If argnums is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments.

vmap

haiku.vmap(fun, in_axes=0, out_axes=0, axis_name=None)[source]

Equivalent to jax.vmap() with module parameters/state not mapped.

Mixed Precision

Automatic Mixed Precision

set_policy(cls, policy)

Uses the given policy for all instances of the module class.

clear_policy(cls)

Clears any policy assocated with the given class.

set_policy

haiku.mixed_precision.set_policy(cls, policy)[source]

Uses the given policy for all instances of the module class.

NOTE: Policies are only applied to modules created in the current thread.

A mixed precision policy describes how inputs, module parameters and module outputs should be cast at runtime. By applying a policy to a given type of module, you can control how all instances of that module behave in your program.

For example, you might want to try running a ResNet50 model in a mixture of float16 and float32 on GPU to get higher throughput. To do so you can apply a mixed precision policy to the ResNet50 type that will create parameters in float32, but cast them to float16 before use, along with all module inputs:

>>> policy = jmp.get_policy('params=float32,compute=float16,output=float32')
>>> hk.mixed_precision.set_policy(hk.nets.ResNet50, policy)
>>> net = hk.nets.ResNet50(4)
>>> x = jnp.ones([4, 224, 224, 3])
>>> net(x, is_training=True)
DeviceArray([[nan, nan, nan, nan],
             [nan, nan, nan, nan],
             [nan, nan, nan, nan],
             [nan, nan, nan, nan]], dtype=float32)

Oh no, nan! This is because modules like batch norm are not numerically stable in float16. To address this, we apply a second policy to our batch norm modules to keep them in full precision. We are careful to return a float16 output from the module such that subsequent modules receive float16 input:

>>> policy = jmp.get_policy('params=float32,compute=float32,output=float16')
>>> hk.mixed_precision.set_policy(hk.BatchNorm, policy)
>>> net(x, is_training=True)
DeviceArray([[0., 0., 0., 0.],
             [0., 0., 0., 0.],
             [0., 0., 0., 0.],
             [0., 0., 0., 0.]], dtype=float32)

For a fully worked mixed precision example see the imagenet example in Haiku’s examples directory. This example shows mixed precision on GPU offering a 2x speedup in training time with only a small impact on final top-1 accuracy.

>>> hk.mixed_precision.clear_policy(hk.nets.ResNet50)
>>> hk.mixed_precision.clear_policy(hk.BatchNorm)
Parameters
  • cls (Type[hk.Module]) – A Haiku module class.

  • policy (jmp.Policy) – A JMP policy to apply to the module.

See also

clear_policy(): Clears any policies associated with a class.

clear_policy

haiku.mixed_precision.clear_policy(cls)[source]

Clears any policy assocated with the given class.

Parameters

cls (Type[hk.Module]) – A Haiku module class.

See also

set_policy(): Sets a policy for a given class.

🚧 Experimental

Getters and Interceptors

custom_creator(creator, *[, params, state])

Registers a custom parameter and/or state creator.

custom_getter(getter, *[, params, state])

Registers a custom parameter or state getter.

GetterContext(full_name, module, …)

Read only state showing where parameters are being created.

intercept_methods(interceptor)

Register a new method interceptor.

MethodContext(module, method_name, …)

Read only state showing the calling context for a method.

custom_creator

haiku.experimental.custom_creator(creator, *, params=True, state=False)[source]

Registers a custom parameter and/or state creator.

When new parameters are created via get_parameter() we first run custom creators passing user defined values through. For example:

>>> def zeros_creator(next_creator, shape, dtype, init, context):
...   init = jnp.zeros
...   return next_creator(shape, dtype, init)
>>> with hk.experimental.custom_creator(zeros_creator):
...   z = hk.get_parameter("z", [], jnp.float32, jnp.ones)
>>> z
DeviceArray(0., dtype=float32)

If state=True then your creator will additionally run on calls to get_state():

>>> with hk.experimental.custom_creator(zeros_creator, state=True):
...   z = hk.get_state("z", [], jnp.float32, jnp.ones)
>>> z
DeviceArray(0., dtype=float32)
Parameters
  • creator (Creator) – A parameter creator.

  • params (bool) – Whether to intercept parameter creation, defaults to True.

  • state (bool) – Whether to intercept state creation, defaults to False.

Return type

contextlib.AbstractContextManager

Returns

Context manager under which the creator is active.

custom_getter

haiku.experimental.custom_getter(getter, *, params=True, state=False)[source]

Registers a custom parameter or state getter.

When parameters are retrieved using get_parameter() we always run all custom getters before returning a value to the user.

>>> def bf16_getter(next_getter, value, context):
...   value = value.astype(jnp.bfloat16)
...   return next_getter(value)
>>> with hk.experimental.custom_getter(bf16_getter):
...   w = hk.get_parameter("w", [], jnp.float32, jnp.ones)
>>> w.dtype
dtype(bfloat16)

If state=True the getter will additionally run for calls to get_state():

>>> with hk.experimental.custom_getter(bf16_getter, state=True):
...   c = hk.get_state("c", [], jnp.float32, jnp.ones)
>>> c.dtype
dtype(bfloat16)
Parameters
  • getter (Getter) – A parameter getter.

  • params (bool) – Whether the getter should run on get_parameter()

  • state (bool) – Whether the getter should run on get_state().

Return type

contextlib.AbstractContextManager

Returns

Context manager under which the getter is active.

GetterContext

class haiku.experimental.GetterContext(full_name: str, module: Optional[Module], original_dtype: Any, original_shape: Sequence[int])[source]

Read only state showing where parameters are being created.

full_name

The full name of the given parameter (e.g. mlp/~/linear_0/w).

module

The module that owns the current parameter, None if this parameter exists outside any module.

original_dtype

The dtype that get_parameter() or get_state() was originally called with.

original_shape

The shape that get_parameter() or get_state() was originally called with.

module_name

The full name of enclosing modules.

name

The name of this parameter.

ParamContext

Deprecated since version 0.0.3: Renamed to GetterContext and used for both custom parameter getters and custom state getters.

haiku.experimental.ParamContext[source]

alias of haiku._src.base.GetterContext

intercept_methods

haiku.experimental.intercept_methods(interceptor)[source]

Register a new method interceptor.

Method interceptors allow you to (at a distance) intercept method calls to modules and modify args/kwargs before calling the underlying method. After the underlying method is called you can modify its result before it is passed back to the user.

For example you could intercept method calls to BatchNorm and ensure it is always computed in full precision:

>>> def my_interceptor(next_f, args, kwargs, context):
...   if (type(context.module) is not hk.BatchNorm
...       or context.method_name != "__call__"):
...     # We ignore methods other than BatchNorm.__call__.
...     return next_f(*args, **kwargs)
...
...   def cast_if_array(x):
...     if isinstance(x, jnp.ndarray):
...       x = x.astype(jnp.float32)
...     return x
...
...   args, kwargs = jax.tree_map(cast_if_array, (args, kwargs))
...   out = next_f(*args, **kwargs)
...   return out

We can create and use our module in the usual way, we just need to wrap any method calls we want to intercept in the context manager:

>>> mod = hk.BatchNorm(decay_rate=0.9, create_scale=True, create_offset=True)
>>> x = jnp.ones([], jnp.bfloat16)
>>> with hk.experimental.intercept_methods(my_interceptor):
...   out = mod(x, is_training=True)
>>> assert out.dtype == jnp.float32

Without the interceptor BatchNorm would compute in bf16, however since we cast x before the underlying method is called we compute in f32.

Parameters

interceptor (MethodGetter) – A method interceptor.

Returns

Context manager under which the interceptor is active.

MethodContext

class haiku.experimental.MethodContext(module: Modul, method_name: str, orig_method: Callable[, Any])[source]

Read only state showing the calling context for a method.

For example, let’s define two interceptors and print the values in the context. Additionally, we will make the first interceptor conditionally short circuit, since interceptors stack and are run in order, an earlier interceptor can decide to call the next interceptor, or short circuit and call the underlying method directly:

>>> module = hk.Linear(1, name="method_context_example")
>>> short_circuit = False
>>> def my_interceptor_1(next_fun, args, kwargs, context):
...   print('running my_interceptor_1')
...   print('- module.name: ', context.module.name)
...   print('- method_name: ', context.method_name)
...   if short_circuit:
...     return context.orig_method(*args, **kwargs)
...   else:
...     return next_fun(*args, **kwargs)
>>> def my_interceptor_2(next_fun, args, kwargs, context):
...   print('running my_interceptor_2')
...   print('- module.name: ', context.module.name)
...   print('- method_name: ', context.method_name)
...   return next_fun(*args, **kwargs)

When short_circuit=False the two interceptors will run in order:

>>> with hk.experimental.intercept_methods(my_interceptor_1), \
...      hk.experimental.intercept_methods(my_interceptor_2):
...   _ = module(jnp.ones([1, 1]))
running my_interceptor_1
- module.name:  method_context_example
- method_name:  __call__
running my_interceptor_2
- module.name:  method_context_example
- method_name:  __call__

Setting short_circuit=True will cause the first interceptor to call the original method (rather than next_fun which will trigger the next interceptor):

>>> short_circuit = True
>>> with hk.experimental.intercept_methods(my_interceptor_1), \
...      hk.experimental.intercept_methods(my_interceptor_2):
...   _ = module(jnp.ones([1, 1]))
running my_interceptor_1
- module.name:  method_context_example
- method_name:  __call__
module

A Module instance whose method is being called.

method_name

The name of the method being called on the module.

orig_method

The underlying method on the module which when called will not trigger interceptors. You should only call this if you want to short circuit all the other interceptors, in general you should prefer to call the next_fun passed to your interceptor which will run orig_method after running all other interceptors.

TensorFlow Profiler

named_call(fun, *[, name])

Wraps a function in an XLA name_scope and maintains Haiku state.

profiler_name_scopes([enabled])

Enable/disable profiler name_scopes on all haiku module methods.

named_call

haiku.experimental.named_call(fun, *, name=None)[source]

Wraps a function in an XLA name_scope and maintains Haiku state.

Return type

Callable[.., Any]

profiler_name_scopes

haiku.experimental.profiler_name_scopes(enabled=True)[source]

Enable/disable profiler name_scopes on all haiku module methods.

Note: currently only enables for __call__. See: :function:`named_call` if you want to annotate other methods explicitly.

Parameters

enabled – Whether to enable name scopes or not.

Returns

The previous value of the name_scopes setting.

Graphviz Visualisation

to_dot(fun)

Converts a function using Haiku modules to a dot graph.

abstract_to_dot(fun)

Converts a function using Haiku modules to a dot graph.

abstract_to_dot

haiku.experimental.abstract_to_dot(fun)[source]

Converts a function using Haiku modules to a dot graph.

Same as to_dot() but uses JAX’s abstract interpretation machinery to evaluate the function without requiring concrete inputs. Valid inputs for the wrapped function include jax.ShapeDtypeStruct.

abstract_to_dot() does not support data-dependent control-flow, because no concrete values are provided to the function.

Parameters

fun (Callable[.., Any]) – A function using Haiku modules.

Return type

Callable[.., str]

Returns

A function that returns the source code string to a graphviz graph describing the operations executed by the given function clustered by Haiku module.

See also

to_dot(): Generates a graphviz graph using concrete inputs.

to_dot

haiku.experimental.to_dot(fun)[source]

Converts a function using Haiku modules to a dot graph.

To view the resulting graph in Google Colab or an iPython notebook use the graphviz package:

dot = hk.experimental.to_dot(f)(x)
import graphviz
graphviz.Source(dot)
Parameters

fun (Callable[.., Any]) – A function using Haiku modules.

Return type

Callable[.., str]

Returns

A function that returns the source code string to a graphviz graph describing the operations executed by the given function clustered by Haiku module.

See also

abstract_to_dot(): Generates a graphviz graph using abstract inputs.

Summarisation

tabulate(f, *[, columns, filters, …])

Produces a summarised view of the execution of f.

eval_summary(f)

Records module method calls performed by f.

ArraySpec(shape, dtype)

Shaped and sized specification of an array.

MethodInvocation(module_details, args_spec, …)

Record of a method being invoked on a given module.

ModuleDetails(module, method_name, params, state)

Module and method related information.

tabulate

haiku.experimental.tabulate(f, *, columns=('module', 'config', 'owned_params', 'input', 'output', 'params_size', 'params_bytes'), filters=('has_output'), tabulate_kwargs={'tablefmt': 'grid'})[source]

Produces a summarised view of the execution of f.

>>> def f(x):
...   return hk.nets.MLP([300, 100, 10])(x)
>>> x = jnp.ones([8, 28 * 28])
>>> f = hk.transform(f)
>>> print(hk.experimental.tabulate(f)(x))
+-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+
| Module                  | Config                                   | Module params   | Input      | Output     |   Param count |   Param bytes |
+=========================+==========================================+=================+============+============+===============+===============+
| mlp (MLP)               | MLP(output_sizes=[300, 100, 10])         |                 | f32[8,784] | f32[8,10]  |       266,610 |       1.07 MB |
+-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+
| mlp/~/linear_0 (Linear) | Linear(output_size=300, name='linear_0') | w: f32[784,300] | f32[8,784] | f32[8,300] |       235,500 |     942.00 KB |
|  └ mlp (MLP)            |                                          | b: f32[300]     |            |            |               |               |
+-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+
| mlp/~/linear_1 (Linear) | Linear(output_size=100, name='linear_1') | w: f32[300,100] | f32[8,300] | f32[8,100] |        30,100 |     120.40 KB |
|  └ mlp (MLP)            |                                          | b: f32[100]     |            |            |               |               |
+-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+
| mlp/~/linear_2 (Linear) | Linear(output_size=10, name='linear_2')  | w: f32[100,10]  | f32[8,100] | f32[8,10]  |         1,010 |       4.04 KB |
|  └ mlp (MLP)            |                                          | b: f32[10]      |            |            |               |               |
+-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+

Possible values for columns:

  • module: Displays module and method name.

  • config: Displays the constructor arguments used for the module.

  • owned_params: Displays parameters directly owned by this module.

  • input: Displays module inputs.

  • output: Displays module output.

  • params_size: Displays the number of parameters

  • params_bytes: Displays parameter size in bytes.

Possible values for filters:

  • has_output: Only include methods returning a value other than None.

  • has_params: Removes methods from modules that do not have parameters.

Parameters
  • f (Union[Callable[.., Any], hk.Transformed, hk.TransformedWithState]) – A function to transform OR one of the init/apply functions from Haiku or the result of transform() or transform_with_state().

  • columns (Optional[Sequence[str]]) – A list of column names to enable.

  • filters (Optional[Sequence[str]]) – A list of filters to apply to remove certain module methods.

  • tabulate_kwargs – Keyword arguments to pass to tabulate.tabulate(..).

Return type

Callable[.., str]

Returns

A callable that takes the same arguments as f but returns a string summarising the modules used during the execution of f.

See also

eval_summary(): Raw data used to generate this table.

eval_summary

haiku.experimental.eval_summary(f)[source]

Records module method calls performed by f.

>>> f = lambda x: hk.nets.MLP([300, 100, 10])(x)
>>> x = jnp.ones([8, 28 * 28])
>>> for i in hk.experimental.eval_summary(f)(x):
...   print("mod := {:14} | in := {} out := {}".format(
...       i.module_details.module.module_name, i.args_spec[0], i.output_spec))
mod := mlp            | in := f32[8,784] out := f32[8,10]
mod := mlp/~/linear_0 | in := f32[8,784] out := f32[8,300]
mod := mlp/~/linear_1 | in := f32[8,300] out := f32[8,100]
mod := mlp/~/linear_2 | in := f32[8,100] out := f32[8,10]
Parameters

f (Union[Callable[.., Any], hk.Transformed, hk.TransformedWithState]) – A function or transformed function to trace.

Return type

Callable[.., Sequence[MethodInvocation]]

Returns

A callable taking the same arguments as the provided function, but returning a sequence of MethodInvocation instances revealing the methods called on each module when applying f.

See also

tabulate(): Pretty prints a summary of the execution of a function.

ArraySpec

class haiku.experimental.ArraySpec(shape, dtype)[source]

Shaped and sized specification of an array.

shape

Shape of the array.

dtype

DType of the array.

MethodInvocation

class haiku.experimental.MethodInvocation(module_details, args_spec, kwargs_spec, output_spec, context, call_stack)[source]

Record of a method being invoked on a given module.

module_details

Details about which module and method were invoked.

args_spec

Positional arguments to the method invocation with arrays replaced by ArraySpec.

kwargs_spec

Keyword arguments to the method invocation with arrays replaced by ArraySpec.

output_spec

Output of the method invocation with arrays replaced by ArraySpec.

context

Additional context information for the method call as provided by intercept_methods().

call_stack

Stack of modules currently active while calling this module method. For example if A calls B which calls C then the call stack for C will be [B_DETAILS, A_DETAILS].

ModuleDetails

class haiku.experimental.ModuleDetails(module, method_name, params, state)[source]

Module and method related information.

module

A Module instance.

method_name

The method name that was invoked on the module.

params

The modules params dict with arrays converted to ArraySpec.

state

The modules state dict with arrays converted to ArraySpec.

Managing State

name_scope(name)

Context manager which adds a prefix to all new modules, params or state.

name_like(method_name)

Allows a method to be named like some other method.

lift

alias of haiku._src.lift.LiftingModule

name_scope

haiku.experimental.name_scope(name)[source]

Context manager which adds a prefix to all new modules, params or state.

>>> with hk.experimental.name_scope("my_name_scope"):
...   net = hk.Linear(1, name="my_linear")
>>> net.module_name
'my_name_scope/my_linear'

When used inside a module, any submodules, parameters or state created inside the name scope will have a prefix added to their names:

>>> class MyModule(hk.Module):
...   def __call__(self, x):
...     with hk.experimental.name_scope("my_name_scope"):
...       submodule = hk.Linear(1, name="submodule")
...       w = hk.get_parameter("w", [], init=jnp.ones)
...     return submodule(x) + w
>>> f = hk.transform(lambda x: MyModule()(x))
>>> params = f.init(jax.random.PRNGKey(42), jnp.ones([1, 1]))
>>> jax.tree_map(jnp.shape, params)
{'my_module/my_name_scope': {'w': ()},
 'my_module/my_name_scope/submodule': {'b': (1,), 'w': (1, 1)}}

Name scopes are very similar to putting all of the code inside the context manager inside a method on a Module with the name you provide. Behind the scenes this is precisely how name scopes are implemented.

If you are familiar with TensorFlow then Haiku’s name_scope() is similar to tf.variable_scope(..) in TensorFlow 1 and tf.name_scope(..) in TensorFlow 1 and 2 in that it changes the names associated with modules, parameters and state.

Parameters

name (str) – The name scope to use (e.g. "foo" or "foo/bar").

Return type

ContextManager[None]

Returns

A single use context manager that when active prefixes new modules, parameters or state with the given name.

name_like

haiku.experimental.name_like(method_name)[source]

Allows a method to be named like some other method.

In Haiku submodules are named based on the name of their parent module and the method in which they are created. When refactoring code it may be desirable to maintain previous names in order to keep checkpoint compatibility, this can be achieved using name_like().

As an example, consider the following toy autoencoder:

>>> class Autoencoder(hk.Module):
...   def __call__(self, x):
...     z = hk.Linear(10, name="enc")(x)  # name: autoencoder/enc
...     y = hk.Linear(10, name="dec")(z)  # name: autoencoder/dec
...     return y

If we want to refactor this such that users can encode or decode, we would create two methods (encode, decode) which would create and apply our modules. In order to retain checkpoint compatibility with the original module we can use name_like() to name those submodules as if they were created inside __call__:

>>> class Autoencoder(hk.Module):
...   @hk.experimental.name_like("__call__")
...   def encode(self, x):
...     return hk.Linear(10, name="enc")(x)  # name: autoencoder/enc
...
...   @hk.experimental.name_like("__call__")
...   def decode(self, z):
...     return hk.Linear(10, name="dec")(z)  # name: autoencoder/dec
...
...   def __call__(self, x):
...     return self.decode(self.encode(x))

One sharp edge is if users rely on Haiku’s numbering to take care of giving unique names and refactor using name_like(). For example when refactoring the following:

>>> class Autoencoder(hk.Module):
...   def __call__(self, x):
...     y = hk.Linear(10)(z)  # name: autoencoder/linear_1
...     z = hk.Linear(10)(x)  # name: autoencoder/linear
...     return y

To use name_like(), the unnamed linear modules in encode/decode will end up with the same name (both: autoencoder/linear) because module numbering is only applied within a method:

>>> class Autoencoder(hk.Module):
...   @hk.experimental.name_like("__call__")
...   def encode(self, x):
...     return hk.Linear(10)(x)  # name: autoencoder/linear
...
...   @hk.experimental.name_like("__call__")
...   def decode(self, z):
...     return hk.Linear(10)(z)  # name: autoencoder/linear  <-- NOT INTENDED

To fix this case you need to explicitly name the modules within the method with their former name:

>>> class Autoencoder(hk.Module):
...   @hk.experimental.name_like("__call__")
...   def encode(self, x):
...     return hk.Linear(10, name="linear")(x)    # name: autoencoder/linear
...
...   @hk.experimental.name_like("__call__")
...   def decode(self, z):
...     return hk.Linear(10, name="linear_1")(z)  # name: autoencoder/linear_1
Parameters

method_name (str) – The name of a method whose name we should adopt. This method does not actually have to be defined on the class.

Return type

Callable[[T], T]

Returns

A decorator that when applied to a method marks it as having a different name.

lift

haiku.experimental.lift(*args, **kwargs) → T[source]

Lifts the given init function to a function in the current Haiku namespace.

During init, the returned callable will run the given init_fn, and include the resulting params in the outer transform’s dictionaries. During apply, the returned callable will instead pull the relevant parameters from the outer transform’s dictionaries.

Must be called inside transform(), and be passed the init member of a Transformed.

The user must ensure that the given init does not accidentally catch modules from an outer transform() via functional closure.

This is highly experimental and may be changed or removed at any time.

Optimizations

optimize_rng_use(fun)

Optimizes a RNG key splitting in fun.

optimize_rng_use

haiku.experimental.optimize_rng_use(fun)[source]

Optimizes a RNG key splitting in fun.

Our strategy here is to use abstract interpretation to run your function twice, the first time we use jax.eval_shape() to avoid spending any flops and simply observe how many times you call next_rng_key(). We then run your function again, but this time we reserve enough RNG keys ahead of time such that we only need to call jax.random.split() once.

In the following example, we need three random samples for our weight matrices in our 3-layer MLP. To draw these samples we use next_rng_key() which will split a new key for each sample. By using optimize_rng_use() Haiku will pre-allocate exactly enough RNGs for f to be evaluated by splitting the input key once and only once. For large models (unlike this example) this can lead to a significant reduction in compilation time for init:

>>> def f(x):
...   net = hk.nets.MLP([300, 100, 10])
...   return net(x)
>>> f = hk.experimental.optimize_rng_use(f)
>>> f = hk.transform(f)
>>> params = f.init(jax.random.PRNGKey(42), jnp.ones([1, 1]))
Parameters

fun – A function to wrap.

Returns

A function that applies fun but only requires one call to jax.random.split() by Haiku.

Utilities

Data Structures

filter(predicate, structure)

Filters an input structure according to a user specified predicate.

is_subset(*, subset, superset)

Checks whether the leaves of subset appear in superset.

map(fn, structure)

Maps a function to an input structure accordingly.

merge(*structures)

Merges multiple input structures.

partition(predicate, structure)

Partitions the input structure in two according to a given predicate.

partition_n(fn, structure, n)

Partitions a structure into n structures.

to_immutable_dict(mapping)

Returns an immutable copy of the given mapping.

to_mutable_dict(mapping)

Turns an immutable FlatMapping into a mutable dict.

traverse(structure)

Iterates over a structure yielding module names, names and values.

tree_bytes(tree)

Sums the size in bytes of all arrays in a pytree.

tree_size(tree)

Sums the sizes of all arrays in a pytree.

filter

haiku.data_structures.filter(predicate, structure)[source]

Filters an input structure according to a user specified predicate.

>>> params = {'linear': {'w': None, 'b': None}}
>>> predicate = lambda module_name, name, value: name == 'w'
>>> hk.data_structures.filter(predicate, params)
{'linear': {'w': None}}

Note: returns a new structure not a view.

Parameters
  • predicate (Callable[[str, str, T], bool]) – criterion to be used to partition the input data. The predicate argument is expected to be a boolean function taking as inputs the name of the module, the name of a given entry in the module data bundle (e.g. parameter name) and the corresponding data.

  • structure (Mapping[str, Mapping[str, T]]) – Haiku params or state data structure to be filtered.

Return type

Mapping[str, Mapping[str, T]]

Returns

All the input parameters or state as selected by the input predicate.

is_subset

haiku.data_structures.is_subset(*, subset, superset)[source]

Checks whether the leaves of subset appear in superset.

Note that this is vacuously true in the case that both structures have no leaves:

>>> hk.data_structures.is_subset(subset={'a': {}}, superset={})
True
Parameters
  • subset (Mapping[str, Mapping[str, Any]]) – The subset to check.

  • superset (Mapping[str, Mapping[str, Any]]) – The superset to check.

Return type

bool

Returns

A boolean indicating whether all elements in subset are contained in superset.

map

haiku.data_structures.map(fn, structure)[source]

Maps a function to an input structure accordingly.

>>> params = {'linear': {'w': 1.0, 'b': 2.0}}
>>> fn = lambda module_name, name, value: 2 * value if name == 'w' else value
>>> hk.data_structures.map(fn, params)
{'linear': {'b': 2.0, 'w': 2.0}}

Note: returns a new structure not a view.

Parameters
  • fn (Callable[[str, str, InT], OutT]) – criterion to be used to map the input data. The fn argument is expected to be a boolean function taking as inputs the name of the module, the name of a given entry in the module data bundle (e.g. parameter name) and the corresponding data.

  • structure (Mapping[str, Mapping[str, InT]]) – Haiku params or state data structure to be mapped.

Return type

Mapping[str, Mapping[str, OutT]]

Returns

All the input parameters or state as mapped by the input fn.

merge

haiku.data_structures.merge(*structures)[source]

Merges multiple input structures.

>>> weights = {'linear': {'w': None}}
>>> biases = {'linear': {'b': None}}
>>> hk.data_structures.merge(weights, biases)
{'linear': {'w': None, 'b': None}}

When structures are not disjoint the output will contain the value from the last structure for each path:

>>> weights1 = {'linear': {'w': 1}}
>>> weights2 = {'linear': {'w': 2}}
>>> hk.data_structures.merge(weights1, weights2)
{'linear': {'w': 2}}

Note: returns a new structure not a view.

Parameters

*structures – One or more structures to merge.

Return type

Mapping[str, Mapping[str, Any]]

Returns

A single structure with an entry for each path in the input structures.

partition

haiku.data_structures.partition(predicate, structure)[source]

Partitions the input structure in two according to a given predicate.

For a given set of parameters, you can use partition() to split them:

>>> params = {'linear': {'w': None, 'b': None}}
>>> predicate = lambda module_name, name, value: name == 'w'
>>> weights, biases = hk.data_structures.partition(predicate, params)
>>> weights
{'linear': {'w': None}}
>>> biases
{'linear': {'b': None}}

Note: returns new structures not a view.

Parameters
  • predicate (Callable[[str, str, jnp.ndarray], bool]) – criterion to be used to partition the input data. The predicate argument is expected to be a boolean function taking as inputs the name of the module, the name of a given entry in the module data bundle (e.g. parameter name) and the corresponding data.

  • structure (Mapping[str, Mapping[str, T]]) – Haiku params or state data structure to be partitioned.

Return type

Tuple[Mapping[str, Mapping[str, T]], Mapping[str, Mapping[str, T]]]

Returns

A tuple containing all the params or state as partitioned by the input

predicate. Entries matching the predicate will be in the first structure, and the rest will be in the second.

partition_n

haiku.data_structures.partition_n(fn, structure, n)[source]

Partitions a structure into n structures.

For a given set of parameters, you can use partition_n() to split them into n groups. For example, to split your parameters/gradients by module name:

>>> def partition_by_module(structure):
...   cnt = itertools.count()
...   d = collections.defaultdict(lambda: next(cnt))
...   fn = lambda m, n, v: d[m]
...   return hk.data_structures.partition_n(fn, structure, len(structure))
>>> structure = {f'layer_{i}': {'w': None, 'b': None} for i in range(3)}
>>> for substructure in partition_by_module(structure):
...   print(substructure)
{'layer_0': {'b': None, 'w': None}}
{'layer_1': {'b': None, 'w': None}}
{'layer_2': {'b': None, 'w': None}}
Parameters
  • fn (Callable[[str, str, T], int]) – Callable returning which bucket in [0, n) the given element should be output.

  • structure (Mapping[str, Mapping[str, T]]) – Haiku params or state data structure to be partitioned.

  • n (int) – The total number of buckets.

Return type

Tuple[Mapping[str, Mapping[str, T]], ..]

Returns

A tuple of size n, where each element will contain the values for which the function returned the current index.

to_immutable_dict

haiku.data_structures.to_immutable_dict(mapping)[source]

Returns an immutable copy of the given mapping.

Return type

Mapping[K, V]

to_mutable_dict

haiku.data_structures.to_mutable_dict(mapping)[source]

Turns an immutable FlatMapping into a mutable dict.

traverse

haiku.data_structures.traverse(structure)[source]

Iterates over a structure yielding module names, names and values.

NOTE: Items are iterated in key sorted order.

Parameters

structure (Mapping[str, Mapping[str, T]]) – The structure to traverse.

Yields

Tuples of the module name, name and value from the given structure.

Return type

Generator[Tuple[str, str, T], None, None]

tree_bytes

haiku.data_structures.tree_bytes(tree)[source]

Sums the size in bytes of all arrays in a pytree.

Note that this is the minimum size of the array (e.g. for a float32 we need at least 4 bytes) however on some accelerators buffers may occupy more memory due to padding/alignment constraints.

For example given a ResNet50 model:

>>> f = hk.transform_with_state(lambda x: hk.nets.ResNet50(1000)(x, True))
>>> rng = jax.random.PRNGKey(42)
>>> x = jnp.ones([128, 224, 224, 3])
>>> params, state = f.init(rng, x)

We can count the number of parameters and their size at f32:

>>> num_params = hk.data_structures.tree_size(params)
>>> byte_size = hk.data_structures.tree_bytes(params)
>>> print(f'{num_params} params, size: {byte_size / 1e6:.2f}MB')
25557032 params, size: 102.23MB

And compare that with casting our parameters to bf16:

>>> params = jax.tree_map(lambda x: x.astype(jnp.bfloat16), params)
>>> num_params = hk.data_structures.tree_size(params)
>>> byte_size = hk.data_structures.tree_bytes(params)
>>> print(f'{num_params} params, size: {byte_size / 1e6:.2f}MB')
25557032 params, size: 51.11MB
Parameters

tree – A tree of jnp.ndarrays.

Return type

int

Returns

The total size in bytes of the array(s) in the input.

tree_size

haiku.data_structures.tree_size(tree)[source]

Sums the sizes of all arrays in a pytree.

For example given a ResNet50 model:

>>> f = hk.transform_with_state(lambda x: hk.nets.ResNet50(1000)(x, True))
>>> rng = jax.random.PRNGKey(42)
>>> x = jnp.ones([128, 224, 224, 3])
>>> params, state = f.init(rng, x)

We can count the number of parameters and their size at f32:

>>> num_params = hk.data_structures.tree_size(params)
>>> byte_size = hk.data_structures.tree_bytes(params)
>>> print(f'{num_params} params, size: {byte_size / 1e6:.2f}MB')
25557032 params, size: 102.23MB

And compare that with casting our parameters to bf16:

>>> params = jax.tree_map(lambda x: x.astype(jnp.bfloat16), params)
>>> num_params = hk.data_structures.tree_size(params)
>>> byte_size = hk.data_structures.tree_bytes(params)
>>> print(f'{num_params} params, size: {byte_size / 1e6:.2f}MB')
25557032 params, size: 51.11MB
Parameters

tree – A tree of jnp.ndarrays.

Return type

int

Returns

The total size (number of elements) of the array(s) in the input.

Testing

transform_and_run([f, seed, run_apply, …])

Transforms the given function and runs init then (optionally) apply.

transform_and_run

haiku.testing.transform_and_run(f=None, seed=42, run_apply=True, jax_transform=None)[source]

Transforms the given function and runs init then (optionally) apply.

Equivalent to:

>>> def f(x):
...   return x
>>> x = jnp.ones([])
>>> rng = jax.random.PRNGKey(42)
>>> f = hk.transform_with_state(f)
>>> params, state = f.init(rng, x)
>>> out = f.apply(params, state, rng, x)

This function makes it very convenient to unit test Haiku:

>>> class MyTest(unittest.TestCase):
...   @hk.testing.transform_and_run
...   def test_linear_output(self):
...     mod = hk.Linear(1)
...     out = mod(jnp.ones([1, 1]))
...     self.assertEqual(out.ndim, 2)

It can also be combined with chex to test all pure/jit/pmap versions of a function:

>>> class MyTest(unittest.TestCase):
...   @chex.all_variants
...   def test_linear_output(self):
...     @hk.testing.transform_and_run(jax_transform=self.variant)
...     def f(inputs):
...       mod = hk.Linear(1)
...       return mod(inputs)
...     out = f(jnp.ones([1, 1]))
...     self.assertEqual(out.ndim, 2)

And can also be useful in an interactive environment like ipython, Jupyter or Google Colaboratory:

>>> f = lambda x: hk.Bias()(x)
>>> hk.testing.transform_and_run(f)(jnp.ones([1, 1]))
DeviceArray([[1.]], dtype=float32)

See transform() for more details.

Parameters
  • f (Optional[Fn]) – A function method to transform.

  • seed (Optional[int]) – A seed to pass to init and apply.

  • run_apply (bool) – Whether to run apply as well as init. Defaults to true.

  • jax_transform (Optional[Callable[[Fn], Fn]]) – An optional jax transform to apply on the init and apply functions.

Return type

T

Returns

A function that transform()s f and runs init and optionally apply.

Conditional Computation

running_init()

Return True if running the init function of a Haiku transform.

running_init

haiku.running_init()[source]

Return True if running the init function of a Haiku transform.

In general you should not need to gate behaviour of your module based on whether you are running init or apply, but sometimes (e.g. when making use of JAX control flow) this is required.

For example, if you want to use switch() to pick between experts, when we run your init function we need to ensure that params/state for all experts are created (unconditionally) but during apply we want to conditionally apply (and perhaps update the internal state) of only one of our experts:

>>> experts = [hk.nets.ResNet50(10) for _ in range(5)]
>>> x = jnp.ones([1, 224, 224, 3])
>>> if hk.running_init():
...   # During init unconditionally create params/state for all experts.
...   for expert in experts:
...     out = expert(x, is_training=True)
... else:
...   # During apply conditionally apply (and update) only one expert.
...   index = jax.random.randint(hk.next_rng_key(), [], 0, len(experts) - 1)
...   out = hk.switch(index, experts, x)
Return type

bool

Returns

True if running init otherwise False.

Functions

multinomial(rng, logits, num_samples)

Draws samples from a multinomial distribution.

one_hot(x, num_classes[, dtype])

Returns a one-hot version of indices.

multinomial

haiku.multinomial(rng, logits, num_samples)[source]

Draws samples from a multinomial distribution.

Parameters
  • rng – A JAX PRNGKey.

  • logits – Unnormalized log-probabilities, of shape [batch_size, categories] or [categories].

  • num_samples – Number of samples to draw.

Returns

Chosen categories, of shape [batch_size, num_samples] or [num_samples].

one_hot

haiku.one_hot(x, num_classes, dtype=<class 'jax._src.numpy.lax_numpy.float32'>)[source]

Returns a one-hot version of indices.

DEPRECATED: Use jax.nn.one_hot(x, num_classes).astype(dtype) instead.

Parameters
  • x – A tensor of indices.

  • num_classes – Number of classes in the one-hot dimension.

  • dtype – The dtype.

Returns

The one-hot tensor. If indices’ shape is [A, B, …], shape is

[A, B, … num_classes].

References

1

Wojciech Zaremba, Ilya Sutskever, and Oriol Vinyals. Recurrent neural network regularization. arXiv preprint arXiv:1409.2329, 2014. URL: https://arxiv.org/abs/1409.2329.

2(1,2,3,4)

Rafal Jozefowicz, Wojciech Zaremba, and Ilya Sutskever. An empirical exploration of recurrent network architectures. In International Conference on Machine Learning, 2342–2350. 2015.

3(1,2,3)

SHI Xingjian, Zhourong Chen, Hao Wang, Dit-Yan Yeung, Wai-Kin Wong, and Wang-chun Woo. Convolutional lstm network: a machine learning approach for precipitation nowcasting. In Advances in neural information processing systems, 802–810. 2015.