Base

Transforming Functions

haiku.transform(f, *, apply_rng=True)[source]

Transforms a function using Haiku modules into a pair of pure functions.

For a function out = f(*a, **k) this function returns a pair of two pure functions that call f(*a, **k) explicitly collecting and injecting parameter values:

params = init(rng, *a, **k)
out = apply(params, rng, *a, **k)

Note that the rng argument is typically not required for apply and passing None is accepted.

The first thing to do is to define a Module. A module encapsulates some parameters and a computation on those parameters:

>>> class MyModule(hk.Module):
...   def __call__(self, x):
...     w = hk.get_parameter("w", [], init=jnp.zeros)
...     return x + w

Next, define some function that creates and applies modules. We use transform() to transform that function into a pair of functions that allow us to lift all the parameters out of the function (f.init) and apply the function with a given set of parameters (f.apply):

>>> def f(x):
...   a = MyModule()
...   b = MyModule()
...   return a(x) + b(x)
>>> f = hk.transform(f)

To get the initial state of the module call init with an example input:

>>> params = f.init(None, 1)
>>> params
frozendict({
  'my_module': frozendict({'w': DeviceArray(0., dtype=float32)}),
  'my_module_1': frozendict({'w': DeviceArray(0., dtype=float32)}),
})

You can then apply the function with the given parameters by calling apply (note that since we don’t use Haiku’s random number APIs to apply our network we pass None as an RNG key):

>>> f.apply(params, None, 1)
DeviceArray(2., dtype=float32)

It is expected that your program will at some point produce updated parameters and you will want to re-apply apply. You can do this by calling apply with different parameters:

>>> new_params = {"my_module": {"w": jnp.array(2.)},
...               "my_module_1": {"w": jnp.array(3.)}}
>>> f.apply(new_params, None, 2)
DeviceArray(9., dtype=float32)

If your transformed function needs to maintain internal state (e.g. moving averages in batch norm) then see transform_with_state().

Parameters
  • f – A function closing over Module instances.

  • apply_rng – In the process of being removed. Can only value True.

Return type

Transformed

Returns

A Transformed tuple with init and apply pure functions.

haiku.transform_with_state(f)[source]

Transforms a function using Haiku modules into a pair of pure functions.

See transform() for general details on Haiku transformations.

For a function out = f(*a, **k) this function returns a pair of two pure functions that call f(*a, **k) explicitly collecting and injecting parameter values and state:

params, state = init(rng, *a, **k)
out, state = apply(params, state, rng, *a, **k)

Note that the rng argument is typically not required for apply and passing None is accepted.

This function is equivalent to transform(), however it allows you to maintain and update internal state (e.g. ExponentialMovingAverage in BatchNorm) via get_state() and set_state():

>>> def f():
...   counter = hk.get_state("counter", shape=[], dtype=jnp.int32,
...                          init=jnp.zeros)
...   hk.set_state("counter", counter + 1)
...   return counter
>>> f = hk.transform_with_state(f)
>>> params, state = f.init(None)
>>> for _ in range(10):
...   counter, state = f.apply(params, state, None)
>>> counter
DeviceArray(9, dtype=int32)
Parameters

f – A function closing over Module instances.

Return type

TransformedWithState

Returns

A TransformedWithState tuple with init and apply pure functions.

haiku.without_apply_rng(f)[source]

Removes the rng argument from the apply function.

Return type

TransformedT

haiku.without_state(f)[source]

Wraps a transformed tuple and ignores state in/out.

The example below is equivalent to f = hk.transform(f):

>>> def f(x):
...   mod = hk.Linear(10)
...   return mod(x)
>>> f = hk.without_state(hk.transform_with_state(f))
>>> rng = jax.random.PRNGKey(42)
>>> x = jnp.zeros([1, 1])
>>> params = f.init(rng, x)
>>> f.apply(params, rng, x)
DeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)
Parameters

f (TransformedWithState) – A transformed function.

Return type

Transformed

Returns

A transformed function that does not take or return state.

Type Hints

class haiku.LSTMState(hidden: jnp.ndarray, cell: jnp.ndarray)[source]

An LSTM core state consists of hidden and cell vectors.

hidden

Hidden state.

cell

Cell state.

haiku.Params
haiku.State
class haiku.Transformed(init: Callable[…, hk.Params], apply: Callable[…, Any])[source]

Holds a pair of pure functions.

init

A pure function: params = init(rng, *a, **k)

apply

A pure function: out = apply(params, rng, *a, **k)

class haiku.TransformedWithState(init: Callable[…, Tuple[hk.Params, hk.State]], apply: Callable[…, Tuple[Any, hk.State]])[source]

Holds a pair of pure functions.

init

A pure function: params, state = init(rng, *a, **k)

apply

A pure function: out, state = apply(params, state, rng, *a, **k)

Parameters and State

class haiku.Module(name=None)[source]

Base class for Haiku modules.

A Haiku module is a lightweight container for variables and other modules. Modules typically define one or more “forward” methods (e.g. __call__) which apply operations combining user input and module parameters.

Modules must be initialized inside a transform() call.

For example:

>>> class AddModule(hk.Module):
...   def __call__(self, x):
...     w = hk.get_parameter("w", [], init=jnp.ones)
...     return x + w
>>> def forward_fn(x):
...   mod = AddModule()
...   return mod(x)
>>> forward = hk.transform(forward_fn)
>>> x = 1.
>>> rng = None
>>> params = forward.init(rng, x)
>>> forward.apply(params, None, x)
DeviceArray(2., dtype=float32)
__init__(name=None)[source]

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

__post_init__(name=None)[source]

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

params_dict()[source]

Returns parameters keyed by name for this module and submodules.

Return type

Mapping[str, jnp.array]

haiku.to_module(f)[source]

Converts a function into a callable module class.

Sample usage:

>>> def bias_fn(x):
...   b = hk.get_parameter("b", [], init=hk.initializers.RandomNormal())
...   return x + b
>>> Bias = hk.to_module(bias_fn)
>>> def net(x, y):
...   b = Bias(name="my_bias")
...   # Bias x and y by the same amount.
...   return b(x) * b(y)
Parameters

f (Callable[.., Any]) – The function to convert.

Return type

Type[CallableModule]

Returns

A module class which runs f when called.

haiku.get_parameter(name, shape, dtype=<class 'jax.numpy.lax_numpy.float32'>, init=None)[source]

Creates or reuses a parameter for the given transformed function.

>>> hk.get_parameter("w", [], init=jnp.ones)
DeviceArray(1., dtype=float32)

Parameters within the same transform() and/or Module with the same name have the same value:

>>> w1 = hk.get_parameter("w", [], init=jnp.zeros)
>>> w2 = hk.get_parameter("w", [], init=jnp.zeros)
>>> assert w1 is w2
Parameters
  • name (str) – A name for the parameter.

  • shape (Sequence[int]) – The shape of the parameter.

  • dtype (Any) – The dtype of the parameter.

  • init (Initializer) – A callable of shape, dtype to generate an initial value for the parameter.

Return type

jnp.ndarray

Returns

A jnp.ndarray with the parameter of the given shape.

haiku.get_state(name, shape=None, dtype=<class 'jax.numpy.lax_numpy.float32'>, init=None)[source]

Gets the current value for state with an optional initializer.

“State” can be used to represent mutable state in your network. The most common usage of state is to represent the moving averages used in batch normalization (see ExponentialMovingAverage). If your network uses “state” then you are required to use transform_with_state() and pass state into and out of the apply function.

>>> hk.get_state("counter", [], init=jnp.zeros)
DeviceArray(0., dtype=float32)

If the value for the given state is already defined (e.g. using set_state()) then you can call with just the name:

>>> hk.get_state("counter")
DeviceArray(0., dtype=float32)

MOTE: state within the same transform() and/or Module with the same name have the same value:

>>> c1 = hk.get_state("counter")
>>> c2 = hk.get_state("counter")
>>> assert c1 is c2
Parameters
  • name (str) – A name for the state.

  • shape (Optional[Sequence[int]]) – The shape of the state.

  • dtype (Any) – The dtype of the state.

  • init (Optional[Initializer]) – A callable f(shape, dtype) that returns an initial value for the state.

Return type

jnp.ndarray

Returns

A jnp.ndarray with the state of the given shape.

haiku.set_state(name, value)[source]

Sets the current value for some state.

See get_state().

“State” can be used to represent mutable state in your network. The most common usage of state is to represent the moving averages used in batch normalization (see ExponentialMovingAverage). If your network uses “state” then you are required to use transform_with_state() and pass state into and out of the apply function.

>>> hk.set_state("counter", jnp.zeros([]))
>>> hk.get_state("counter")
DeviceArray(0., dtype=float32)

NOTE: state within the same transform() and/or Module with the same name have the same value:

>>> w1 = hk.get_state("counter")
>>> w2 = hk.get_state("counter")
>>> assert w1 is w2
Parameters
  • name (str) – A name for the state.

  • value – A value to set.

haiku.transparent(method)[source]

Decorator to wrap a method, preventing automatic variable scope wrapping.

By default, all variables and modules created in a method are scoped by the module and method names. This is undesirable in some cases. Any method decorated with transparent() will create variables and modules in the scope in which it was called.

Parameters

method (T) – the method to wrap.

Return type

T

Returns

The method, with a flag indicating no name scope wrapping should occur.

Random Number Generators

class haiku.PRNGSequence(key_or_seed)[source]

Iterator of JAX random keys.

>>> seq = hk.PRNGSequence(42)  # OR pass a jax.random.PRNGKey
>>> key1 = next(seq)
>>> key2 = next(seq)
>>> assert key1 is not key2

If you know how many keys you will want then you can use reserve() to more efficiently split the keys you need:

>>> seq.reserve(4)
>>> keys = [next(seq) for _ in range(4)]
__init__(key_or_seed)[source]

Creates a new PRNGSequence.

reserve(num)[source]

Splits an additional num keys for later use.

__next__()[source]

Return the next item from the iterator. When exhausted, raise StopIteration

Return type

PRNGKey

next()[source]

Return the next item from the iterator. When exhausted, raise StopIteration

Return type

PRNGKey

haiku.next_rng_key()[source]

Returns a unique JAX random key split from the current global key.

>>> key = hk.next_rng_key()
>>> _ = jax.random.uniform(key, [])
Return type

PRNGKey

Returns

A unique (within a call to init or apply) JAX rng key that can be used with APIs such as jax.random.uniform.

haiku.next_rng_keys(num)[source]

Returns one or more JAX random key split from the current global key.

>>> k1, k2 = hk.next_rng_keys(2)
>>> assert (k1 != k2).all()
>>> a = jax.random.uniform(k1, [])
>>> b = jax.random.uniform(k2, [])
>>> assert a != b
Parameters

num (int) – The number of keys to split.

Return type

Tuple[PRNGKey, ..]

Returns

One or more unique (within a transformed function) JAX rng key that can be used with APIs such as jax.random.uniform.

haiku.maybe_next_rng_key()[source]

next_rng_key() if random numbers are available, else None.

Return type

Optional[PRNGKey]

haiku.reserve_rng_keys(num)[source]

Pre-allocate some number of JAX RNG keys.

See next_rng_key().

This API offers a way to micro-optimize how RNG keys are split when using Haiku. It is unlikely that you need it unless you find compilation time of your init function to be a problem, or you sample a lot of random numbers in apply.

>>> hk.reserve_rng_keys(2)  # Pre-allocate 2 keys for us to consume.
>>> _ = hk.next_rng_key()   # Takes the first pre-allocated key.
>>> _ = hk.next_rng_key()   # Takes the second pre-allocated key.
>>> _ = hk.next_rng_key()   # Splits a new key.
Parameters

num (int) – The number of JAX rng keys to allocate.

haiku.with_rng(key)[source]

Provides a new sequence for next_rng_key() to draw from.

When next_rng_key() is called, it draws a new key from the PRNGSequence defined by the input key to the transformed function. This context manager overrides the sequence for the duration of the scope.

>>> with hk.with_rng(jax.random.PRNGKey(428)):
...   s = jax.random.uniform(hk.next_rng_key(), ())
>>> s
DeviceArray(0.501871, dtype=float32)
Parameters

key (PRNGKey) – The key to seed the sequence with.

Returns

Context manager under which the given sequence is active.

Linear modules

Linear

class haiku.Linear(output_size, with_bias=True, w_init=None, b_init=None, name=None)[source]

Linear module.

__init__(output_size, with_bias=True, w_init=None, b_init=None, name=None)[source]

Constructs the Linear module.

Parameters
  • output_size (int) – Output dimensionality.

  • with_bias (bool) – Whether to add a bias to the output.

  • w_init (Optional[hk.initializers.Initializer]) – Optional initializer for weights. By default, uses random values from truncated normal, with stddev 1 / sqrt(fan_in). See https://arxiv.org/abs/1502.03167v3.

  • b_init (Optional[hk.initializers.Initializer]) – Optional initializer for bias. By default, zero.

  • name (Optional[str]) – Name of the module.

__call__(inputs)[source]

Computes a linear transform of the input.

Return type

jnp.ndarray

Bias

class haiku.Bias(output_size=None, bias_dims=None, b_init=None, name=None)[source]

Adds a bias to inputs.

Example Usage:

>>> N, H, W, C = 1, 2, 3, 4
>>> x = jnp.ones([N, H, W, C])
>>> scalar_bias = hk.Bias(bias_dims=[])
>>> scalar_bias_output = scalar_bias(x)
>>> assert scalar_bias.bias_shape == ()

Create a bias over all non-minibatch dimensions:

>>> all_bias = hk.Bias()
>>> all_bias_output = all_bias(x)
>>> assert all_bias.bias_shape == (H, W, C)

Create a bias over the last non-minibatch dimension:

>>> last_bias = hk.Bias(bias_dims=[-1])
>>> last_bias_output = last_bias(x)
>>> assert last_bias.bias_shape == (C,)

Create a bias over the first non-minibatch dimension:

>>> first_bias = hk.Bias(bias_dims=[1])
>>> first_bias_output = first_bias(x)
>>> assert first_bias.bias_shape == (H, 1, 1)

Subtract and later add the same learned bias:

>>> bias = hk.Bias()
>>> h1 = bias(x, multiplier=-1)
>>> h2 = bias(x)
>>> h3 = bias(x, multiplier=-1)
>>> reconstructed_x = bias(h3)
>>> assert (x == reconstructed_x).all()
__init__(output_size=None, bias_dims=None, b_init=None, name=None)[source]

Constructs a Bias module that supports broadcasting.

Parameters
  • output_size (Optional[Sequence[int]]) – Output size (output shape without batch dimension). If output_size is left as None, the size will be directly inferred by the input.

  • bias_dims (Optional[Sequence[int]]) – Sequence of which dimensions to retain from the input shape when constructing the bias. The remaining dimensions will be broadcast over (given size of 1), and leading dimensions will be removed completely. See class doc for examples.

  • b_init (Optional[hk.initializers.Initializer]) – Optional initializer for the bias. Default to zeros.

  • name (Optional[str]) – Name of the module.

__call__(inputs, multiplier=None)[source]

Adds bias to inputs and optionally multiplies by multiplier.

Parameters
  • inputs (jnp.ndarray) – A Tensor of size [batch_size, input_size1, ...].

  • multiplier (Union[float, jnp.ndarray]) – A scalar or Tensor which the bias term is multiplied by before adding it to inputs. Anything which works in the expression bias * multiplier is acceptable here. This may be useful if you want to add a bias in one place and subtract the same bias in another place via multiplier=-1.

Return type

jnp.ndarray

Returns

A Tensor of size [batch_size, input_size1, ...].

Pooling modules

Average Pool

haiku.avg_pool(value, window_shape, strides, padding, channel_axis=- 1)[source]

Average pool.

Parameters
  • value (jnp.ndarray) – Value to pool.

  • window_shape (Union[int, Sequence[int]]) – Shape of the pooling window, an int or same rank as value.

  • strides (Union[int, Sequence[int]]) – Strides of the pooling window, an int or same rank as value.

  • padding (str) – Padding algorithm. Either VALID or SAME.

  • channel_axis (Optional[int]) – Axis of the spatial channels for which pooling is skipped, used to infer window_shape or strides if they are an integer.

Return type

jnp.ndarray

Returns

Pooled result. Same rank as value.

Raises

ValueError – If the padding is not valid.

class haiku.AvgPool(window_shape, strides, padding, channel_axis=- 1, name=None)[source]

Average pool.

Equivalent to partial application of avg_pool().

__init__(window_shape, strides, padding, channel_axis=- 1, name=None)[source]

Average pool.

Parameters
  • window_shape (Union[int, Sequence[int]]) – Shape of window to pool over. Same rank as value or int.

  • strides (Union[int, Sequence[int]]) – Strides for the window. Same rank as value or int.

  • padding (str) – Padding algorithm. Either VALID or SAME.

  • channel_axis (Optional[int]) – Axis of the spatial channels for which pooling is skipped.

  • name (Optional[str]) – String name for the module.

__call__(value)[source]

Call self as a function.

Return type

jnp.ndarray

Max Pool

haiku.max_pool(value, window_shape, strides, padding, channel_axis=- 1)[source]

Max pool.

Parameters
  • value (jnp.ndarray) – Value to pool.

  • window_shape (Union[int, Sequence[int]]) – Shape of the pooling window, an int or same rank as value.

  • strides (Union[int, Sequence[int]]) – Strides of the pooling window, an int or same rank as value.

  • padding (str) – Padding algorithm. Either VALID or SAME.

  • channel_axis (Optional[int]) – Axis of the spatial channels for which pooling is skipped, used to infer window_shape or strides if they are an integer.

Return type

jnp.ndarray

Returns

Pooled result. Same rank as value.

class haiku.MaxPool(window_shape, strides, padding, channel_axis=- 1, name=None)[source]

Max pool.

Equivalent to partial application of max_pool().

__init__(window_shape, strides, padding, channel_axis=- 1, name=None)[source]

Max pool.

Parameters
  • window_shape (Union[int, Sequence[int]]) – Shape of window to pool over. Same rank as value or int.

  • strides (Union[int, Sequence[int]]) – Strides for the window. Same rank as value or int.

  • padding (str) – Padding algorithm. Either VALID or SAME.

  • channel_axis (Optional[int]) – Axis of the spatial channels for which pooling is skipped.

  • name (Optional[str]) – String name for the module.

__call__(value)[source]

Call self as a function.

Return type

jnp.ndarray

Dropout

dropout

haiku.dropout(rng, rate, x)[source]

Randomly drop units in the input at a given rate.

See: http://www.cs.toronto.edu/~hinton/absps/dropout.pdf

Parameters
  • rng (PRNGKey) – A JAX random key.

  • rate (float) – Probability that each element of x is discarded. Must be a scalar

  • the range [0, 1) (in) –

  • x (jnp.ndarray) – The value to be dropped out.

Return type

jnp.ndarray

Returns

x, but dropped out and scaled by 1 / (1 - rate).

Combinator modules

Sequential

class haiku.Sequential(layers, name=None)[source]

Sequentially calls the given list of layers.

Note that Sequential is limited in the range of possible architectures it can handle. This is a deliberate design decision; Sequential is only meant to be used for the simple case of fusing together modules/ops where the input of a particular module/op is the output of the previous one.

Another restriction is that it is not possible to have extra arguments in the __call__() method that are passed to the constituents of the module - for example, if there is a BatchNorm module in Sequential and the user wishes to switch the is_training flag. If this is the desired use case, the recommended solution is to subclass Module and implement __call__:

>>> class CustomModule(hk.Module):
...   def __call__(self, x, is_training):
...     x = hk.Conv2D(32, 4, 2)(x)
...     x = hk.BatchNorm(True, True, 0.9)(x, is_training)
...     x = jax.nn.relu(x)
...     return x
__init__(layers, name=None)[source]

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

__call__(inputs, *args, **kwargs)[source]

Calls all layers sequentially.

Convolutional modules

ConvND

class haiku.ConvND(num_spatial_dims, output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='channels_last', mask=None, name=None)[source]

General N-dimensional convolutional.

__init__(num_spatial_dims, output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='channels_last', mask=None, name=None)[source]

Initializes the module.

Parameters
  • num_spatial_dims (int) – The number of spatial dimensions of the input.

  • output_channels (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length num_spatial_dims.

  • stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length num_spatial_dims. Defaults to 1.

  • rate (Union[int, Sequence[int]]) – Optional kernel dilation rate. Either an integer or a sequence of length num_spatial_dims. 1 corresponds to standard ND convolution, rate > 1 corresponds to dilated convolution. Defaults to 1.

  • padding (Union[str, Sequence[Tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – Optional padding algorithm. Either VALID or SAME or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. or a callable or sequence of callables of size num_spatial_dims. Any callables must take a single integer argument equal to the effective kernel size and return a sequence of two integers representing the padding before and after. See haiku.pad.* for more details and example functions. Defaults to SAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.

  • with_bias (bool) – Whether to add a bias. By default, true.

  • w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.

  • b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.

  • data_format (str) – The data format of the input. Can be either channels_first, channels_last, N...C or NC.... By default, channels_last.

  • mask (Optional[jnp.ndarray]) – Optional mask of the weights.

  • name (Optional[str]) – The name of the module.

__call__(inputs)[source]

Connects ConvND layer.

Parameters

inputs (jnp.ndarray) – An array of shape [spatial_dims, C] and rank-N+1 if unbatched, or an array of shape [N, spatial_dims, C] and rank-N+2 if batched.

Return type

jnp.ndarray

Returns

An array of shape [spatial_dims, output_channels] and rank-N+1 if

unbatched, or an array of shape [N, spatial_dims, output_channels] and rank-N+2 if batched.

Conv1D

class haiku.Conv1D(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', mask=None, name=None)[source]

One dimensional convolution.

__init__(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', mask=None, name=None)[source]

Initializes the module.

Parameters
  • output_channels (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 1.

  • stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 1. Defaults to 1.

  • rate (Union[int, Sequence[int]]) – Optional kernel dilation rate. Either an integer or a sequence of length 1. 1 corresponds to standard ND convolution, rate > 1 corresponds to dilated convolution. Defaults to 1.

  • padding (Union[str, Sequence[Tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – Optional padding algorithm. Either VALID or SAME or a callable or sequence of callables of length 1. Any callables must take a single integer argument equal to the effective kernel size and return a list of two integers representing the padding before and after. See haiku.pad.* for more details and example functions. Defaults to SAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.

  • with_bias (bool) – Whether to add a bias. By default, true.

  • w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.

  • b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.

  • data_format (str) – The data format of the input. Either NWC or NCW. By default, NWC.

  • mask (Optional[jnp.ndarray]) – Optional mask of the weights.

  • name (Optional[str]) – The name of the module.

Conv2D

class haiku.Conv2D(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', mask=None, name=None)[source]

Two dimensional convolution.

__init__(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', mask=None, name=None)[source]

Initializes the module.

Parameters
  • output_channels (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 2.

  • stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 2. Defaults to 1.

  • rate (Union[int, Sequence[int]]) – Optional kernel dilation rate. Either an integer or a sequence of length 2. 1 corresponds to standard ND convolution, rate > 1 corresponds to dilated convolution. Defaults to 1.

  • padding (Union[str, Sequence[Tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – Optional padding algorithm. Either VALID or SAME or a callable or sequence of callables of length 2. Any callables must take a single integer argument equal to the effective kernel size and return a list of two integers representing the padding before and after. See haiku.pad.* for more details and example functions. Defaults to SAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.

  • with_bias (bool) – Whether to add a bias. By default, true.

  • w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.

  • b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.

  • data_format (str) – The data format of the input. Either NHWC or NCHW. By default, NHWC.

  • mask (Optional[jnp.ndarray]) – Optional mask of the weights.

  • name (Optional[str]) – The name of the module.

Conv3D

class haiku.Conv3D(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', mask=None, name=None)[source]

Three dimensional convolution.

__init__(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', mask=None, name=None)[source]

Initializes the module.

Parameters
  • output_channels (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 3.

  • stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 3. Defaults to 1.

  • rate (Union[int, Sequence[int]]) – Optional kernel dilation rate. Either an integer or a sequence of length 3. 1 corresponds to standard ND convolution, rate > 1 corresponds to dilated convolution. Defaults to 1.

  • padding (Union[str, Sequence[Tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – Optional padding algorithm. Either VALID or SAME or a callable or sequence of callables of length 3. Any callables must take a single integer argument equal to the effective kernel size and return a list of two integers representing the padding before and after. See haiku.pad.* for more details and example functions. Defaults to SAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.

  • with_bias (bool) – Whether to add a bias. By default, true.

  • w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.

  • b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.

  • data_format (str) – The data format of the input. Either NDHWC or NCDHW. By default, NDHWC.

  • mask (Optional[jnp.ndarray]) – Optional mask of the weights.

  • name (Optional[str]) – The name of the module.

ConvNDTranspose

class haiku.ConvNDTranspose(num_spatial_dims, output_channels, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='channels_last', mask=None, name=None)[source]

General n-dimensional transposed convolution (aka. deconvolution).

__init__(num_spatial_dims, output_channels, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='channels_last', mask=None, name=None)[source]

Initializes the module.

Parameters
  • num_spatial_dims (int) – The number of spatial dimensions of the input.

  • output_channels (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length num_spatial_dims.

  • stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length num_spatial_dims. Defaults to 1.

  • padding (Union[str, Sequence[Tuple[int, int]]]) – Optional padding algorithm. Either “VALID” or “SAME”. Defaults to “SAME”. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.

  • with_bias (bool) – Whether to add a bias. By default, true.

  • w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.

  • b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.

  • data_format (str) – The data format of the input. Can be either channels_first, channels_last, N...C or NC.... By default, channels_last.

  • mask (Optional[jnp.ndarray]) – Optional mask of the weights.

  • name (str) – The name of the module.

__call__(inputs)[source]

Computes the transposed convolution of the input.

Parameters

inputs (jnp.ndarray) – An array of shape [spatial_dims, C] and rank-N+1 if unbatched, or an array of shape [N, spatial_dims, C] and rank-N+2 if batched.

Return type

jnp.ndarray

Returns

An array of shape [spatial_dims, output_channels] and rank-N+1 if

unbatched, or an array of shape [N, spatial_dims, output_channels] and rank-N+2 if batched.

Conv1DTranspose

class haiku.Conv1DTranspose(output_channels, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', mask=None, name=None)[source]

One dimensional transposed convolution (aka. deconvolution).

__init__(output_channels, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', mask=None, name=None)[source]

Initializes the module.

Parameters
  • output_channels (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 1.

  • stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 1. Defaults to 1.

  • padding (Union[str, Sequence[Tuple[int, int]]]) – Optional padding algorithm. Either VALID or SAME. Defaults to SAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.

  • with_bias (bool) – Whether to add a bias. By default, true.

  • w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.

  • b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.

  • data_format (str) – The data format of the input. Either NWC or NCW. By default, NWC.

  • mask (Optional[jnp.ndarray]) – Optional mask of the weights.

  • name (Optional[str]) – The name of the module.

Conv2DTranspose

class haiku.Conv2DTranspose(output_channels, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', mask=None, name=None)[source]

Two dimensional transposed convolution (aka. deconvolution).

__init__(output_channels, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', mask=None, name=None)[source]

Initializes the module.

Parameters
  • output_channels (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 2.

  • stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 2. Defaults to 1.

  • padding (Union[str, Sequence[Tuple[int, int]]]) – Optional padding algorithm. Either VALID or SAME. Defaults to SAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.

  • with_bias (bool) – Whether to add a bias. By default, true.

  • w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.

  • b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.

  • data_format (str) – The data format of the input. Either NHWC or NCHW. By default, NHWC.

  • mask (Optional[jnp.ndarray]) – Optional mask of the weights.

  • name (Optional[str]) – The name of the module.

Conv3DTranspose

class haiku.Conv3DTranspose(output_channels, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', mask=None, name=None)[source]

Three dimensional transposed convolution (aka. deconvolution).

__init__(output_channels, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', mask=None, name=None)[source]

Initializes the module.

Parameters
  • output_channels (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 3.

  • stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 3. Defaults to 1.

  • padding (Union[str, Sequence[Tuple[int, int]]]) – Optional padding algorithm. Either VALID or SAME. Defaults to SAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.

  • with_bias (bool) – Whether to add a bias. By default, true.

  • w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.

  • b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.

  • data_format (str) – The data format of the input. Either NDHWC or NCDHW. By default, NDHWC.

  • mask (Optional[jnp.ndarray]) – Optional mask of the weights.

  • name (Optional[str]) – The name of the module.

DepthwiseConv2D

class haiku.DepthwiseConv2D(channel_multiplier, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', name=None)[source]

2-D Depthwise Convolution Module.

__init__(channel_multiplier, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', name=None)[source]

Construct a 2D Depthwise Convolution.

Parameters
  • channel_multiplier (int) – Multiplicity of output channels. To keep the number of output channels the same as the number of input channels, set 1.

  • kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length num_spatial_dims.

  • stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length num_spatial_dims. Defaults to 1.

  • padding (Union[str, Sequence[Tuple[int, int]]]) – Optional padding algorithm. Either VALID, SAME or a sequence of before, after pairs. Defaults to SAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.

  • with_bias (bool) – Whether to add a bias. By default, true.

  • w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.

  • b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.

  • data_format (str) – The data format of the input. Can be either channels_first, channels_last, N...C or NC.... By default, channels_last.

  • name (Optional[str]) – The name of the module.

__call__(inputs)[source]

Call self as a function.

Return type

jnp.ndarray

Normalization modules

LayerNorm

class haiku.LayerNorm(axis, create_scale, create_offset, eps=1e-05, scale_init=None, offset_init=None, name=None)[source]

LayerNorm module.

See: https://arxiv.org/abs/1607.06450.

__init__(axis, create_scale, create_offset, eps=1e-05, scale_init=None, offset_init=None, name=None)[source]

Constructs a LayerNorm module.

Parameters
  • axis (Union[int, Sequence[int], slice]) – Integer, list of integers, or slice indicating which axes to normalize over.

  • create_scale (bool) – Bool, defines whether to create a trainable scale per channel applied after the normalization.

  • create_offset (bool) – Bool, defines whether to create a trainable offset per channel applied after normalization and scaling.

  • eps (float) – Small epsilon to avoid division by zero variance. Defaults 1e-5, as in the paper and Sonnet.

  • scale_init (Optional[hk.initializers.Initializer]) – Optional initializer for gain (aka scale). By default, one.

  • offset_init (Optional[hk.initializers.Initializer]) – Optional initializer for bias (aka offset). By default, zero.

  • name (Optional[str]) – The module name.

__call__(inputs, scale=None, offset=None)[source]

Connects the layer norm.

Parameters
  • inputs (jnp.ndarray) – An array, where the data format is [N, ..., C].

  • scale (Optional[jnp.ndarray]) – An array up to n-D. The shape of this tensor must be broadcastable to the shape of inputs. This is the scale applied to the normalized inputs. This cannot be passed in if the module was constructed with create_scale=True.

  • offset (Optional[jnp.ndarray]) – An array up to n-D. The shape of this tensor must be broadcastable to the shape of inputs. This is the offset applied to the normalized inputs. This cannot be passed in if the module was constructed with create_offset=True.

Return type

jnp.ndarray

Returns

The array, normalized.

InstanceNorm

class haiku.InstanceNorm(create_scale, create_offset, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]

Normalizes inputs along the spatial dimensions.

See LayerNorm for more details.

__init__(create_scale, create_offset, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]

Constructs an InstanceNorm module.

This method creates a module which normalizes over the spatial dimensions.

Parameters
  • create_scale (bool) – bool representing whether to create a trainable scale per channel applied after the normalization.

  • create_offset (bool) – bool representing whether to create a trainable offset per channel applied after normalization and scaling.

  • eps (float) – Small epsilon to avoid division by zero variance. Defaults to 1e-5.

  • scale_init (Optional[hk.initializers.Initializer]) – Optional initializer for the scale variable. Can only be set if create_scale=True. By default scale is initialized to 1.

  • offset_init (Optional[hk.initializers.Initializer]) – Optional initializer for the offset variable. Can only be set if create_offset=True. By default offset is initialized to 0.

  • data_format (str) – The data format of the input. Can be either channels_first, channels_last, N...C or NC.... By default it is channels_last.

  • name (Optional[str]) – Name of the module.

BatchNorm

class haiku.BatchNorm(create_scale, create_offset, decay_rate, eps=1e-05, scale_init=None, offset_init=None, axis=None, cross_replica_axis=None, data_format='channels_last', name=None)[source]

Normalizes inputs to maintain a mean of ~0 and stddev of ~1.

See: https://arxiv.org/abs/1502.03167.

There are many different variations for how users want to manage scale and offset if they require them at all. These are:

  • No scale/offset in which case create_* should be set to False and scale/offset aren’t passed when the module is called.

  • Trainable scale/offset in which case create_* should be set to True and again scale/offset aren’t passed when the module is called. In this case this module creates and owns the scale/offset variables.

  • Externally generated scale/offset, such as for conditional normalization, in which case create_* should be set to False and then the values fed in at call time.

NOTE: jax.vmap(hk.transform(BatchNorm)) will update summary statistics and normalize values on a per-batch basis; we currently do not support normalizing across a batch axis introduced by vmap.

__init__(create_scale, create_offset, decay_rate, eps=1e-05, scale_init=None, offset_init=None, axis=None, cross_replica_axis=None, data_format='channels_last', name=None)[source]

Constructs a BatchNorm module.

Parameters
  • create_scale (bool) – Whether to include a trainable scaling factor.

  • create_offset (bool) – Whether to include a trainable offset.

  • decay_rate (float) – Decay rate for EMA.

  • eps (float) – Small epsilon to avoid division by zero variance. Defaults 1e-5, as in the paper and Sonnet.

  • scale_init (Optional[hk.initializers.Initializer]) – Optional initializer for gain (aka scale). Can only be set if create_scale=True. By default, 1.

  • offset_init (Optional[hk.initializers.Initializer]) – Optional initializer for bias (aka offset). Can only be set if create_offset=True. By default, 0.

  • axis (Optional[Sequence[int]]) – Which axes to reduce over. The default (None) signifies that all but the channel axis should be normalized. Otherwise this is a list of axis indices which will have normalization statistics calculated.

  • cross_replica_axis (Optional[str]) – If not None, it should be a string representing the axis name over which this module is being run within a jax.pmap. Supplying this argument means that batch statistics are calculated across all replicas on that axis.

  • data_format (str) – The data format of the input. Can be either channels_first, channels_last, N...C or NC.... By default it is channels_last.

  • name (Optional[str]) – The module name.

__call__(inputs, is_training, test_local_stats=False, scale=None, offset=None)[source]

Computes the normalized version of the input.

Parameters
  • inputs (jnp.ndarray) – An array, where the data format is [..., C].

  • is_training (bool) – Whether this is during training.

  • test_local_stats (bool) – Whether local stats are used when is_training=False.

  • scale (Optional[jnp.ndarray]) – An array up to n-D. The shape of this tensor must be broadcastable to the shape of inputs. This is the scale applied to the normalized inputs. This cannot be passed in if the module was constructed with create_scale=True.

  • offset (Optional[jnp.ndarray]) – An array up to n-D. The shape of this tensor must be broadcastable to the shape of inputs. This is the offset applied to the normalized inputs. This cannot be passed in if the module was constructed with create_offset=True.

Return type

jnp.ndarray

Returns

The array, normalized across all but the last dimension.

ExponentialMovingAverage

class haiku.ExponentialMovingAverage(*args, **kwargs)[source]

Maintains an exponential moving average.

This uses the Adam debiasing procedure. See https://arxiv.org/pdf/1412.6980.pdf for details.

__init__(decay, zero_debias=True, warmup_length=0, name=None)[source]

Initializes an ExponentialMovingAverage module.

Parameters
  • decay – The chosen decay. Must in [0, 1). Values close to 1 result in slow decay; values close to 0 result in fast decay.

  • zero_debias – Whether to run with zero-debiasing.

  • warmup_length – A positive integer, EMA has no effect until the internal counter has reached warmup_length at which point the initial value for the decaying average is initialized to the input value after warmup_length iterations.

  • name – The name of the module.

initialize(value)[source]

If uninitialized sets the average to zeros_like the given value.

__call__(value, update_stats=True)[source]

Updates the EMA and returns the new value.

Parameters
  • value – The array-like object for which you would like to perform an exponential decay on.

  • update_stats – A Boolean, whether to update the internal state of this object to reflect the input value. When update_stats is False the internal stats will remain unchanged.

Returns

The exponentially weighted average of the input value.

SpectralNorm

class haiku.SpectralNorm(eps=0.0001, n_steps=1, name=None)[source]

Normalizes an input by its first singular value.

This module uses power iteration to calculate this value based on the input and an internal hidden state.

__init__(eps=0.0001, n_steps=1, name=None)[source]

Initializes an SpectralNorm module.

Parameters
  • eps (float) – The constant used for numerical stability.

  • n_steps (int) – How many steps of power iteration to perform to approximate the singular value of the input.

  • name (Optional[str]) – The name of the module.

__call__(value, update_stats=True, error_on_non_matrix=False)[source]

Performs Spectral Normalization and returns the new value.

Parameters
  • value – The array-like object for which you would like to perform an spectral normalization on.

  • update_stats (bool) – A boolean defaulting to True. Regardless of this arg, this function will return the normalized input. When update_stats is True, the internal state of this object will also be updated to reflect the input value. When update_stats is False the internal stats will remain unchanged.

  • error_on_non_matrix (bool) – Spectral normalization is only defined on matrices. By default, this module will return scalars unchanged and flatten higher-order tensors in their leading dimensions. Setting this flag to True will instead throw errors in those cases.

Return type

jnp.ndarray

Returns

The input value normalized by it’s first singular value.

Raises

ValueError – If error_on_non_matrix is True and value has ndims > 2.

SNParamsTree

class haiku.SNParamsTree(eps=0.0001, n_steps=1, ignore_regex='', name=None)[source]

Applies Spectral Normalization to all parameters in a tree.

This is isomorphic to EMAParamsTree in moving_averages.py.

__init__(eps=0.0001, n_steps=1, ignore_regex='', name=None)[source]

Initializes an SNParamsTree module.

Parameters
  • eps (float) – The constant used for numerical stability.

  • n_steps (int) – How many steps of power iteration to perform to approximate the singular value of the input.

  • ignore_regex (str) – A string. Any parameter in the tree whose name matches this regex will not have spectral normalization applied to it. The empty string means this module apply to all parameters.

  • name (Optional[str]) – The name of the module.

__call__(tree, update_stats=True)[source]

Call self as a function.

EMAParamsTree

class haiku.EMAParamsTree(*args, **kwargs)[source]

Maintains an exponential moving average for all parameters in a tree.

While ExponentialMovingAverage is meant to be applied to single parameters within a function, this class is meant to be applied to the entire tree of parameters for a function.

Given a set of parameters for some network:

>>> network_fn = lambda x: hk.Linear(10)(x)
>>> x = jnp.ones([1, 1])
>>> params = hk.transform(network_fn).init(jax.random.PRNGKey(428), x)

You might use the EMAParamsTree like follows:

>>> ema_fn = hk.transform_with_state(lambda x: hk.EMAParamsTree(0.2)(x))
>>> _, ema_state = ema_fn.init(None, params)
>>> ema_params, ema_state = ema_fn.apply(None, ema_state, None, params)

Here, we are transforming a Haiku function and constructing its parameters via an init_fn as normal, but are creating a second transformed function which expects a tree of parameters as input. This function is then called with the current parameters as input, which then returns an identical tree with every parameter replaced with its exponentially decayed average. This ema_params object can then be passed into the network_fn as usual, and will cause it to run with EMA weights.

__init__(decay, zero_debias=True, warmup_length=0, ignore_regex='', name=None)[source]

Initializes an EMAParamsTree module.

Parameters
  • decay – The chosen decay. Must in [0, 1). Values close to 1 result in slow decay; values close to 0 result in fast decay.

  • zero_debias – Whether to run with zero-debiasing.

  • warmup_length – A positive integer, EMA has no effect until the internal counter has reached warmup_length at which point the initial value for the decaying average is initialized to the input value after warmup_length iterations.

  • ignore_regex – A string. Any parameter in the tree whose name matches this regex will not have any moving average applied to it. The empty string means this module will EMA all parameters.

  • name – The name of the module.

__call__(tree, update_stats=True)[source]

Call self as a function.

Recurrent modules

RNNCore

class haiku.RNNCore(name=None)[source]

Base class for RNN cores.

This class defines the basic functionality that every core should implement: initial_state(), used to construct an example of the core state; and __call__() which applies the core parameterized by a previous state to an input.

Cores may be used with dynamic_unroll() and static_unroll() to iteratively construct an output sequence from the given input sequence.

abstract __call__(inputs, prev_state)[source]

Run one step of the RNN.

Parameters
  • inputs – An arbitrarily nested structure.

  • prev_state – Previous core state.

Return type

Tuple[Any, Any]

Returns

A tuple with two elements output, next_state. output is an arbitrarily nested structure. next_state is the next core state, this must be the same shape as prev_state.

abstract initial_state(batch_size)[source]

Constructs an initial state for this core.

Parameters

batch_size (Optional[int]) – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.

Returns

Arbitrarily nested initial state for this core.

dynamic_unroll

haiku.dynamic_unroll(core, input_sequence, initial_state)[source]

Performs a dynamic unroll of an RNN.

An unroll corresponds to calling the core on each element of the input sequence in a loop, carrying the state through:

state = initial_state
for t in range(len(input_sequence)):
   outputs, state = core(input_sequence[t], state)

A dynamic unroll preserves the loop structure when executed inside jax.jit. See static_unroll() for an unroll function which replaces a loop with its body repeated multiple times.

Parameters
  • core – An RNNCore to unroll.

  • input_sequence – An arbitrarily nested structure of tensors of shape [T, ...] where T is the number of time steps.

  • initial_state – initial state of the given core.

Returns

  • output_sequence - An arbitrarily nested structure of tensors of shape [T, ...].

  • final_state - Core state at time step T.

Return type

A tuple with two elements

static_unroll

haiku.static_unroll(core, input_sequence, initial_state)[source]

Performs a static unroll of an RNN.

An unroll corresponds to calling the core on each element of the input sequence in a loop, carrying the state through:

state = initial_state
for t in range(len(input_sequence)):
   outputs, state = core(input_sequence[t], state)

A static unroll replaces a loop with its body repeated multiple times when executed inside jax.jit:

state = initial_state
outputs0, state = core(input_sequence[0], state)
outputs1, state = core(input_sequence[1], state)
outputs2, state = core(input_sequence[2], state)
...

See dynamic_unroll() for a loop-preserving unroll function.

Parameters
  • core – An RNNCore to unroll.

  • input_sequence – An arbitrarily nested structure of tensors of shape [T, ...] where T is the number of time steps.

  • initial_state – An initial state of the given core.

Returns

  • output_sequence - An arbitrarily nested structure of tensors of shape [T, ...].

  • final_state - Core state at time step T.

Return type

A tuple with two elements

expand_apply

haiku.expand_apply(f, axis=0)[source]

Wraps f to temporarily add a size-1 axis to its inputs.

Syntactic sugar for:

ins = jax.tree_util.tree_map(lambda t: np.expand_dims(t, axis=axis), ins)
out = f(ins)
out = jax.tree_util.tree_map(lambda t: np.squeeze(t, axis=axis), out)

This may be useful for applying a function built for [Time, Batch, ...] arrays to a single timestep.

Parameters
  • f – The callable to be applied to the expanded inputs.

  • axis – Where to add the extra axis.

Returns

f, wrapped as described above.

VanillaRNN

class haiku.VanillaRNN(hidden_size, name=None)[source]

Basic fully-connected RNN core.

Given \(x_t\) and the previous hidden state \(h_{t-1}\) the core computes

\[h_t = \operatorname{ReLU}(w_i x_t + b_i + w_h h_{t-1} + b_h)\]
__init__(hidden_size, name=None)[source]

Constructs a vanilla RNN core.

Parameters
  • hidden_size (int) – Hidden layer size.

  • name (Optional[str]) – Name of the module.

__call__(inputs, prev_state)[source]

Run one step of the RNN.

Parameters
  • inputs – An arbitrarily nested structure.

  • prev_state – Previous core state.

Returns

A tuple with two elements output, next_state. output is an arbitrarily nested structure. next_state is the next core state, this must be the same shape as prev_state.

initial_state(batch_size)[source]

Constructs an initial state for this core.

Parameters

batch_size (Optional[int]) – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.

Returns

Arbitrarily nested initial state for this core.

LSTM

class haiku.LSTM(hidden_size, name=None)[source]

Long short-term memory (LSTM) RNN core.

The implementation is based on [1]. Given \(x_t\) and the previous state \((h_{t-1}, c_{t-1})\) the core computes

\[\begin{array}{ll} i_t = \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\ f_t = \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\ o_t = \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\]

where \(i_t\), \(f_t\), \(o_t\) are input, forget and output gate activations, and \(g_t\) is a vector of cell updates.

Notes

Forget gate initialization:

Following [2] we add 1.0 to \(b_f\) after initialization in order to reduce the scale of forgetting in the beginning of the training.

__init__(hidden_size, name=None)[source]

Constructs an LSTM.

Parameters
  • hidden_size (int) – Hidden layer size.

  • name (Optional[str]) – Name of the module.

__call__(inputs, prev_state)[source]

Run one step of the RNN.

Parameters
  • inputs (jnp.ndarray) – An arbitrarily nested structure.

  • prev_state (LSTMState) – Previous core state.

Return type

Tuple[jnp.ndarray, LSTMState]

Returns

A tuple with two elements output, next_state. output is an arbitrarily nested structure. next_state is the next core state, this must be the same shape as prev_state.

initial_state(batch_size)[source]

Constructs an initial state for this core.

Parameters

batch_size (Optional[int]) – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.

Return type

LSTMState

Returns

Arbitrarily nested initial state for this core.

GRU

class haiku.GRU(hidden_size, w_i_init=None, w_h_init=None, b_init=None, name=None)[source]

Gated Recurrent Unit.

The implementation is based on: https://arxiv.org/pdf/1412.3555v1.pdf with biases.

Given \(x_t\) and the previous state \(h_{t-1}\) the core computes

\[\begin{array}{ll} z_t &= \sigma(W_{iz} x_t + W_{hz} h_{t-1} + b_z) \\ r_t &= \sigma(W_{ir} x_t + W_{hr} h_{t-1} + b_r) \\ a_t &= \tanh(W_{ia} x_t + W_{ha} (r_t \bigodot h_{t-1}) + b_a) \\ h_t &= (1 - z_t) \bigodot h_{t-1} + z_t \bigodot a_t \end{array}\]

where \(z_t\) and \(r_t\) are reset and update gates.

Warning: Backwards compatibility of GRU weights is currently unsupported.

TODO(tycai): Make policy decision/benchmark performance for GRU variants.

__init__(hidden_size, w_i_init=None, w_h_init=None, b_init=None, name=None)[source]

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

__call__(inputs, state)[source]

Run one step of the RNN.

Parameters
  • inputs – An arbitrarily nested structure.

  • prev_state – Previous core state.

Returns

A tuple with two elements output, next_state. output is an arbitrarily nested structure. next_state is the next core state, this must be the same shape as prev_state.

initial_state(batch_size)[source]

Constructs an initial state for this core.

Parameters

batch_size (Optional[int]) – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.

Returns

Arbitrarily nested initial state for this core.

DeepRNN

class haiku.DeepRNN(layers, name=None)[source]

Wraps a sequence of cores and callables as a single core.

>>> deep_rnn = hk.DeepRNN([
...     hk.LSTM(hidden_size=4),
...     jax.nn.relu,
...     hk.LSTM(hidden_size=2),
... ])

The state of a DeepRNN is a tuple with one element per RNNCore. If no layers are RNNCores, the state is an empty tuple.

__init__(layers, name=None)[source]

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

haiku.deep_rnn_with_skip_connections(layers, name=None)[source]

Constructs a DeepRNN with skip connections.

Skip connections alter the dependency structure within a DeepRNN. Specifically, input to the i-th layer (i > 0) is given by a concatenation of the core’s inputs and the outputs of the (i-1)-th layer.

The output of the DeepRNN is the concatenation of the outputs of all cores.

outputs0, ... = layers[0](inputs, ...)
outputs1, ... = layers[1](tf.concat([inputs, outputs0], axis=-1], ...)
outputs2, ... = layers[2](tf.concat([inputs, outputs1], axis=-1], ...)
...
Parameters
  • layers (Sequence[RNNCore]) – List of RNNCores.

  • name (Optional[str]) – Name of the module.

Return type

RNNCore

Returns

A _DeepRNN with skip connections.

Raises

ValueError – If any of the layers is not an RNNCore.

ResetCore

class haiku.ResetCore(core, name=None)[source]

A wrapper for managing state resets during unrolls.

When unrolling an RNNCore on a batch of inputs sequences it may be necessary to reset the core’s state at different timesteps for different elements of the batch. The ResetCore class enables this by taking a batch of should_reset booleans in addition to the batch of inputs, and conditionally resetting the core’s state for individual elements of the batch. You may also reset individual entries of the state by passing a should_reset nest compatible with the state structure.

__init__(core, name=None)[source]

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

__call__(inputs, state)[source]

Run one step of the wrapped core, handling state reset.

Parameters
  • inputs – Tuple with two elements, inputs, should_reset, where should_reset is the signal used to reset the wrapped core’s state. should_reset can be either tensor or nest. If nest, should_reset must match the state structure, and its components’ shapes must be prefixes of the corresponding entries tensors’ shapes in the state nest. If tensor, supported shapes are all commom shape prefixes of the state component tensors, e.g. [batch_size].

  • state – Previous wrapped core state.

Returns

Tuple of the wrapped core’s output, next_state.

initial_state(batch_size)[source]

Constructs an initial state for this core.

Parameters

batch_size (Optional[int]) – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.

Returns

Arbitrarily nested initial state for this core.

IdentityCore

class haiku.IdentityCore(name=None)[source]

A recurrent core that forwards the inputs and an empty state.

This is commonly used when switching between recurrent and feedforward versions of a model while preserving the same interface.

__call__(inputs, state)[source]

Run one step of the RNN.

Parameters
  • inputs – An arbitrarily nested structure.

  • prev_state – Previous core state.

Returns

A tuple with two elements output, next_state. output is an arbitrarily nested structure. next_state is the next core state, this must be the same shape as prev_state.

initial_state(batch_size)[source]

Constructs an initial state for this core.

Parameters

batch_size (Optional[int]) – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.

Returns

Arbitrarily nested initial state for this core.

Conv1DLSTM

class haiku.Conv1DLSTM(input_shape, output_channels, kernel_shape, name=None)[source]

1-D convolutional LSTM.

The implementation is based on [3]. Given \(x_t\) and the previous state \((h_{t-1}, c_{t-1})\) the core computes

\[\begin{array}{ll} i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\ f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\ o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\]

where \(*\) denotes the convolution operator; \(i_t\), \(f_t\), \(o_t\) are input, forget and output gate activations, and \(g_t\) is a vector of cell updates.

Notes

Forget gate initialization:

Following [2] we add 1.0 to \(b_f\) after initialization in order to reduce the scale of forgetting in the beginning of the training.

__init__(input_shape, output_channels, kernel_shape, name=None)[source]

Constructs a 1-D convolutional LSTM.

Parameters
  • input_shape (Sequence[int]) – Shape of the inputs excluding batch size.

  • output_channels (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – Sequence of kernel sizes (of length 1), or an int. kernel_shape will be expanded to define a kernel size in all dimensions.

  • name (Optional[str]) – Name of the module.

Conv2DLSTM

class haiku.Conv2DLSTM(input_shape, output_channels, kernel_shape, name=None)[source]

2-D convolutional LSTM.

The implementation is based on [3]. Given \(x_t\) and the previous state \((h_{t-1}, c_{t-1})\) the core computes

\[\begin{array}{ll} i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\ f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\ o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\]

where \(*\) denotes the convolution operator; \(i_t\), \(f_t\), \(o_t\) are input, forget and output gate activations, and \(g_t\) is a vector of cell updates.

Notes

Forget gate initialization:

Following [2] we add 1.0 to \(b_f\) after initialization in order to reduce the scale of forgetting in the beginning of the training.

__init__(input_shape, output_channels, kernel_shape, name=None)[source]

Constructs a 2-D convolutional LSTM.

Parameters
  • input_shape (Sequence[int]) – Shape of the inputs excluding batch size.

  • output_channels (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – Sequence of kernel sizes (of length 2), or an int. kernel_shape will be expanded to define a kernel size in all dimensions.

  • name (Optional[str]) – Name of the module.

Conv3DLSTM

class haiku.Conv3DLSTM(input_shape, output_channels, kernel_shape, name=None)[source]

3-D convolutional LSTM.

The implementation is based on [3]. Given \(x_t\) and the previous state \((h_{t-1}, c_{t-1})\) the core computes

\[\begin{array}{ll} i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\ f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\ o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\]

where \(*\) denotes the convolution operator; \(i_t\), \(f_t\), \(o_t\) are input, forget and output gate activations, and \(g_t\) is a vector of cell updates.

Notes

Forget gate initialization:

Following [2] we add 1.0 to \(b_f\) after initialization in order to reduce the scale of forgetting in the beginning of the training.

__init__(input_shape, output_channels, kernel_shape, name=None)[source]

Constructs a 3-D convolutional LSTM.

Parameters
  • input_shape (Sequence[int]) – Shape of the inputs excluding batch size.

  • output_channels (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – Sequence of kernel sizes (of length 3), or an int. kernel_shape will be expanded to define a kernel size in all dimensions.

  • name (Optional[str]) – Name of the module.

Batch modules

Reshape

class haiku.Reshape(output_shape, preserve_dims=1, name=None)[source]

Reshapes input Tensor, preserving the batch dimension.

For example, given an input tensor with shape [B, H, W, C, D]:

>>> B, H, W, C, D = range(1, 6)
>>> x = jnp.ones([B, H, W, C, D])

The default behavior when output_shape is (-1, D) is to flatten all dimensions between B and D:

>>> mod = hk.Reshape(output_shape=(-1, D))
>>> assert mod(x).shape == (B, H*W*C, D)

You can change the number of preserved leading dimensions via preserve_dims:

>>> mod = hk.Reshape(output_shape=(-1, D), preserve_dims=2)
>>> assert mod(x).shape == (B, H, W*C, D)

>>> mod = hk.Reshape(output_shape=(-1, D), preserve_dims=3)
>>> assert mod(x).shape == (B, H, W, C, D)

>>> mod = hk.Reshape(output_shape=(-1, D), preserve_dims=4)
>>> assert mod(x).shape == (B, H, W, C, 1, D)
__init__(output_shape, preserve_dims=1, name=None)[source]

Constructs a Reshape module.

Parameters
  • output_shape (Sequence[int]) – Shape to reshape the input tensor to while preserving its first preserve_dims dimensions. When the special value -1 appears in output_shape the corresponding size is automatically inferred. Note that -1 can only appear once in output_shape. To flatten all non-batch dimensions use Flatten.

  • preserve_dims (int) – Number of leading dimensions that will not be reshaped.

  • name (Optional[str]) – Name of the module.

Raises

ValueError – If preserve_dims is not positive.

__call__(inputs)[source]

Call self as a function.

Flatten

class haiku.Flatten(preserve_dims=1, name=None)[source]

Flattens the input, preserving the batch dimension(s).

By default, Flatten combines all dimensions except the first. Additional leading dimensions can be preserved by setting preserve_dims.

>>> x = jnp.ones([3, 2, 4])
>>> flat = hk.Flatten()
>>> flat(x).shape
(3, 8)

When the input to flatten has fewer than preserve_dims dimensions it is returned unchanged:

>>> x = jnp.ones([3])
>>> flat(x).shape
(3,)
__init__(preserve_dims=1, name=None)[source]

Constructs a Reshape module.

Parameters
  • output_shape – Shape to reshape the input tensor to while preserving its first preserve_dims dimensions. When the special value -1 appears in output_shape the corresponding size is automatically inferred. Note that -1 can only appear once in output_shape. To flatten all non-batch dimensions use Flatten.

  • preserve_dims (int) – Number of leading dimensions that will not be reshaped.

  • name (Optional[str]) – Name of the module.

Raises

ValueError – If preserve_dims is not positive.

BatchApply

class haiku.BatchApply(f, num_dims=2)[source]

Temporarily merges leading dimensions of input tensors.

Merges the leading dimensions of a tensor into a single dimension, runs the given callable, then splits the leading dimension of the result to match the input.

Input arrays whose rank is smaller than the number of dimensions to collapse are passed unmodified.

This may be useful for applying a module to each timestep of e.g. a [Time, Batch, ...] array.

For some fs and platforms, this may be more efficient than jax.vmap, especially when combined with other transformations like jax.grad.

__init__(f, num_dims=2)[source]

Constructs a BatchApply module.

Parameters
  • f – The callable to be applied to the reshaped array.

  • num_dims – The number of dimensions to merge.

__call__(*args, **kwargs)[source]

Call self as a function.

Embedding modules

Embed

class haiku.Embed(vocab_size=None, embed_dim=None, embedding_matrix=None, w_init=None, lookup_style='ARRAY_INDEX', name=None)[source]

Module for embedding tokens in a low-dimensional space.

__init__(vocab_size=None, embed_dim=None, embedding_matrix=None, w_init=None, lookup_style='ARRAY_INDEX', name=None)[source]

Constructs an Embed module.

Parameters
  • vocab_size (Optional[int]) – The number of unique tokens to embed. If not provided, an existing vocabulary matrix from which vocab_size can be inferred must be provided as existing_vocab.

  • embed_dim (Optional[int]) – Number of dimensions to assign to each embedding. If an existing vocabulary matrix initializes the module, this should not be provided as it will be inferred.

  • embedding_matrix (Optional[jnp.ndarray]) – A matrix-like object equivalent in size to [vocab_size, embed_dim]. If given, it is used as the initial value for the embedding matrix and neither vocab_size or embed_dim need be given. If they are given, their values are checked to be consistent with the dimensions of embedding_matrix.

  • w_init (Optional[hk.initializers.Initializer]) – An initializer for the embeddings matrix. As a default, embeddings are initialized via a truncated normal distribution.

  • lookup_style (Union[str, hk.EmbedLookupStyle]) – One of the enum values of EmbedLookupStyle determining how to access the value of the embbeddings given an ID. Regardless the input should be a dense array of integer values representing ids. This setting changes how internally this module maps those ides to embeddings. The result is the same, but the speed and memory tradeoffs are different. It default to using numpy-style array indexing. This value is only the default for the module, and at any given invocation can be overriden in __call__().

  • name (Optional[str]) – Optional name for this module.

Raises

ValueError – If none of embed_dim, embedding_matrix and vocab_size are supplied, or if embedding_matrix is supplied and embed_dim or vocab_size is not consistent with the supplied matrix.

__call__(ids, lookup_style=None)[source]

Lookup embeddings.

Looks up an embedding vector for each value in ids. All ids must be within [0, vocab_size) to prevent NaNs from propagating.

Parameters
  • ids (jnp.ndarray) – integer array.

  • lookup_style (Optional[Union[str, hk.EmbedLookupStyle]]) – Overrides the lookup_style given in the constructor.

Return type

jnp.ndarray

Returns

Tensor of ids.shape + [embedding_dim].

Raises
  • AttributeError – If lookup_style is not valid.

  • ValueError – If ids is not an integer array.

EmbedLookupStyle

class haiku.EmbedLookupStyle(value)[source]

How to return the embedding matrices given IDs.

ARRAY_INDEX = 1
ONE_HOT = 2

Initializers

Initializer

haiku.initializers.Initializer

Constant

class haiku.initializers.Constant(*args, **kwds)[source]

Initializes with a constant.

__init__(constant)[source]

Constructs a Constant initializer.

Parameters

constant – Constant to initialize with.

__call__(shape, dtype)[source]

Call self as a function.

Return type

jnp.ndarray

Identity

class haiku.initializers.Identity(gain=1.0)[source]

Initializer that generates the identity matrix.

Constructs a 2D identity matrix or batches of these.

__init__(gain=1.0)[source]

Constructs an Identity initializer.

Parameters

gain (Union[float, jnp.ndarray]) – Multiplicative factor to apply to the identity matrix.

__call__(shape, dtype)[source]

Call self as a function.

Return type

jnp.ndarray

Orthogonal

class haiku.initializers.Orthogonal(*args, **kwds)[source]

Uniform scaling initializer.

__init__(scale=1.0, axis=- 1)[source]

Construct an initializer for uniformly distributed orthogonal matrices.

These matrices will be row-orthonormal along the access specified by axis. If the rank of the weight is greater than 2, the shape will be flattened in all other dimensions and then will be row-orthonormal along the final dimension. Note that this only works if the axis dimension is larger, otherwise the matrix will be transposed (equivalently, it will be column orthonormal instead of row orthonormal).

If the shape is not square, the matrices will have orthonormal rows or columns depending on which side is smaller.

Parameters
  • scale – Scale factor.

  • axis – Which axis corresponds to the “output dimension” of the tensor.

Returns

An orthogonally initialized parameter.

__call__(shape, dtype)[source]

Call self as a function.

Return type

jnp.ndarray

RandomNormal

class haiku.initializers.RandomNormal(*args, **kwds)[source]

Initializes by sampling from a normal distribution.

__init__(stddev=1.0, mean=0.0)[source]

Constructs a RandomNormal initializer.

Parameters
  • stddev – The standard deviation of the normal distribution to sample from.

  • mean – The mean of the normal distribution to sample from.

__call__(shape, dtype)[source]

Call self as a function.

Return type

jnp.ndarray

RandomUniform

class haiku.initializers.RandomUniform(*args, **kwds)[source]

Initializes by sampling from a uniform distribution.

__init__(minval=0.0, maxval=1.0)[source]

Constructs a RandomUniform initializer.

Parameters
  • minval – The lower limit of the uniform distribution.

  • maxval – The upper limit of the uniform distribution.

__call__(shape, dtype)[source]

Call self as a function.

Return type

jnp.ndarray

TruncatedNormal

class haiku.initializers.TruncatedNormal(*args, **kwds)[source]

Initializes by sampling from a truncated normal distribution.

__init__(stddev=1.0, mean=0.0)[source]

Constructs a TruncatedNormal initializer.

Parameters
  • stddev – The standard deviation parameter of the truncated normal distribution.

  • mean – The mean of the truncated normal distribution.

__call__(shape, dtype)[source]

Call self as a function.

Return type

jnp.ndarray

VarianceScaling

class haiku.initializers.VarianceScaling(*args, **kwds)[source]

Initializer which adapts its scale to the shape of the initialized array.

The initializer first computes the scaling factor s = scale / n, where n is:

  • Number of input units in the weight tensor, if mode = fan_in.

  • Number of output units, if mode = fan_out.

  • Average of the numbers of input and output units, if mode = fan_avg.

Then, with distribution="truncated_normal" or "normal", samples are drawn from a distribution with a mean of zero and a standard deviation (after truncation, if used) stddev = sqrt(s).

With distribution=uniform, samples are drawn from a uniform distribution within [-limit, limit], with limit = sqrt(3 * s).

The variance scaling initializer can be configured to generate other standard initializers using the scale, mode and distribution arguments. Here are some example configurations:

Name

Parameters

glorot_uniform

VarianceScaling(1.0, “fan_avg”, “uniform”)

glorot_normal

VarianceScaling(1.0, “fan_avg”, “truncated_normal”)

lecun_uniform

VarianceScaling(1.0, “fan_in”, “uniform”)

lecun_normal

VarianceScaling(1.0, “fan_in”, “truncated_normal”)

he_uniform

VarianceScaling(2.0, “fan_in”, “uniform”)

he_normal

VarianceScaling(2.0, “fan_in”, “truncated_normal”)

__init__(scale=1.0, mode='fan_in', distribution='truncated_normal')[source]

Constructs the VarianceScaling initializer.

Parameters
  • scale – Scale to multiply the variance by.

  • mode – One of fan_in, fan_out, fan_avg

  • distribution – Random distribution to use. One of truncated_normal, normal or uniform.

__call__(shape, dtype)[source]

Call self as a function.

Return type

jnp.ndarray

UniformScaling

class haiku.initializers.UniformScaling(*args, **kwds)[source]

Uniform scaling initializer.

Initializes by sampling from a uniform distribution, but with the variance scaled by the inverse square root of the number of input units, multiplied by the scale.

__init__(scale=1.0)[source]

Constructs the UniformScaling initializer.

Parameters

scale – Scale to multiply the upper limit of the uniform distribution by.

__call__(shape, dtype)[source]

Call self as a function.

Return type

jnp.ndarray

Paddings

PadFn

haiku.pad.PadFn

create

haiku.pad.create(padding, kernel, rate, n)[source]

Generates the padding required for a given padding algorithm.

Parameters
  • padding (Union[hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – callable or list of callables of length n. The callables take an integer representing the effective kernel size (kernel size when the rate is 1) and return a list of two integers representing the padding before and padding after for that dimension.

  • kernel (Union[int, Sequence[int]]) – int or list of ints of length n. The size of the kernel for each dimension. If it is an int it will be replicated for the non channel and batch dimensions.

  • rate (Union[int, Sequence[int]]) – int or list of ints of length n. The dilation rate for each dimension. If it is an int it will be replicated for the non channel and batch dimensions.

  • n (int) – the number of spatial dimensions.

Return type

Sequence[Tuple[int, int]]

Returns

A sequence of length n containing the padding for each element. These are of the form [pad_before, pad_after].

causal

haiku.pad.causal(effective_kernel_size)[source]

Pre-padding such that output has no dependence on the future.

Return type

Tuple[int, int]

full

haiku.pad.full(effective_kernel_size)[source]

Maximal padding whilst not convolving over just padded elements.

Return type

Tuple[int, int]

reverse_causal

haiku.pad.reverse_causal(effective_kernel_size)[source]

Post-padding such that output has no dependence on the past.

Return type

Tuple[int, int]

same

haiku.pad.same(effective_kernel_size)[source]

Pads such that the output size matches input size for stride=1.

Return type

Tuple[int, int]

valid

haiku.pad.valid(effective_kernel_size)[source]

No padding.

Return type

Tuple[int, int]

Networks

MLP

class haiku.nets.MLP(output_sizes, w_init=None, b_init=None, with_bias=True, activation=<jax.custom_derivatives.custom_jvp object>, activate_final=False, name=None)[source]

A multi-layer perceptron module.

__init__(output_sizes, w_init=None, b_init=None, with_bias=True, activation=<jax.custom_derivatives.custom_jvp object>, activate_final=False, name=None)[source]

Constructs an MLP.

Parameters
  • output_sizes (Iterable[int]) – Sequence of layer sizes.

  • w_init (Optional[hk.initializers.Initializer]) – Initializer for Linear weights.

  • b_init (Optional[hk.initializers.Initializer]) – Initializer for Linear bias. Must be None if with_bias=False.

  • with_bias (bool) – Whether or not to apply a bias in each layer.

  • activation (Callable[[jnp.ndarray], jnp.ndarray]) – Activation function to apply between Linear layers. Defaults to ReLU.

  • activate_final (bool) – Whether or not to activate the final layer of the MLP.

  • name (Optional[str]) – Optional name for this module.

Raises

ValueError – If with_bias is False and b_init is not None.

__call__(inputs, dropout_rate=None, rng=None)[source]

Connects the module to some inputs.

Parameters
  • inputs (jnp.ndarray) – A Tensor of shape [batch_size, input_size].

  • dropout_rate (Optional[float]) – Optional dropout rate.

  • rng – Optional RNG key. Require when using dropout.

Return type

jnp.ndarray

Returns

The output of the model of size [batch_size, output_size].

reverse(activate_final=None, name=None)[source]

Returns a new MLP which is the layer-wise reverse of this MLP.

NOTE: Since computing the reverse of an MLP requires knowing the input size of each linear layer this method will fail if the module has not been called at least once.

The contract of reverse is that the reversed module will accept the output of the parent module as input and produce an output which is the input size of the parent.

>>> mlp = hk.nets.MLP([1, 2, 3])
>>> y = mlp(jnp.ones([1, 2]))
>>> rev = mlp.reverse()
>>> rev(y)
DeviceArray(...)
Parameters
  • activate_final (Optional[bool]) – Whether the final layer of the MLP should be activated.

  • name (Optional[str]) – Optional name for the new module. The default name will be the name of the current module prefixed with "reversed_".

Return type

‘MLP’

Returns

An MLP instance which is the reverse of the current instance. Note these instances do not share weights and, apart from being symmetric to each other, are not coupled in any way.

MobileNetV1

class haiku.nets.MobileNetV1(strides=1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1, channels=64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024, num_classes=1000, use_bn=True, name=None)[source]

MobileNetV1 model.

__init__(strides=1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1, channels=64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024, num_classes=1000, use_bn=True, name=None)[source]

Constructs a MobileNetV1 model.

Parameters
  • strides (Sequence[int]) – The stride to use the in depthwise convolution in each mobilenet block.

  • channels (Sequence[int]) – Number of output channels from the pointwise convolution to use in each block.

  • num_classes (int) – Number of classes.

  • use_bn (bool) – Whether or not to use batch normalization. Defaults to True. When true, biases are not used. When false, biases are used.

  • name (Optional[str]) – Name of the module.

__call__(inputs, is_training)[source]

Call self as a function.

Return type

jnp.ndarray

ResNet

class haiku.nets.ResNet(blocks_per_group, num_classes, bn_config=None, resnet_v2=False, bottleneck=True, channels_per_group=256, 512, 1024, 2048, use_projection=True, True, True, True, name=None)[source]

ResNet model.

class BlockGroup(channels, num_blocks, stride, bn_config, resnet_v2, bottleneck, use_projection, name=None)[source]

Higher level block for ResNet implementation.

__call__(inputs, is_training, test_local_stats)[source]

Call self as a function.

__init__(channels, num_blocks, stride, bn_config, resnet_v2, bottleneck, use_projection, name=None)[source]

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

class BlockV1(channels, stride, use_projection, bn_config, bottleneck, name=None)[source]

ResNet V1 block with optional bottleneck.

__call__(inputs, is_training, test_local_stats)[source]

Call self as a function.

__init__(channels, stride, use_projection, bn_config, bottleneck, name=None)[source]

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

class BlockV2(channels, stride, use_projection, bn_config, bottleneck, name=None)[source]

ResNet V2 block with optional bottleneck.

__call__(inputs, is_training, test_local_stats)[source]

Call self as a function.

__init__(channels, stride, use_projection, bn_config, bottleneck, name=None)[source]

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

__init__(blocks_per_group, num_classes, bn_config=None, resnet_v2=False, bottleneck=True, channels_per_group=256, 512, 1024, 2048, use_projection=True, True, True, True, name=None)[source]

Constructs a ResNet model.

Parameters
  • blocks_per_group (Sequence[int]) – A sequence of length 4 that indicates the number of blocks created in each group.

  • num_classes (int) – The number of classes to classify the inputs into.

  • bn_config (Optional[Mapping[str, float]]) – A dictionary of two elements, decay_rate and eps to be passed on to the BatchNorm layers. By default the decay_rate is 0.9 and eps is 1e-5.

  • resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to False.

  • bottleneck (bool) – Whether the block should bottleneck or not. Defaults to True.

  • channels_per_group (Sequence[int]) – A sequence of length 4 that indicates the number of channels used for each block in each group.

  • use_projection (Sequence[bool]) – A sequence of length 4 that indicates whether each residual block should use projection.

  • name (Optional[str]) – Name of the module.

__call__(inputs, is_training, test_local_stats=False)[source]

Call self as a function.

ResNet18

class haiku.nets.ResNet18(num_classes, bn_config=None, resnet_v2=False, name=None)[source]

ResNet18.

__init__(num_classes, bn_config=None, resnet_v2=False, name=None)[source]

Constructs a ResNet model.

Parameters
  • num_classes (int) – The number of classes to classify the inputs into.

  • bn_config (Optional[Mapping[str, float]]) – A dictionary of two elements, decay_rate and eps to be passed on to the BatchNorm layers.

  • resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to False.

  • name (Optional[str]) – Name of the module.

ResNet34

class haiku.nets.ResNet34(num_classes, bn_config=None, resnet_v2=False, name=None)[source]

ResNet34.

__init__(num_classes, bn_config=None, resnet_v2=False, name=None)[source]

Constructs a ResNet model.

Parameters
  • num_classes (int) – The number of classes to classify the inputs into.

  • bn_config (Optional[Mapping[str, float]]) – A dictionary of two elements, decay_rate and eps to be passed on to the BatchNorm layers.

  • resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to False.

  • name (Optional[str]) – Name of the module.

ResNet50

class haiku.nets.ResNet50(num_classes, bn_config=None, resnet_v2=False, name=None)[source]

ResNet50.

__init__(num_classes, bn_config=None, resnet_v2=False, name=None)[source]

Constructs a ResNet model.

Parameters
  • num_classes (int) – The number of classes to classify the inputs into.

  • bn_config (Optional[Mapping[str, float]]) – A dictionary of two elements, decay_rate and eps to be passed on to the BatchNorm layers.

  • resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to False.

  • name (Optional[str]) – Name of the module.

ResNet101

class haiku.nets.ResNet101(num_classes, bn_config=None, resnet_v2=False, name=None)[source]

ResNet101.

__init__(num_classes, bn_config=None, resnet_v2=False, name=None)[source]

Constructs a ResNet model.

Parameters
  • num_classes (int) – The number of classes to classify the inputs into.

  • bn_config (Optional[Mapping[str, float]]) – A dictionary of two elements, decay_rate and eps to be passed on to the BatchNorm layers.

  • resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to False.

  • name (Optional[str]) – Name of the module.

ResNet152

class haiku.nets.ResNet152(num_classes, bn_config=None, resnet_v2=False, name=None)[source]

ResNet152.

__init__(num_classes, bn_config=None, resnet_v2=False, name=None)[source]

Constructs a ResNet model.

Parameters
  • num_classes (int) – The number of classes to classify the inputs into.

  • bn_config (Optional[Mapping[str, float]]) – A dictionary of two elements, decay_rate and eps to be passed on to the BatchNorm layers.

  • resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to False.

  • name (Optional[str]) – Name of the module.

ResNet200

class haiku.nets.ResNet200(num_classes, bn_config=None, resnet_v2=False, name=None)[source]

ResNet200.

__init__(num_classes, bn_config=None, resnet_v2=False, name=None)[source]

Constructs a ResNet model.

Parameters
  • num_classes (int) – The number of classes to classify the inputs into.

  • bn_config (Optional[Mapping[str, float]]) – A dictionary of two elements, decay_rate and eps to be passed on to the BatchNorm layers.

  • resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to False.

  • name (Optional[str]) – Name of the module.

VectorQuantizer

class haiku.nets.VectorQuantizer(embedding_dim, num_embeddings, commitment_cost, dtype=<class 'jax.numpy.lax_numpy.float32'>, name=None)[source]

Haiku module representing the VQ-VAE layer.

Implements the algorithm presented in “Neural Discrete Representation Learning” by van den Oord et al. https://arxiv.org/abs/1711.00937

Input any tensor to be quantized. Last dimension will be used as space in which to quantize. All other dimensions will be flattened and will be seen as different examples to quantize.

The output tensor will have the same shape as the input.

For example a tensor with shape [16, 32, 32, 64] will be reshaped into [16384, 64] and all 16384 vectors (each of 64 dimensions) will be quantized independently.

embedding_dim

integer representing the dimensionality of the tensors in the quantized space. Inputs to the modules must be in this format as well.

num_embeddings

integer, the number of vectors in the quantized space.

commitment_cost

scalar which controls the weighting of the loss terms (see equation 4 in the paper - this variable is Beta).

__init__(embedding_dim, num_embeddings, commitment_cost, dtype=<class 'jax.numpy.lax_numpy.float32'>, name=None)[source]

Initializes a VQ-VAE module.

Parameters
  • embedding_dim (int) – dimensionality of the tensors in the quantized space. Inputs to the modules must be in this format as well.

  • num_embeddings (int) – number of vectors in the quantized space.

  • commitment_cost (float) – scalar which controls the weighting of the loss terms (see equation 4 in the paper - this variable is Beta).

  • dtype (Any) – dtype for the embeddings variable, defaults to float32.

  • name (str) – name of the module.

__call__(inputs, is_training)[source]

Connects the module to some inputs.

Parameters
  • inputs – Tensor, final dimension must be equal to embedding_dim. All other leading dimensions will be flattened and treated as a large batch.

  • is_training – boolean, whether this connection is to training data.

Returns

Dictionary containing the following keys and values:
  • quantize: Tensor containing the quantized version of the input.

  • loss: Tensor containing the loss to optimize.

  • perplexity: Tensor containing the perplexity of the encodings.

  • encodings: Tensor containing the discrete encodings, ie which element of the quantized space each input element was mapped to.

  • encoding_indices: Tensor containing the discrete encoding indices, ie which element of the quantized space each input element was mapped to.

Return type

dict

quantize(encoding_indices)[source]

Returns embedding tensor for a batch of indices.

VectorQuantizerEMA

class haiku.nets.VectorQuantizerEMA(embedding_dim, num_embeddings, commitment_cost, decay, epsilon=1e-05, dtype=<class 'jax.numpy.lax_numpy.float32'>, name=None)[source]

Haiku module representing the VQ-VAE layer.

Implements a slightly modified version of the algorithm presented in “Neural Discrete Representation Learning” by van den Oord et al. https://arxiv.org/abs/1711.00937

The difference between VectorQuantizerEMA and VectorQuantizer is that this module uses ExponentialMovingAverages to update the embedding vectors instead of an auxiliary loss. This has the advantage that the embedding updates are independent of the choice of optimizer (SGD, RMSProp, Adam, K-Fac, …) used for the encoder, decoder and other parts of the architecture. For most experiments the EMA version trains faster than the non-EMA version.

Input any tensor to be quantized. Last dimension will be used as space in which to quantize. All other dimensions will be flattened and will be seen as different examples to quantize.

The output tensor will have the same shape as the input.

For example a tensor with shape [16, 32, 32, 64] will be reshaped into [16384, 64] and all 16384 vectors (each of 64 dimensions) will be quantized independently.

embedding_dim

integer representing the dimensionality of the tensors in the quantized space. Inputs to the modules must be in this format as well.

num_embeddings

integer, the number of vectors in the quantized space.

commitment_cost

scalar which controls the weighting of the loss terms (see equation 4 in the paper).

decay

float, decay for the moving averages.

epsilon

small float constant to avoid numerical instability.

__init__(embedding_dim, num_embeddings, commitment_cost, decay, epsilon=1e-05, dtype=<class 'jax.numpy.lax_numpy.float32'>, name=None)[source]

Initializes a VQ-VAE EMA module.

Parameters
  • embedding_dim – integer representing the dimensionality of the tensors in the quantized space. Inputs to the modules must be in this format as well.

  • num_embeddings – integer, the number of vectors in the quantized space.

  • commitment_cost – scalar which controls the weighting of the loss terms (see equation 4 in the paper - this variable is Beta).

  • decay – float between 0 and 1, controls the speed of the Exponential Moving Averages.

  • epsilon (float) – small constant to aid numerical stability, default 1e-5.

  • dtype (Any) – dtype for the embeddings variable, defaults to float32.

  • name (str) – name of the module.

__call__(inputs, is_training)[source]

Connects the module to some inputs.

Parameters
  • inputs – Tensor, final dimension must be equal to embedding_dim. All other leading dimensions will be flattened and treated as a large batch.

  • is_training – boolean, whether this connection is to training data. When this is set to False, the internal moving average statistics will not be updated.

Returns

Dictionary containing the following keys and values:
  • quantize: Tensor containing the quantized version of the input.

  • loss: Tensor containing the loss to optimize.

  • perplexity: Tensor containing the perplexity of the encodings.

  • encodings: Tensor containing the discrete encodings, ie which element of the quantized space each input element was mapped to.

  • encoding_indices: Tensor containing the discrete encoding indices, ie which element of the quantized space each input element was mapped to.

Return type

dict

quantize(encoding_indices)[source]

Returns embedding tensor for a batch of indices.

JAX Transforms

cond

haiku.cond(pred, true_operand, true_fun, false_operand, false_fun)[source]

Equivalent to jax.lax.cond but with Haiku state threaded in and out.

grad

haiku.grad(fun, argnums=0, has_aux=False, holomorphic=False)[source]

Creates a function which evaluates the gradient of fun.

NOTE: You only need this in a very specific case that you want to take a gradient inside a transform()ed function and the function you are differentiating uses set_state(). For example:

>>> class MyModule(hk.Module):
...   def __call__(self, x):
...     hk.set_state("last", x ** 2)
...     return x ** 2
>>> def f(x):
...   m = MyModule()
...   g = hk.grad(m)(x)
...   return g
>>> f = hk.transform_with_state(f)
>>> x = jnp.array(2.)
>>> params, state = jax.jit(f.init)(None, x)
>>> state["my_module"]["last"]
DeviceArray(4., dtype=float32)
Parameters
  • fun – Function to be differentiated. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape () but not arrays with shape (1,) etc.)

  • argnums – Optional, integer or tuple of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

  • has_aux – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • holomorphic – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.

Returns

A function with the same arguments as fun, that evaluates the gradient of fun. If argnums is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. If has_aux is True then a pair of gradient, auxiliary_data is returned.

For example:

>>> grad_tanh = jax.grad(jax.numpy.tanh)
>>> print(grad_tanh(0.2))
0.96...

jit

haiku.jit(fun, *dec_args, **dec_kwargs)[source]

Decorates a modified version of fun that passes Haiku state.

remat

haiku.remat(fun, *dec_args, **dec_kwargs)[source]

Decorates a modified version of fun that passes Haiku state.

scan

haiku.scan(f, init, xs, length=None, reverse=False)[source]

Equivalent to jax.lax.scan but with Haiku state threaded in and out.

value_and_grad

haiku.value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False)[source]

Creates a function which evaluates both fun and the grad of fun.

NOTE: You only need this in a very specific case that you want to take a gradient inside a transform()ed function and the function you are differentiating uses set_state(). For example:

>>> class MyModule(hk.Module):
...   def __call__(self, x):
...     hk.set_state("last", jnp.sum(x))
...     return x ** 2
>>> def f(x):
...   m = MyModule()
...   y, g = hk.value_and_grad(m)(x)
...   return y, g
>>> f = hk.transform_with_state(f)
>>> x = jnp.array(2.)
>>> _ = jax.jit(f.init)(None, x)
Parameters
  • fun – Function to be differentiated. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape () but not arrays with shape (1,) etc.)

  • argnums – Optional, integer or tuple of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

  • has_aux – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • holomorphic – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.

Returns

A function with the same arguments as fun that evaluates both fun and the gradient of fun and returns them as a pair (a two-element tuple). If argnums is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments.

Testing

transform_and_run

haiku.testing.transform_and_run(f=None, seed=42, run_apply=True)[source]

Transforms the given function and runs init then (optionally) apply.

Equivalent to:

>>> def f(x):
...   return x
>>> x = jnp.ones([])
>>> rng = jax.random.PRNGKey(42)
>>> f = hk.transform_with_state(f)
>>> params, state = f.init(rng, x)
>>> out = f.apply(params, state, rng, x)

This function makes it very convenient to unit test Haiku:

>>> class MyTest(unittest.TestCase):
...   @hk.testing.transform_and_run
...   def test_linear_output(self):
...     mod = hk.Linear(1)
...     out = mod(jnp.ones([1, 1]))
...     self.assertEqual(out.ndim, 2)

And can also be useful in an interactive environment like ipython, Jupyter or Google Colaboratory:

>>> f = lambda x: hk.Bias()(x)
>>> hk.testing.transform_and_run(f)(jnp.ones([1, 1]))
DeviceArray([[1.]], dtype=float32)

See transform() for more details.

Parameters
  • f – A function method to transform.

  • seed (Optional[int]) – A seed to pass to init and apply.

  • run_apply (bool) – Whether to run apply as well as init. Defaults to true.

Returns

A function that transform()s f and runs init and optionally apply.

Data structures

filter

haiku.data_structures.filter(predicate, structure)[source]

Filters a input structure according to a user specified predicate.

>>> params = {'linear': {'w': None, 'b': None}}
>>> predicate = lambda module_name, name, value: name == 'w'
>>> hk.data_structures.filter(predicate, params)
frozendict({'linear': frozendict({'w': None})})

Note: returns a new structure not a view.

Parameters
  • predicate (Callable[[str, str, jnp.ndarray], bool]) – criterion to be used to partition the input data. The predicate argument is expected to be a boolean function taking as inputs the name of the module, the name of a given entry in the module data bundle (e.g. parameter name) and the corresponding data.

  • structure (T) – Haiku params or state data structure to be filtered.

Return type

T

Returns

All the input parameters or state as selected by the input predicate.

merge

haiku.data_structures.merge(*structures)[source]

Merges multiple input structures.

>>> weights = {'linear': {'w': None}}
>>> biases = {'linear': {'b': None}}
>>> hk.data_structures.merge(weights, biases)
frozendict({'linear': frozendict({'b': None, 'w': None})})

When structures are not disjoint the output will contain the value from the last structure for each path:

>>> weights1 = {'linear': {'w': 1}}
>>> weights2 = {'linear': {'w': 2}}
>>> hk.data_structures.merge(weights1, weights2)
frozendict({'linear': frozendict({'w': 2})})

Note: returns a new structure not a view.

Parameters

*structures – One or more structures to merge.

Return type

T

Returns

A single structure with an entry for each path in the input structures.

partition

haiku.data_structures.partition(predicate, structure)[source]

Partitions the input structure in two according to a given predicate.

For a given set of parameters, you can use partition() to split them:

>>> params = {'linear': {'w': None, 'b': None}}
>>> predicate = lambda module_name, name, value: name == 'w'
>>> weights, biases = hk.data_structures.partition(predicate, params)
>>> weights
frozendict({'linear': frozendict({'w': None})})
>>> biases
frozendict({'linear': frozendict({'b': None})})

Note: returns new structures not a view.

Parameters
  • predicate (Callable[[str, str, jnp.ndarray], bool]) – criterion to be used to partition the input data. The predicate argument is expected to be a boolean function taking as inputs the name of the module, the name of a given entry in the module data bundle (e.g. parameter name) and the corresponding data.

  • structure (T) – Haiku params or state data structure to be partitioned.

Return type

Tuple[T, T]

Returns

A tuple containing all the params or state as partitioned by the input

predicate. Entries matching the predicate will be in the first structure, and the rest will be in the second.

to_immutable_dict

haiku.data_structures.to_immutable_dict(mapping)[source]
Return type

Mapping[K, V]

to_mutable_dict

haiku.data_structures.to_mutable_dict(mapping)[source]

Turns an immutable frozendict into a mutable dict.

tree_bytes

haiku.data_structures.tree_bytes(tree)[source]

Sums the size in bytes of all arrays in a pytree.

Note that this is the minimum size of the array (e.g. for a float32 we need at least 4 bytes) however on some accelerators buffers may occupy more memory due to padding/alignment constraints.

For example given a ResNet50 model:

>>> f = hk.transform_with_state(lambda x: hk.nets.ResNet50(1000)(x, True))
>>> rng = jax.random.PRNGKey(42)
>>> x = jnp.ones([128, 224, 224, 3])
>>> params, state = f.init(rng, x)

We can count the number of parameters and their size at f32:

>>> num_params = hk.data_structures.tree_size(params)
>>> byte_size = hk.data_structures.tree_bytes(params)
>>> print(f'{num_params} params, size: {byte_size / 1e6:.2f}MB')
25557032 params, size: 102.23MB

And compare that with casting our parameters to bf16:

>>> params = jax.tree_map(lambda x: x.astype(jnp.bfloat16), params)
>>> num_params = hk.data_structures.tree_size(params)
>>> byte_size = hk.data_structures.tree_bytes(params)
>>> print(f'{num_params} params, size: {byte_size / 1e6:.2f}MB')
25557032 params, size: 51.11MB
Parameters

tree – A tree of jnp.ndarrays.

Return type

int

Returns

The total size in bytes of the array(s) in the input.

tree_size

haiku.data_structures.tree_size(tree)[source]

Sums the sizes of all arrays in a pytree.

For example given a ResNet50 model:

>>> f = hk.transform_with_state(lambda x: hk.nets.ResNet50(1000)(x, True))
>>> rng = jax.random.PRNGKey(42)
>>> x = jnp.ones([128, 224, 224, 3])
>>> params, state = f.init(rng, x)

We can count the number of parameters and their size at f32:

>>> num_params = hk.data_structures.tree_size(params)
>>> byte_size = hk.data_structures.tree_bytes(params)
>>> print(f'{num_params} params, size: {byte_size / 1e6:.2f}MB')
25557032 params, size: 102.23MB

And compare that with casting our parameters to bf16:

>>> params = jax.tree_map(lambda x: x.astype(jnp.bfloat16), params)
>>> num_params = hk.data_structures.tree_size(params)
>>> byte_size = hk.data_structures.tree_bytes(params)
>>> print(f'{num_params} params, size: {byte_size / 1e6:.2f}MB')
25557032 params, size: 51.11MB
Parameters

tree – A tree of jnp.ndarrays.

Return type

int

Returns

The total size (number of elements) of the array(s) in the input.

Experimental

custom_creator

haiku.experimental.custom_creator(creator)[source]

Registers a custom parameter creator.

When new parameters are created via get_parameter() we first run custom creators passing user defined values through. For example:

>>> def zeros_creator(next_creator, shape, dtype, init, context):
...   init = jnp.zeros
...   return next_creator(shape, dtype, init)
>>> with hk.experimental.custom_creator(zeros_creator):
...   z = hk.get_parameter("z", [], jnp.float32, jnp.ones)
>>> z
DeviceArray(0., dtype=float32)
Parameters

creator (ParamCreator) – A parameter creator.

Returns

Context manager under which the creator is active.

custom_getter

haiku.experimental.custom_getter(getter)[source]

Registers a custom parameter getter.

When parameters are retrieved using get_parameter() we always run all custom getters before returning a value to the user.

>>> def bf16_getter(next_getter, value, context):
...   value = value.astype(jnp.bfloat16)
...   return next_getter(value)
>>> with hk.experimental.custom_getter(bf16_getter):
...   w = hk.get_parameter("w", [], jnp.float32, jnp.ones)
>>> w.dtype
dtype(bfloat16)
Parameters

getter (ParamGetter) – A parameter getter.

Returns

Context manager under which the getter is active.

ParamContext

class haiku.experimental.ParamContext(full_name: str, module: Optional[Module])[source]

Read only state showing where parameters are being created.

full_name

The full name of the given parameter (e.g. mlp/~/linear_0/w).

module

The module that owns the current parameter, None if this parameter exists outside any module.

named_call

haiku.experimental.named_call(fun, *, name)[source]

Wraps a function in a name_scope and maintains Haiku state.

Return type

Callable[.., Any]

optimize_rng_use

haiku.experimental.optimize_rng_use(fun)[source]

Optimizes a RNG key splitting in fun.

Our strategy here is to use abstract interpretation to run your function twice, the first time we use jax.eval_shape to avoid spending any flops and simply observe how many times you call next_rng_key(). We then run your function again, but this time we reserve enough RNG keys ahead of time such that we only need to call jax.random.split once.

In the following example, we need three random samples for our weight matricies in our 3 layer MLP. To draw these samples we use next_rng_key() which will split a new key for each sample. By using optimize_rng_use() Haiku will pre-allocate exactly enough RNGs for f to be evaluated by splitting the input key once and only once. For large models (unlike this example) this can lead to a significant reduction in compilation time for init:

>>> def f(x):
...   net = hk.nets.MLP([300, 100, 10])
...   return net(x)
>>> f = hk.experimental.optimize_rng_use(f)
>>> f = hk.transform(f)
>>> params = f.init(jax.random.PRNGKey(42), jnp.ones([1, 1]))
Parameters

fun – A function to wrap.

Returns

A function that applies fun but only requires one call to jax.random.split by Haiku.

lift

haiku.experimental.lift(*args, **kwargs) → T[source]

Lifts the given init function to a function in the current Haiku namespace.

During init, the returned callable will run the given init_fn, and include the resulting params in the outer transform’s dictionaries. During apply, the returned callable will instead pull the relevant parameters from the outer transform’s dictionaries.

Must be called inside transform(), and be passed the init member of a Transformed.

The user must ensure that the given init does not accidentally catch modules from an outer transform() via functional closure.

This is highly experimental and may be changed or removed at any time.

profiler_name_scopes

haiku.experimental.profiler_name_scopes(enabled=True)[source]

Enable/disable profiler name_scopes on all haiku module methods.

to_dot

haiku.experimental.to_dot(fun)[source]

Converts a function using Haiku modules to a dot graph.

Utilities

multinomial

haiku.multinomial(rng, logits, num_samples)[source]

Draws samples from a multinomial distribution.

Parameters
  • rng – A JAX PRNGKey.

  • logits – Unnormalized log-probabilities, of shape [batch_size, categories] or [categories].

  • num_samples – Number of samples to draw.

Returns

Chosen categories, of shape [batch_size, num_samples] or [num_samples].

one_hot

haiku.one_hot(x, num_classes, dtype=<class 'jax.numpy.lax_numpy.float32'>)[source]

Returns a one-hot version of indices.

DEPRECATED: Use jax.nn.one_hot(x, num_classes).astype(dtype) instead.

Parameters
  • x – A tensor of indices.

  • num_classes – Number of classes in the one-hot dimension.

  • dtype – The dtype.

Returns

The one-hot tensor. If indices’ shape is [A, B, …], shape is

[A, B, … num_classes].

References

1

Wojciech Zaremba, Ilya Sutskever, and Oriol Vinyals. Recurrent neural network regularization. arXiv preprint arXiv:1409.2329, 2014. URL: https://arxiv.org/abs/1409.2329.

2(1,2,3,4)

Rafal Jozefowicz, Wojciech Zaremba, and Ilya Sutskever. An empirical exploration of recurrent network architectures. In International Conference on Machine Learning, 2342–2350. 2015.

3(1,2,3)

SHI Xingjian, Zhourong Chen, Hao Wang, Dit-Yan Yeung, Wai-Kin Wong, and Wang-chun Woo. Convolutional lstm network: a machine learning approach for precipitation nowcasting. In Advances in neural information processing systems, 802–810. 2015.