Base¶
Transforming Functions¶

haiku.
transform
(f, *, apply_rng=True)[source]¶ Transforms a function using Haiku modules into a pair of pure functions.
For a function
out = f(*a, **k)
this function returns a pair of two pure functions that callf(*a, **k)
explicitly collecting and injecting parameter values:params = init(rng, *a, **k) out = apply(params, rng, *a, **k)
Note that the
rng
argument is typically not required forapply
and passingNone
is accepted.The first thing to do is to define a
Module
. A module encapsulates some parameters and a computation on those parameters:>>> class MyModule(hk.Module): ... def __call__(self, x): ... w = hk.get_parameter("w", [], init=jnp.zeros) ... return x + w
Next, define some function that creates and applies modules. We use
transform()
to transform that function into a pair of functions that allow us to lift all the parameters out of the function (f.init
) and apply the function with a given set of parameters (f.apply
):>>> def f(x): ... a = MyModule() ... b = MyModule() ... return a(x) + b(x) >>> f = hk.transform(f)
To get the initial state of the module call
init
with an example input:>>> params = f.init(None, 1) >>> params FlatMapping({ 'my_module': FlatMapping({'w': DeviceArray(0., dtype=float32)}), 'my_module_1': FlatMapping({'w': DeviceArray(0., dtype=float32)}), })
You can then apply the function with the given parameters by calling
apply
(note that since we don’t use Haiku’s random number APIs to apply our network we passNone
as an RNG key):>>> f.apply(params, None, 1) DeviceArray(2., dtype=float32)
It is expected that your program will at some point produce updated parameters and you will want to reapply
apply
. You can do this by callingapply
with different parameters:>>> new_params = {"my_module": {"w": jnp.array(2.)}, ... "my_module_1": {"w": jnp.array(3.)}} >>> f.apply(new_params, None, 2) DeviceArray(9., dtype=float32)
If your transformed function needs to maintain internal state (e.g. moving averages in batch norm) then see
transform_with_state()
. Parameters
f – A function closing over
Module
instances.apply_rng – In the process of being removed. Can only value True.
 Return type
 Returns
A
Transformed
tuple withinit
andapply
pure functions.

haiku.
transform_with_state
(f)[source]¶ Transforms a function using Haiku modules into a pair of pure functions.
See
transform()
for general details on Haiku transformations.For a function
out = f(*a, **k)
this function returns a pair of two pure functions that callf(*a, **k)
explicitly collecting and injecting parameter values and state:params, state = init(rng, *a, **k) out, state = apply(params, state, rng, *a, **k)
Note that the
rng
argument is typically not required forapply
and passingNone
is accepted.This function is equivalent to
transform()
, however it allows you to maintain and update internal state (e.g.ExponentialMovingAverage
inBatchNorm
) viaget_state()
andset_state()
:>>> def f(): ... counter = hk.get_state("counter", shape=[], dtype=jnp.int32, ... init=jnp.zeros) ... hk.set_state("counter", counter + 1) ... return counter >>> f = hk.transform_with_state(f)
>>> params, state = f.init(None) >>> for _ in range(10): ... counter, state = f.apply(params, state, None) >>> counter DeviceArray(9, dtype=int32)
 Parameters
f – A function closing over
Module
instances. Return type
 Returns
A
TransformedWithState
tuple withinit
andapply
pure functions.

haiku.
without_apply_rng
(f)[source]¶ Removes the rng argument from the apply function.
 Return type
TransformedT

haiku.
without_state
(f)[source]¶ Wraps a transformed tuple and ignores state in/out.
The example below is equivalent to
f = hk.transform(f)
:>>> def f(x): ... mod = hk.Linear(10) ... return mod(x) >>> f = hk.without_state(hk.transform_with_state(f)) >>> rng = jax.random.PRNGKey(42) >>> x = jnp.zeros([1, 1]) >>> params = f.init(rng, x) >>> f.apply(params, rng, x) DeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)
 Parameters
f (TransformedWithState) – A transformed function.
 Return type
 Returns
A transformed function that does not take or return state.
Type Hints¶

class
haiku.
LSTMState
[source]¶ An LSTM core state consists of hidden and cell vectors.
Hidden state.

cell
¶ Cell state.

haiku.
Params
¶

haiku.
State
¶
Parameters and State¶

class
haiku.
Module
(name=None)[source]¶ Base class for Haiku modules.
A Haiku module is a lightweight container for variables and other modules. Modules typically define one or more “forward” methods (e.g.
__call__
) which apply operations combining user input and module parameters.Modules must be initialized inside a
transform()
call.For example:
>>> class AddModule(hk.Module): ... def __call__(self, x): ... w = hk.get_parameter("w", [], init=jnp.ones) ... return x + w
>>> def forward_fn(x): ... mod = AddModule() ... return mod(x)
>>> forward = hk.transform(forward_fn) >>> x = 1. >>> rng = None >>> params = forward.init(rng, x) >>> forward.apply(params, None, x) DeviceArray(2., dtype=float32)

__init__
(name=None)[source]¶ Initializes the current module with the given name.
Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.
 Parameters
name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If
name
is not provided then the class name for the current instance is converted tolower_snake_case
and used instead.

__post_init__
(name=None)[source]¶ Initializes the current module with the given name.
Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.
 Parameters
name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If
name
is not provided then the class name for the current instance is converted tolower_snake_case
and used instead.


haiku.
to_module
(f)[source]¶ Converts a function into a callable module class.
Sample usage:
>>> def bias_fn(x): ... b = hk.get_parameter("b", [], init=hk.initializers.RandomNormal()) ... return x + b >>> Bias = hk.to_module(bias_fn) >>> def net(x, y): ... b = Bias(name="my_bias") ... # Bias x and y by the same amount. ... return b(x) * b(y)
 Parameters
f (Callable[.., Any]) – The function to convert.
 Return type
Type[CallableModule]
 Returns
A module class which runs
f
when called.

haiku.
get_parameter
(name, shape, dtype=<class 'jax.numpy.lax_numpy.float32'>, init=None)[source]¶ Creates or reuses a parameter for the given transformed function.
>>> hk.get_parameter("w", [], init=jnp.ones) DeviceArray(1., dtype=float32)
Parameters within the same
transform()
and/orModule
with the same name have the same value:>>> w1 = hk.get_parameter("w", [], init=jnp.zeros) >>> w2 = hk.get_parameter("w", [], init=jnp.zeros) >>> assert w1 is w2
 Parameters
name (str) – A name for the parameter.
shape (Sequence[int]) – The shape of the parameter.
dtype (Any) – The dtype of the parameter.
init (Initializer) – A callable of shape, dtype to generate an initial value for the parameter.
 Return type
jnp.ndarray
 Returns
A jnp.ndarray with the parameter of the given shape.

haiku.
get_state
(name, shape=None, dtype=<class 'jax.numpy.lax_numpy.float32'>, init=None)[source]¶ Gets the current value for state with an optional initializer.
“State” can be used to represent mutable state in your network. The most common usage of state is to represent the moving averages used in batch normalization (see
ExponentialMovingAverage
). If your network uses “state” then you are required to usetransform_with_state()
and pass state into and out of the apply function.>>> hk.get_state("counter", [], init=jnp.zeros) DeviceArray(0., dtype=float32)
If the value for the given state is already defined (e.g. using
set_state()
) then you can call with just the name:>>> hk.get_state("counter") DeviceArray(0., dtype=float32)
MOTE: state within the same
transform()
and/orModule
with the same name have the same value:>>> c1 = hk.get_state("counter") >>> c2 = hk.get_state("counter") >>> assert c1 is c2
 Parameters
name (str) – A name for the state.
shape (Optional[Sequence[int]]) – The shape of the state.
dtype (Any) – The dtype of the state.
init (Optional[Initializer]) – A callable
f(shape, dtype)
that returns an initial value for the state.
 Return type
jnp.ndarray
 Returns
A jnp.ndarray with the state of the given shape.

haiku.
set_state
(name, value)[source]¶ Sets the current value for some state.
See
get_state()
.“State” can be used to represent mutable state in your network. The most common usage of state is to represent the moving averages used in batch normalization (see
ExponentialMovingAverage
). If your network uses “state” then you are required to usetransform_with_state()
and pass state into and out of the apply function.>>> hk.set_state("counter", jnp.zeros([])) >>> hk.get_state("counter") DeviceArray(0., dtype=float32)
NOTE: state within the same
transform()
and/orModule
with the same name have the same value:>>> w1 = hk.get_state("counter") >>> w2 = hk.get_state("counter") >>> assert w1 is w2
 Parameters
name (str) – A name for the state.
value – A value to set.

haiku.
transparent
(method)[source]¶ Decorator to wrap a method, preventing automatic variable scope wrapping.
By default, all variables and modules created in a method are scoped by the module and method names. This is undesirable in some cases. Any method decorated with
transparent()
will create variables and modules in the scope in which it was called. Parameters
method (T) – the method to wrap.
 Return type
T
 Returns
The method, with a flag indicating no name scope wrapping should occur.
Random Number Generators¶

class
haiku.
PRNGSequence
(key_or_seed)[source]¶ Iterator of JAX random keys.
>>> seq = hk.PRNGSequence(42) # OR pass a jax.random.PRNGKey >>> key1 = next(seq) >>> key2 = next(seq) >>> assert key1 is not key2
If you know how many keys you will want then you can use
reserve()
to more efficiently split the keys you need:>>> seq.reserve(4) >>> keys = [next(seq) for _ in range(4)]

__init__
(key_or_seed)[source]¶ Creates a new
PRNGSequence
.


haiku.
next_rng_key
()[source]¶ Returns a unique JAX random key split from the current global key.
>>> key = hk.next_rng_key() >>> _ = jax.random.uniform(key, [])
 Return type
PRNGKey
 Returns
A unique (within a call to
init
orapply
) JAX rng key that can be used with APIs such asjax.random.uniform
.

haiku.
next_rng_keys
(num)[source]¶ Returns one or more JAX random key split from the current global key.
>>> k1, k2 = hk.next_rng_keys(2) >>> assert (k1 != k2).all() >>> a = jax.random.uniform(k1, []) >>> b = jax.random.uniform(k2, []) >>> assert a != b
 Parameters
num (int) – The number of keys to split.
 Return type
Tuple[PRNGKey, ..]
 Returns
One or more unique (within a transformed function) JAX rng key that can be used with APIs such as
jax.random.uniform
.

haiku.
maybe_next_rng_key
()[source]¶ next_rng_key()
if random numbers are available, elseNone
. Return type
Optional[PRNGKey]

haiku.
reserve_rng_keys
(num)[source]¶ Preallocate some number of JAX RNG keys.
See
next_rng_key()
.This API offers a way to microoptimize how RNG keys are split when using Haiku. It is unlikely that you need it unless you find compilation time of your
init
function to be a problem, or you sample a lot of random numbers inapply
.>>> hk.reserve_rng_keys(2) # Preallocate 2 keys for us to consume. >>> _ = hk.next_rng_key() # Takes the first preallocated key. >>> _ = hk.next_rng_key() # Takes the second preallocated key. >>> _ = hk.next_rng_key() # Splits a new key.
 Parameters
num (int) – The number of JAX rng keys to allocate.

haiku.
with_rng
(key)[source]¶ Provides a new sequence for
next_rng_key()
to draw from.When
next_rng_key()
is called, it draws a new key from thePRNGSequence
defined by the input key to the transformed function. This context manager overrides the sequence for the duration of the scope.>>> with hk.with_rng(jax.random.PRNGKey(428)): ... s = jax.random.uniform(hk.next_rng_key(), ()) >>> s DeviceArray(0.501871, dtype=float32)
 Parameters
key (PRNGKey) – The key to seed the sequence with.
 Returns
Context manager under which the given sequence is active.
Linear modules¶
Linear¶

class
haiku.
Linear
(output_size, with_bias=True, w_init=None, b_init=None, name=None)[source]¶ Linear module.

__init__
(output_size, with_bias=True, w_init=None, b_init=None, name=None)[source]¶ Constructs the Linear module.
 Parameters
output_size (int) – Output dimensionality.
with_bias (bool) – Whether to add a bias to the output.
w_init (Optional[hk.initializers.Initializer]) – Optional initializer for weights. By default, uses random values from truncated normal, with stddev
1 / sqrt(fan_in)
. See https://arxiv.org/abs/1502.03167v3.b_init (Optional[hk.initializers.Initializer]) – Optional initializer for bias. By default, zero.
name (Optional[str]) – Name of the module.

Bias¶

class
haiku.
Bias
(output_size=None, bias_dims=None, b_init=None, name=None)[source]¶ Adds a bias to inputs.
Example Usage:
>>> N, H, W, C = 1, 2, 3, 4 >>> x = jnp.ones([N, H, W, C]) >>> scalar_bias = hk.Bias(bias_dims=[]) >>> scalar_bias_output = scalar_bias(x) >>> assert scalar_bias.bias_shape == ()
Create a bias over all nonminibatch dimensions:
>>> all_bias = hk.Bias() >>> all_bias_output = all_bias(x) >>> assert all_bias.bias_shape == (H, W, C)
Create a bias over the last nonminibatch dimension:
>>> last_bias = hk.Bias(bias_dims=[1]) >>> last_bias_output = last_bias(x) >>> assert last_bias.bias_shape == (C,)
Create a bias over the first nonminibatch dimension:
>>> first_bias = hk.Bias(bias_dims=[1]) >>> first_bias_output = first_bias(x) >>> assert first_bias.bias_shape == (H, 1, 1)
Subtract and later add the same learned bias:
>>> bias = hk.Bias() >>> h1 = bias(x, multiplier=1) >>> h2 = bias(x) >>> h3 = bias(x, multiplier=1) >>> reconstructed_x = bias(h3) >>> assert (x == reconstructed_x).all()

__init__
(output_size=None, bias_dims=None, b_init=None, name=None)[source]¶ Constructs a
Bias
module that supports broadcasting. Parameters
output_size (Optional[Sequence[int]]) – Output size (output shape without batch dimension). If
output_size
is left as None, the size will be directly inferred by the input.bias_dims (Optional[Sequence[int]]) – Sequence of which dimensions to retain from the input shape when constructing the bias. The remaining dimensions will be broadcast over (given size of 1), and leading dimensions will be removed completely. See class doc for examples.
b_init (Optional[hk.initializers.Initializer]) – Optional initializer for the bias. Default to zeros.
name (Optional[str]) – Name of the module.

__call__
(inputs, multiplier=None)[source]¶ Adds bias to
inputs
and optionally multiplies bymultiplier
. Parameters
inputs (jnp.ndarray) – A Tensor of size
[batch_size, input_size1, ...]
.multiplier (Union[float, jnp.ndarray]) – A scalar or Tensor which the bias term is multiplied by before adding it to
inputs
. Anything which works in the expressionbias * multiplier
is acceptable here. This may be useful if you want to add a bias in one place and subtract the same bias in another place viamultiplier=1
.
 Return type
jnp.ndarray
 Returns
A Tensor of size
[batch_size, input_size1, ...]
.

Pooling modules¶
Average Pool¶

haiku.
avg_pool
(value, window_shape, strides, padding, channel_axis=1)[source]¶ Average pool.
 Parameters
value (jnp.ndarray) – Value to pool.
window_shape (Union[int, Sequence[int]]) – Shape of the pooling window, an int or same rank as value.
strides (Union[int, Sequence[int]]) – Strides of the pooling window, an int or same rank as value.
padding (str) – Padding algorithm. Either
VALID
orSAME
.channel_axis (Optional[int]) – Axis of the spatial channels for which pooling is skipped, used to infer
window_shape
orstrides
if they are an integer.
 Return type
jnp.ndarray
 Returns
Pooled result. Same rank as value.
 Raises
ValueError – If the padding is not valid.

class
haiku.
AvgPool
(window_shape, strides, padding, channel_axis=1, name=None)[source]¶ Average pool.
Equivalent to partial application of
avg_pool()
.
__init__
(window_shape, strides, padding, channel_axis=1, name=None)[source]¶ Average pool.
 Parameters
window_shape (Union[int, Sequence[int]]) – Shape of window to pool over. Same rank as value or
int
.strides (Union[int, Sequence[int]]) – Strides for the window. Same rank as value or
int
.padding (str) – Padding algorithm. Either
VALID
orSAME
.channel_axis (Optional[int]) – Axis of the spatial channels for which pooling is skipped.
name (Optional[str]) – String name for the module.

Max Pool¶

haiku.
max_pool
(value, window_shape, strides, padding, channel_axis=1)[source]¶ Max pool.
 Parameters
value (jnp.ndarray) – Value to pool.
window_shape (Union[int, Sequence[int]]) – Shape of the pooling window, an int or same rank as value.
strides (Union[int, Sequence[int]]) – Strides of the pooling window, an int or same rank as value.
padding (str) – Padding algorithm. Either
VALID
orSAME
.channel_axis (Optional[int]) – Axis of the spatial channels for which pooling is skipped, used to infer
window_shape
orstrides
if they are an integer.
 Return type
jnp.ndarray
 Returns
Pooled result. Same rank as value.

class
haiku.
MaxPool
(window_shape, strides, padding, channel_axis=1, name=None)[source]¶ Max pool.
Equivalent to partial application of
max_pool()
.
__init__
(window_shape, strides, padding, channel_axis=1, name=None)[source]¶ Max pool.
 Parameters
window_shape (Union[int, Sequence[int]]) – Shape of window to pool over. Same rank as value or
int
.strides (Union[int, Sequence[int]]) – Strides for the window. Same rank as value or
int
.padding (str) – Padding algorithm. Either
VALID
orSAME
.channel_axis (Optional[int]) – Axis of the spatial channels for which pooling is skipped.
name (Optional[str]) – String name for the module.

Dropout¶
dropout¶

haiku.
dropout
(rng, rate, x)[source]¶ Randomly drop units in the input at a given rate.
See: http://www.cs.toronto.edu/~hinton/absps/dropout.pdf
 Parameters
rng (PRNGKey) – A JAX random key.
rate (float) – Probability that each element of
x
is discarded. Must be a scalarthe range [0, 1) (in) –
x (jnp.ndarray) – The value to be dropped out.
 Return type
jnp.ndarray
 Returns
x, but dropped out and scaled by
1 / (1  rate)
.
Combinator modules¶
Sequential¶

class
haiku.
Sequential
(layers, name=None)[source]¶ Sequentially calls the given list of layers.
Note that
Sequential
is limited in the range of possible architectures it can handle. This is a deliberate design decision;Sequential
is only meant to be used for the simple case of fusing together modules/ops where the input of a particular module/op is the output of the previous one.Another restriction is that it is not possible to have extra arguments in the
__call__()
method that are passed to the constituents of the module  for example, if there is aBatchNorm
module inSequential
and the user wishes to switch theis_training
flag. If this is the desired use case, the recommended solution is to subclassModule
and implement__call__
:>>> class CustomModule(hk.Module): ... def __call__(self, x, is_training): ... x = hk.Conv2D(32, 4, 2)(x) ... x = hk.BatchNorm(True, True, 0.9)(x, is_training) ... x = jax.nn.relu(x) ... return x

__init__
(layers, name=None)[source]¶ Initializes the current module with the given name.
Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.
 Parameters
name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If
name
is not provided then the class name for the current instance is converted tolower_snake_case
and used instead.

Convolutional modules¶
ConvND¶

class
haiku.
ConvND
(num_spatial_dims, output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='channels_last', mask=None, feature_group_count=1, name=None)[source]¶ General Ndimensional convolutional.

__init__
(num_spatial_dims, output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='channels_last', mask=None, feature_group_count=1, name=None)[source]¶ Initializes the module.
 Parameters
num_spatial_dims (int) – The number of spatial dimensions of the input.
output_channels (int) – Number of output channels.
kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length
num_spatial_dims
.stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length
num_spatial_dims
. Defaults to 1.rate (Union[int, Sequence[int]]) – Optional kernel dilation rate. Either an integer or a sequence of length
num_spatial_dims
. 1 corresponds to standard ND convolution,rate > 1
corresponds to dilated convolution. Defaults to 1.padding (Union[str, Sequence[Tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – Optional padding algorithm. Either
VALID
orSAME
or a sequence of n(low, high)
integer pairs that give the padding to apply before and after each spatial dimension. or a callable or sequence of callables of sizenum_spatial_dims
. Any callables must take a single integer argument equal to the effective kernel size and return a sequence of two integers representing the padding before and after. Seehaiku.pad.*
for more details and example functions. Defaults toSAME
. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.with_bias (bool) – Whether to add a bias. By default, true.
w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.
b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.
data_format (str) – The data format of the input. Can be either
channels_first
,channels_last
,N...C
orNC...
. By default,channels_last
.mask (Optional[jnp.ndarray]) – Optional mask of the weights.
feature_group_count (int) – Optional number of groups in group convolution. Default value of 1 corresponds to normal dense convolution. If a higher value is used, convolutions are applied separately to that many groups, then stacked together. This reduces the number of parameters and possibly the compute for a given
output_channels
. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.name (Optional[str]) – The name of the module.

__call__
(inputs)[source]¶ Connects
ConvND
layer. Parameters
inputs (jnp.ndarray) – An array of shape
[spatial_dims, C]
and rankN+1 if unbatched, or an array of shape[N, spatial_dims, C]
and rankN+2 if batched. Return type
jnp.ndarray
 Returns
 An array of shape
[spatial_dims, output_channels]
and rankN+1 if unbatched, or an array of shape
[N, spatial_dims, output_channels]
and rankN+2 if batched.
 An array of shape

Conv1D¶

class
haiku.
Conv1D
(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', mask=None, feature_group_count=1, name=None)[source]¶ One dimensional convolution.

__init__
(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', mask=None, feature_group_count=1, name=None)[source]¶ Initializes the module.
 Parameters
output_channels (int) – Number of output channels.
kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 1.
stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 1. Defaults to 1.
rate (Union[int, Sequence[int]]) – Optional kernel dilation rate. Either an integer or a sequence of length 1. 1 corresponds to standard ND convolution,
rate > 1
corresponds to dilated convolution. Defaults to 1.padding (Union[str, Sequence[Tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – Optional padding algorithm. Either
VALID
orSAME
or a callable or sequence of callables of length 1. Any callables must take a single integer argument equal to the effective kernel size and return a list of two integers representing the padding before and after. See haiku.pad.* for more details and example functions. Defaults toSAME
. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.with_bias (bool) – Whether to add a bias. By default, true.
w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.
b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.
data_format (str) – The data format of the input. Either
NWC
orNCW
. By default,NWC
.mask (Optional[jnp.ndarray]) – Optional mask of the weights.
feature_group_count (int) – Optional number of groups in group convolution. Default value of 1 corresponds to normal dense convolution. If a higher value is used, convolutions are applied separately to that many groups, then stacked together. This reduces the number of parameters and possibly the compute for a given
output_channels
. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.name (Optional[str]) – The name of the module.

Conv2D¶

class
haiku.
Conv2D
(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', mask=None, feature_group_count=1, name=None)[source]¶ Two dimensional convolution.

__init__
(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', mask=None, feature_group_count=1, name=None)[source]¶ Initializes the module.
 Parameters
output_channels (int) – Number of output channels.
kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 2.
stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 2. Defaults to 1.
rate (Union[int, Sequence[int]]) – Optional kernel dilation rate. Either an integer or a sequence of length 2. 1 corresponds to standard ND convolution,
rate > 1
corresponds to dilated convolution. Defaults to 1.padding (Union[str, Sequence[Tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – Optional padding algorithm. Either
VALID
orSAME
or a callable or sequence of callables of length 2. Any callables must take a single integer argument equal to the effective kernel size and return a list of two integers representing the padding before and after. See haiku.pad.* for more details and example functions. Defaults toSAME
. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.with_bias (bool) – Whether to add a bias. By default, true.
w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.
b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.
data_format (str) – The data format of the input. Either
NHWC
orNCHW
. By default,NHWC
.mask (Optional[jnp.ndarray]) – Optional mask of the weights.
feature_group_count (int) – Optional number of groups in group convolution. Default value of 1 corresponds to normal dense convolution. If a higher value is used, convolutions are applied separately to that many groups, then stacked together. This reduces the number of parameters and possibly the compute for a given
output_channels
. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.name (Optional[str]) – The name of the module.

Conv3D¶

class
haiku.
Conv3D
(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', mask=None, feature_group_count=1, name=None)[source]¶ Three dimensional convolution.

__init__
(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', mask=None, feature_group_count=1, name=None)[source]¶ Initializes the module.
 Parameters
output_channels (int) – Number of output channels.
kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 3.
stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 3. Defaults to 1.
rate (Union[int, Sequence[int]]) – Optional kernel dilation rate. Either an integer or a sequence of length 3. 1 corresponds to standard ND convolution, rate > 1 corresponds to dilated convolution. Defaults to 1.
padding (Union[str, Sequence[Tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – Optional padding algorithm. Either
VALID
orSAME
or a callable or sequence of callables of length 3. Any callables must take a single integer argument equal to the effective kernel size and return a list of two integers representing the padding before and after. See haiku.pad.* for more details and example functions. Defaults toSAME
. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.with_bias (bool) – Whether to add a bias. By default, true.
w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.
b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.
data_format (str) – The data format of the input. Either
NDHWC
orNCDHW
. By default,NDHWC
.mask (Optional[jnp.ndarray]) – Optional mask of the weights.
feature_group_count (int) – Optional number of groups in group convolution. Default value of 1 corresponds to normal dense convolution. If a higher value is used, convolutions are applied separately to that many groups, then stacked together. This reduces the number of parameters and possibly the compute for a given
output_channels
. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.name (Optional[str]) – The name of the module.

ConvNDTranspose¶

class
haiku.
ConvNDTranspose
(num_spatial_dims, output_channels, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='channels_last', mask=None, name=None)[source]¶ General ndimensional transposed convolution (aka. deconvolution).

__init__
(num_spatial_dims, output_channels, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='channels_last', mask=None, name=None)[source]¶ Initializes the module.
 Parameters
num_spatial_dims (int) – The number of spatial dimensions of the input.
output_channels (int) – Number of output channels.
kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length
num_spatial_dims
.stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length
num_spatial_dims
. Defaults to 1.padding (Union[str, Sequence[Tuple[int, int]]]) – Optional padding algorithm. Either “VALID” or “SAME”. Defaults to “SAME”. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
with_bias (bool) – Whether to add a bias. By default, true.
w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.
b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.
data_format (str) – The data format of the input. Can be either
channels_first
,channels_last
,N...C
orNC...
. By default,channels_last
.mask (Optional[jnp.ndarray]) – Optional mask of the weights.
name (str) – The name of the module.

__call__
(inputs)[source]¶ Computes the transposed convolution of the input.
 Parameters
inputs (jnp.ndarray) – An array of shape
[spatial_dims, C]
and rankN+1 if unbatched, or an array of shape[N, spatial_dims, C]
and rankN+2 if batched. Return type
jnp.ndarray
 Returns
 An array of shape
[spatial_dims, output_channels]
and rankN+1 if unbatched, or an array of shape
[N, spatial_dims, output_channels]
and rankN+2 if batched.
 An array of shape

Conv1DTranspose¶

class
haiku.
Conv1DTranspose
(output_channels, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', mask=None, name=None)[source]¶ One dimensional transposed convolution (aka. deconvolution).

__init__
(output_channels, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', mask=None, name=None)[source]¶ Initializes the module.
 Parameters
output_channels (int) – Number of output channels.
kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 1.
stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 1. Defaults to 1.
padding (Union[str, Sequence[Tuple[int, int]]]) – Optional padding algorithm. Either
VALID
orSAME
. Defaults toSAME
. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.with_bias (bool) – Whether to add a bias. By default, true.
w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.
b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.
data_format (str) – The data format of the input. Either
NWC
orNCW
. By default,NWC
.mask (Optional[jnp.ndarray]) – Optional mask of the weights.
name (Optional[str]) – The name of the module.

Conv2DTranspose¶

class
haiku.
Conv2DTranspose
(output_channels, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', mask=None, name=None)[source]¶ Two dimensional transposed convolution (aka. deconvolution).

__init__
(output_channels, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', mask=None, name=None)[source]¶ Initializes the module.
 Parameters
output_channels (int) – Number of output channels.
kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 2.
stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 2. Defaults to 1.
padding (Union[str, Sequence[Tuple[int, int]]]) – Optional padding algorithm. Either
VALID
orSAME
. Defaults toSAME
. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.with_bias (bool) – Whether to add a bias. By default, true.
w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.
b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.
data_format (str) – The data format of the input. Either
NHWC
orNCHW
. By default,NHWC
.mask (Optional[jnp.ndarray]) – Optional mask of the weights.
name (Optional[str]) – The name of the module.

Conv3DTranspose¶

class
haiku.
Conv3DTranspose
(output_channels, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', mask=None, name=None)[source]¶ Three dimensional transposed convolution (aka. deconvolution).

__init__
(output_channels, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', mask=None, name=None)[source]¶ Initializes the module.
 Parameters
output_channels (int) – Number of output channels.
kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 3.
stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 3. Defaults to 1.
padding (Union[str, Sequence[Tuple[int, int]]]) – Optional padding algorithm. Either
VALID
orSAME
. Defaults toSAME
. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.with_bias (bool) – Whether to add a bias. By default, true.
w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.
b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.
data_format (str) – The data format of the input. Either
NDHWC
orNCDHW
. By default,NDHWC
.mask (Optional[jnp.ndarray]) – Optional mask of the weights.
name (Optional[str]) – The name of the module.

DepthwiseConv2D¶

class
haiku.
DepthwiseConv2D
(channel_multiplier, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', name=None)[source]¶ 2D Depthwise Convolution Module.

__init__
(channel_multiplier, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', name=None)[source]¶ Construct a 2D Depthwise Convolution.
 Parameters
channel_multiplier (int) – Multiplicity of output channels. To keep the number of output channels the same as the number of input channels, set 1.
kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length
num_spatial_dims
.stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length
num_spatial_dims
. Defaults to 1.padding (Union[str, Sequence[Tuple[int, int]]]) – Optional padding algorithm. Either
VALID
,SAME
or a sequence ofbefore, after
pairs. Defaults toSAME
. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.with_bias (bool) – Whether to add a bias. By default, true.
w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.
b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.
data_format (str) – The data format of the input. Can be either
channels_first
,channels_last
,N...C
orNC...
. By default,channels_last
.name (Optional[str]) – The name of the module.

SeparableDepthwiseConv2D¶

class
haiku.
DepthwiseConv2D
(channel_multiplier, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', name=None)[source] 2D Depthwise Convolution Module.

__init__
(channel_multiplier, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', name=None)[source] Construct a 2D Depthwise Convolution.
 Parameters
channel_multiplier (int) – Multiplicity of output channels. To keep the number of output channels the same as the number of input channels, set 1.
kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length
num_spatial_dims
.stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length
num_spatial_dims
. Defaults to 1.padding (Union[str, Sequence[Tuple[int, int]]]) – Optional padding algorithm. Either
VALID
,SAME
or a sequence ofbefore, after
pairs. Defaults toSAME
. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.with_bias (bool) – Whether to add a bias. By default, true.
w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.
b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.
data_format (str) – The data format of the input. Can be either
channels_first
,channels_last
,N...C
orNC...
. By default,channels_last
.name (Optional[str]) – The name of the module.

__call__
(inputs)[source] Call self as a function.
 Return type
jnp.ndarray

Normalization modules¶
BatchNorm¶

class
haiku.
BatchNorm
(create_scale, create_offset, decay_rate, eps=1e05, scale_init=None, offset_init=None, axis=None, cross_replica_axis=None, data_format='channels_last', name=None)[source]¶ Normalizes inputs to maintain a mean of ~0 and stddev of ~1.
See: https://arxiv.org/abs/1502.03167.
There are many different variations for how users want to manage scale and offset if they require them at all. These are:
No scale/offset in which case
create_*
should be set toFalse
andscale
/offset
aren’t passed when the module is called.Trainable scale/offset in which case
create_*
should be set toTrue
and againscale
/offset
aren’t passed when the module is called. In this case this module creates and owns thescale
/offset
variables.Externally generated
scale
/offset
, such as for conditional normalization, in which casecreate_*
should be set toFalse
and then the values fed in at call time.
NOTE:
jax.vmap(hk.transform(BatchNorm))
will update summary statistics and normalize values on a perbatch basis; we currently do not support normalizing across a batch axis introduced by vmap.
__init__
(create_scale, create_offset, decay_rate, eps=1e05, scale_init=None, offset_init=None, axis=None, cross_replica_axis=None, data_format='channels_last', name=None)[source]¶ Constructs a BatchNorm module.
 Parameters
create_scale (bool) – Whether to include a trainable scaling factor.
create_offset (bool) – Whether to include a trainable offset.
decay_rate (float) – Decay rate for EMA.
eps (float) – Small epsilon to avoid division by zero variance. Defaults
1e5
, as in the paper and Sonnet.scale_init (Optional[hk.initializers.Initializer]) – Optional initializer for gain (aka scale). Can only be set if
create_scale=True
. By default,1
.offset_init (Optional[hk.initializers.Initializer]) – Optional initializer for bias (aka offset). Can only be set if
create_offset=True
. By default,0
.axis (Optional[Sequence[int]]) – Which axes to reduce over. The default (
None
) signifies that all but the channel axis should be normalized. Otherwise this is a list of axis indices which will have normalization statistics calculated.cross_replica_axis (Optional[str]) – If not
None
, it should be a string representing the axis name over which this module is being run within ajax.pmap
. Supplying this argument means that batch statistics are calculated across all replicas on that axis.data_format (str) – The data format of the input. Can be either
channels_first
,channels_last
,N...C
orNC...
. By default it ischannels_last
.name (Optional[str]) – The module name.

__call__
(inputs, is_training, test_local_stats=False, scale=None, offset=None)[source]¶ Computes the normalized version of the input.
 Parameters
inputs (jnp.ndarray) – An array, where the data format is
[..., C]
.is_training (bool) – Whether this is during training.
test_local_stats (bool) – Whether local stats are used when is_training=False.
scale (Optional[jnp.ndarray]) – An array up to nD. The shape of this tensor must be broadcastable to the shape of
inputs
. This is the scale applied to the normalized inputs. This cannot be passed in if the module was constructed withcreate_scale=True
.offset (Optional[jnp.ndarray]) – An array up to nD. The shape of this tensor must be broadcastable to the shape of
inputs
. This is the offset applied to the normalized inputs. This cannot be passed in if the module was constructed withcreate_offset=True
.
 Return type
jnp.ndarray
 Returns
The array, normalized across all but the last dimension.
GroupNorm¶

class
haiku.
GroupNorm
(groups, axis=slice(1, None, None), create_scale=True, create_offset=True, eps=1e05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]¶ Group normalization module.
This applies group normalization to the x. This involves splitting the channels into groups before calculating the mean and variance. The default behaviour is to compute the mean and variance over the spatial dimensions and the grouped channels. The mean and variance will never be computed over the created groups axis.
It transforms the input
x
into:\[\d{outputs} = \d{scale} \dfrac{x  \mu}{\sigma + \epsilon} + \d{offset}\]Where \(\mu\) and \(\sigma\) are respectively the mean and standard deviation of
x
.There are many different variations for how users want to manage scale and offset if they require them at all. These are:
No
scale
/offset
in which casecreate_*
should be set toFalse
andscale
/offset
aren’t passed when the module is called.Trainable
scale
/offset
in which case create_* should be set toTrue
and againscale
/offset
aren’t passed when the module is called. In this case this module creates and owns the scale/offset parameters.Externally generated
scale
/offset
, such as for conditional normalization, in which casecreate_*
should be set toFalse
and then the values fed in at call time.

__init__
(groups, axis=slice(1, None, None), create_scale=True, create_offset=True, eps=1e05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]¶ Constructs a
GroupNorm
module. Parameters
groups (int) – number of groups to divide the channels by. The number of channels must be divisible by this.
axis (Union[int, slice, Sequence[int]]) –
int
,slice
or sequence of ints representing the axes which should be normalized across. By default this is all but the first dimension. For time series data use slice(2, None) to average over the none Batch and Time data.create_scale (bool) – whether to create a trainable scale per channel applied after the normalization.
create_offset (bool) – whether to create a trainable offset per channel applied after normalization and scaling.
eps (float) – Small epsilon to add to the variance to avoid division by zero. Defaults to
1e5
.scale_init (Optional[hk.initializers.Initializer]) – Optional initializer for the scale parameter. Can only be set if
create_scale=True
. By default scale is initialized to1
.offset_init (Optional[hk.initializers.Initializer]) – Optional initializer for the offset parameter. Can only be set if
create_offset=True
. By default offset is initialized to0
.data_format (str) – The data format of the input. Can be either
channels_first
,channels_last
,N...C
orNC...
. By default it ischannels_last
.name (Optional[str]) – Name of the module.

__call__
(x, scale=None, offset=None)[source]¶ Returns normalized inputs.
 Parameters
x (jnp.ndarray) – An nD tensor of the
data_format
specified in the constructor on which the transformation is performed.scale (Optional[jnp.ndarray]) – A tensor up to nD. The shape of this tensor must be broadcastable to the shape of
x
. This is the scale applied to the normalized x. This cannot be passed in if the module was constructed withcreate_scale=True
.offset (Optional[jnp.ndarray]) – A tensor up to nD. The shape of this tensor must be broadcastable to the shape of
x
. This is the offset applied to the normalizedx
. This cannot be passed in if the module was constructed withcreate_offset=True
.
 Return type
jnp.ndarray
 Returns
An nd tensor of the same shape as x that has been normalized.
InstanceNorm¶

class
haiku.
InstanceNorm
(create_scale, create_offset, eps=1e05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]¶ Normalizes inputs along the spatial dimensions.
See
LayerNorm
for more details.
__init__
(create_scale, create_offset, eps=1e05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]¶ Constructs an
InstanceNorm
module.This method creates a module which normalizes over the spatial dimensions.
 Parameters
create_scale (bool) –
bool
representing whether to create a trainable scale per channel applied after the normalization.create_offset (bool) –
bool
representing whether to create a trainable offset per channel applied after normalization and scaling.eps (float) – Small epsilon to avoid division by zero variance. Defaults to
1e5
.scale_init (Optional[hk.initializers.Initializer]) – Optional initializer for the scale variable. Can only be set if
create_scale=True
. By default scale is initialized to1
.offset_init (Optional[hk.initializers.Initializer]) – Optional initializer for the offset variable. Can only be set if
create_offset=True
. By default offset is initialized to0
.data_format (str) – The data format of the input. Can be either
channels_first
,channels_last
,N...C
orNC...
. By default it ischannels_last
.name (Optional[str]) – Name of the module.

LayerNorm¶

class
haiku.
LayerNorm
(axis, create_scale, create_offset, eps=1e05, scale_init=None, offset_init=None, name=None)[source]¶ LayerNorm module.
See: https://arxiv.org/abs/1607.06450.

__init__
(axis, create_scale, create_offset, eps=1e05, scale_init=None, offset_init=None, name=None)[source]¶ Constructs a LayerNorm module.
 Parameters
axis (Union[int, Sequence[int], slice]) – Integer, list of integers, or slice indicating which axes to normalize over.
create_scale (bool) – Bool, defines whether to create a trainable scale per channel applied after the normalization.
create_offset (bool) – Bool, defines whether to create a trainable offset per channel applied after normalization and scaling.
eps (float) – Small epsilon to avoid division by zero variance. Defaults
1e5
, as in the paper and Sonnet.scale_init (Optional[hk.initializers.Initializer]) – Optional initializer for gain (aka scale). By default, one.
offset_init (Optional[hk.initializers.Initializer]) – Optional initializer for bias (aka offset). By default, zero.
name (Optional[str]) – The module name.

__call__
(inputs, scale=None, offset=None)[source]¶ Connects the layer norm.
 Parameters
inputs (jnp.ndarray) – An array, where the data format is
[N, ..., C]
.scale (Optional[jnp.ndarray]) – An array up to nD. The shape of this tensor must be broadcastable to the shape of
inputs
. This is the scale applied to the normalized inputs. This cannot be passed in if the module was constructed withcreate_scale=True
.offset (Optional[jnp.ndarray]) – An array up to nD. The shape of this tensor must be broadcastable to the shape of
inputs
. This is the offset applied to the normalized inputs. This cannot be passed in if the module was constructed withcreate_offset=True
.
 Return type
jnp.ndarray
 Returns
The array, normalized.

SpectralNorm¶

class
haiku.
SpectralNorm
(eps=0.0001, n_steps=1, name=None)[source]¶ Normalizes an input by its first singular value.
This module uses power iteration to calculate this value based on the input and an internal hidden state.

__init__
(eps=0.0001, n_steps=1, name=None)[source]¶ Initializes an SpectralNorm module.
 Parameters
eps (float) – The constant used for numerical stability.
n_steps (int) – How many steps of power iteration to perform to approximate the singular value of the input.
name (Optional[str]) – The name of the module.

__call__
(value, update_stats=True, error_on_non_matrix=False)[source]¶ Performs Spectral Normalization and returns the new value.
 Parameters
value – The arraylike object for which you would like to perform an spectral normalization on.
update_stats (bool) – A boolean defaulting to True. Regardless of this arg, this function will return the normalized input. When update_stats is True, the internal state of this object will also be updated to reflect the input value. When update_stats is False the internal stats will remain unchanged.
error_on_non_matrix (bool) – Spectral normalization is only defined on matrices. By default, this module will return scalars unchanged and flatten higherorder tensors in their leading dimensions. Setting this flag to True will instead throw errors in those cases.
 Return type
jnp.ndarray
 Returns
The input value normalized by it’s first singular value.
 Raises
ValueError – If error_on_non_matrix is True and value has ndims > 2.

ExponentialMovingAverage¶

class
haiku.
ExponentialMovingAverage
(decay, zero_debias=True, warmup_length=0, name=None)[source]¶ Maintains an exponential moving average.
This uses the Adam debiasing procedure. See https://arxiv.org/pdf/1412.6980.pdf for details.

__init__
(decay, zero_debias=True, warmup_length=0, name=None)[source]¶ Initializes an ExponentialMovingAverage module.
 Parameters
decay – The chosen decay. Must in [0, 1). Values close to 1 result in slow decay; values close to 0 result in fast decay.
zero_debias – Whether to run with zerodebiasing.
warmup_length – A positive integer, EMA has no effect until the internal counter has reached warmup_length at which point the initial value for the decaying average is initialized to the input value after warmup_length iterations.
name – The name of the module.

initialize
(shape, dtype=<class 'jax.numpy.lax_numpy.float32'>)[source]¶ If uninitialized sets the average to
zeros
of the given shape/dtype.

__call__
(value, update_stats=True)[source]¶ Updates the EMA and returns the new value.
 Parameters
value – The arraylike object for which you would like to perform an exponential decay on.
update_stats – A Boolean, whether to update the internal state of this object to reflect the input value. When update_stats is False the internal stats will remain unchanged.
 Returns
The exponentially weighted average of the input value.

SNParamsTree¶

class
haiku.
SNParamsTree
(eps=0.0001, n_steps=1, ignore_regex='', name=None)[source]¶ Applies Spectral Normalization to all parameters in a tree.
This is isomorphic to EMAParamsTree in moving_averages.py.

__init__
(eps=0.0001, n_steps=1, ignore_regex='', name=None)[source]¶ Initializes an SNParamsTree module.
 Parameters
eps (float) – The constant used for numerical stability.
n_steps (int) – How many steps of power iteration to perform to approximate the singular value of the input.
ignore_regex (str) – A string. Any parameter in the tree whose name matches this regex will not have spectral normalization applied to it. The empty string means this module apply to all parameters.
name (Optional[str]) – The name of the module.

EMAParamsTree¶

class
haiku.
EMAParamsTree
(decay, zero_debias=True, warmup_length=0, ignore_regex='', name=None)[source]¶ Maintains an exponential moving average for all parameters in a tree.
While ExponentialMovingAverage is meant to be applied to single parameters within a function, this class is meant to be applied to the entire tree of parameters for a function.
Given a set of parameters for some network:
>>> network_fn = lambda x: hk.Linear(10)(x) >>> x = jnp.ones([1, 1]) >>> params = hk.transform(network_fn).init(jax.random.PRNGKey(428), x)
You might use the EMAParamsTree like follows:
>>> ema_fn = hk.transform_with_state(lambda x: hk.EMAParamsTree(0.2)(x)) >>> _, ema_state = ema_fn.init(None, params) >>> ema_params, ema_state = ema_fn.apply(None, ema_state, None, params)
Here, we are transforming a Haiku function and constructing its parameters via an init_fn as normal, but are creating a second transformed function which expects a tree of parameters as input. This function is then called with the current parameters as input, which then returns an identical tree with every parameter replaced with its exponentially decayed average. This ema_params object can then be passed into the network_fn as usual, and will cause it to run with EMA weights.

__init__
(decay, zero_debias=True, warmup_length=0, ignore_regex='', name=None)[source]¶ Initializes an EMAParamsTree module.
 Parameters
decay – The chosen decay. Must in [0, 1). Values close to 1 result in slow decay; values close to 0 result in fast decay.
zero_debias – Whether to run with zerodebiasing.
warmup_length – A positive integer, EMA has no effect until the internal counter has reached warmup_length at which point the initial value for the decaying average is initialized to the input value after warmup_length iterations.
ignore_regex – A string. Any parameter in the tree whose name matches this regex will not have any moving average applied to it. The empty string means this module will EMA all parameters.
name – The name of the module.

Recurrent modules¶
RNNCore¶

class
haiku.
RNNCore
(name=None)[source]¶ Base class for RNN cores.
This class defines the basic functionality that every core should implement:
initial_state()
, used to construct an example of the core state; and__call__()
which applies the core parameterized by a previous state to an input.Cores may be used with
dynamic_unroll()
andstatic_unroll()
to iteratively construct an output sequence from the given input sequence.
__call__
(inputs, prev_state)[source]¶ Run one step of the RNN.
 Parameters
inputs – An arbitrarily nested structure.
prev_state – Previous core state.
 Return type
Tuple[Any, Any]
 Returns
A tuple with two elements
output, next_state
.output
is an arbitrarily nested structure.next_state
is the next core state, this must be the same shape asprev_state
.

initial_state
(batch_size)[source]¶ Constructs an initial state for this core.
 Parameters
batch_size (Optional[int]) – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.
 Returns
Arbitrarily nested initial state for this core.

dynamic_unroll¶

haiku.
dynamic_unroll
(core, input_sequence, initial_state)[source]¶ Performs a dynamic unroll of an RNN.
An unroll corresponds to calling the core on each element of the input sequence in a loop, carrying the state through:
state = initial_state for t in range(len(input_sequence)): outputs, state = core(input_sequence[t], state)
A dynamic unroll preserves the loop structure when executed inside
jax.jit
. Seestatic_unroll()
for an unroll function which replaces a loop with its body repeated multiple times. Parameters
core – An
RNNCore
to unroll.input_sequence – An arbitrarily nested structure of tensors of shape
[T, ...]
whereT
is the number of time steps.initial_state – initial state of the given core.
 Returns
output_sequence  An arbitrarily nested structure of tensors of shape
[T, ...]
.final_state  Core state at time step
T
.
 Return type
A tuple with two elements
static_unroll¶

haiku.
static_unroll
(core, input_sequence, initial_state)[source]¶ Performs a static unroll of an RNN.
An unroll corresponds to calling the core on each element of the input sequence in a loop, carrying the state through:
state = initial_state for t in range(len(input_sequence)): outputs, state = core(input_sequence[t], state)
A static unroll replaces a loop with its body repeated multiple times when executed inside
jax.jit
:state = initial_state outputs0, state = core(input_sequence[0], state) outputs1, state = core(input_sequence[1], state) outputs2, state = core(input_sequence[2], state) ...
See
dynamic_unroll()
for a looppreserving unroll function. Parameters
core – An
RNNCore
to unroll.input_sequence – An arbitrarily nested structure of tensors of shape
[T, ...]
whereT
is the number of time steps.initial_state – An initial state of the given core.
 Returns
output_sequence  An arbitrarily nested structure of tensors of shape
[T, ...]
.final_state  Core state at time step
T
.
 Return type
A tuple with two elements
expand_apply¶

haiku.
expand_apply
(f, axis=0)[source]¶ Wraps f to temporarily add a size1 axis to its inputs.
Syntactic sugar for:
ins = jax.tree_util.tree_map(lambda t: np.expand_dims(t, axis=axis), ins) out = f(ins) out = jax.tree_util.tree_map(lambda t: np.squeeze(t, axis=axis), out)
This may be useful for applying a function built for
[Time, Batch, ...]
arrays to a single timestep. Parameters
f – The callable to be applied to the expanded inputs.
axis – Where to add the extra axis.
 Returns
f, wrapped as described above.
VanillaRNN¶

class
haiku.
VanillaRNN
(hidden_size, name=None)[source]¶ Basic fullyconnected RNN core.
Given \(x_t\) and the previous hidden state \(h_{t1}\) the core computes
\[h_t = \operatorname{ReLU}(w_i x_t + b_i + w_h h_{t1} + b_h)\]
__init__
(hidden_size, name=None)[source]¶ Constructs a vanilla RNN core.
 Parameters
hidden_size (int) – Hidden layer size.
name (Optional[str]) – Name of the module.

__call__
(inputs, prev_state)[source]¶ Run one step of the RNN.
 Parameters
inputs – An arbitrarily nested structure.
prev_state – Previous core state.
 Returns
A tuple with two elements
output, next_state
.output
is an arbitrarily nested structure.next_state
is the next core state, this must be the same shape asprev_state
.

initial_state
(batch_size)[source]¶ Constructs an initial state for this core.
 Parameters
batch_size (Optional[int]) – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.
 Returns
Arbitrarily nested initial state for this core.

LSTM¶

class
haiku.
LSTM
(hidden_size, name=None)[source]¶ Long shortterm memory (LSTM) RNN core.
The implementation is based on [1]. Given \(x_t\) and the previous state \((h_{t1}, c_{t1})\) the core computes
\[\begin{array}{ll} i_t = \sigma(W_{ii} x_t + W_{hi} h_{t1} + b_i) \\ f_t = \sigma(W_{if} x_t + W_{hf} h_{t1} + b_f) \\ g_t = \tanh(W_{ig} x_t + W_{hg} h_{t1} + b_g) \\ o_t = \sigma(W_{io} x_t + W_{ho} h_{t1} + b_o) \\ c_t = f_t c_{t1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\]where \(i_t\), \(f_t\), \(o_t\) are input, forget and output gate activations, and \(g_t\) is a vector of cell updates.
Notes
 Forget gate initialization:
Following [2] we add 1.0 to \(b_f\) after initialization in order to reduce the scale of forgetting in the beginning of the training.

__init__
(hidden_size, name=None)[source]¶ Constructs an LSTM.
 Parameters
hidden_size (int) – Hidden layer size.
name (Optional[str]) – Name of the module.

__call__
(inputs, prev_state)[source]¶ Run one step of the RNN.
 Parameters
inputs (jnp.ndarray) – An arbitrarily nested structure.
prev_state (LSTMState) – Previous core state.
 Return type
Tuple[jnp.ndarray, LSTMState]
 Returns
A tuple with two elements
output, next_state
.output
is an arbitrarily nested structure.next_state
is the next core state, this must be the same shape asprev_state
.

initial_state
(batch_size)[source]¶ Constructs an initial state for this core.
 Parameters
batch_size (Optional[int]) – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.
 Return type
 Returns
Arbitrarily nested initial state for this core.
GRU¶

class
haiku.
GRU
(hidden_size, w_i_init=None, w_h_init=None, b_init=None, name=None)[source]¶ Gated Recurrent Unit.
The implementation is based on: https://arxiv.org/pdf/1412.3555v1.pdf with biases.
Given \(x_t\) and the previous state \(h_{t1}\) the core computes
\[\begin{array}{ll} z_t &= \sigma(W_{iz} x_t + W_{hz} h_{t1} + b_z) \\ r_t &= \sigma(W_{ir} x_t + W_{hr} h_{t1} + b_r) \\ a_t &= \tanh(W_{ia} x_t + W_{ha} (r_t \bigodot h_{t1}) + b_a) \\ h_t &= (1  z_t) \bigodot h_{t1} + z_t \bigodot a_t \end{array}\]where \(z_t\) and \(r_t\) are reset and update gates.
Warning: Backwards compatibility of GRU weights is currently unsupported.
TODO(tycai): Make policy decision/benchmark performance for GRU variants.

__init__
(hidden_size, w_i_init=None, w_h_init=None, b_init=None, name=None)[source]¶ Initializes the current module with the given name.
Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.
 Parameters
name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If
name
is not provided then the class name for the current instance is converted tolower_snake_case
and used instead.

__call__
(inputs, state)[source]¶ Run one step of the RNN.
 Parameters
inputs – An arbitrarily nested structure.
prev_state – Previous core state.
 Returns
A tuple with two elements
output, next_state
.output
is an arbitrarily nested structure.next_state
is the next core state, this must be the same shape asprev_state
.

initial_state
(batch_size)[source]¶ Constructs an initial state for this core.
 Parameters
batch_size (Optional[int]) – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.
 Returns
Arbitrarily nested initial state for this core.

DeepRNN¶

class
haiku.
DeepRNN
(layers, name=None)[source]¶ Wraps a sequence of cores and callables as a single core.
>>> deep_rnn = hk.DeepRNN([ ... hk.LSTM(hidden_size=4), ... jax.nn.relu, ... hk.LSTM(hidden_size=2), ... ])
The state of a
DeepRNN
is a tuple with one element perRNNCore
. If no layers areRNNCore
s, the state is an empty tuple.
__init__
(layers, name=None)[source]¶ Initializes the current module with the given name.
Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.
 Parameters
name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If
name
is not provided then the class name for the current instance is converted tolower_snake_case
and used instead.


haiku.
deep_rnn_with_skip_connections
(layers, name=None)[source]¶ Constructs a
DeepRNN
with skip connections.Skip connections alter the dependency structure within a
DeepRNN
. Specifically, input to the ith layer (i > 0) is given by a concatenation of the core’s inputs and the outputs of the (i1)th layer.The output of the
DeepRNN
is the concatenation of the outputs of all cores.outputs0, ... = layers[0](inputs, ...) outputs1, ... = layers[1](tf.concat([inputs, outputs0], axis=1], ...) outputs2, ... = layers[2](tf.concat([inputs, outputs1], axis=1], ...) ...
ResetCore¶

class
haiku.
ResetCore
(core, name=None)[source]¶ A wrapper for managing state resets during unrolls.
When unrolling an
RNNCore
on a batch of inputs sequences it may be necessary to reset the core’s state at different timesteps for different elements of the batch. TheResetCore
class enables this by taking a batch ofshould_reset
booleans in addition to the batch of inputs, and conditionally resetting the core’s state for individual elements of the batch. You may also reset individual entries of the state by passing ashould_reset
nest compatible with the state structure.
__init__
(core, name=None)[source]¶ Initializes the current module with the given name.
Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.
 Parameters
name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If
name
is not provided then the class name for the current instance is converted tolower_snake_case
and used instead.

__call__
(inputs, state)[source]¶ Run one step of the wrapped core, handling state reset.
 Parameters
inputs – Tuple with two elements,
inputs, should_reset
, whereshould_reset
is the signal used to reset the wrapped core’s state.should_reset
can be either tensor or nest. If nest,should_reset
must match the state structure, and its components’ shapes must be prefixes of the corresponding entries tensors’ shapes in the state nest. If tensor, supported shapes are all commom shape prefixes of the state component tensors, e.g.[batch_size]
.state – Previous wrapped core state.
 Returns
Tuple of the wrapped core’s
output, next_state
.

initial_state
(batch_size)[source]¶ Constructs an initial state for this core.
 Parameters
batch_size (Optional[int]) – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.
 Returns
Arbitrarily nested initial state for this core.

IdentityCore¶

class
haiku.
IdentityCore
(name=None)[source]¶ A recurrent core that forwards the inputs and an empty state.
This is commonly used when switching between recurrent and feedforward versions of a model while preserving the same interface.

__call__
(inputs, state)[source]¶ Run one step of the RNN.
 Parameters
inputs – An arbitrarily nested structure.
prev_state – Previous core state.
 Returns
A tuple with two elements
output, next_state
.output
is an arbitrarily nested structure.next_state
is the next core state, this must be the same shape asprev_state
.

initial_state
(batch_size)[source]¶ Constructs an initial state for this core.
 Parameters
batch_size (Optional[int]) – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.
 Returns
Arbitrarily nested initial state for this core.

Conv1DLSTM¶

class
haiku.
Conv1DLSTM
(input_shape, output_channels, kernel_shape, name=None)[source]¶ 1D convolutional LSTM.
The implementation is based on [3]. Given \(x_t\) and the previous state \((h_{t1}, c_{t1})\) the core computes
\[\begin{array}{ll} i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t1} + b_i) \\ f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t1} + b_f) \\ g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t1} + b_g) \\ o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t1} + b_o) \\ c_t = f_t c_{t1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\]where \(*\) denotes the convolution operator; \(i_t\), \(f_t\), \(o_t\) are input, forget and output gate activations, and \(g_t\) is a vector of cell updates.
Notes
 Forget gate initialization:
Following [2] we add 1.0 to \(b_f\) after initialization in order to reduce the scale of forgetting in the beginning of the training.

__init__
(input_shape, output_channels, kernel_shape, name=None)[source]¶ Constructs a 1D convolutional LSTM.
 Parameters
input_shape (Sequence[int]) – Shape of the inputs excluding batch size.
output_channels (int) – Number of output channels.
kernel_shape (Union[int, Sequence[int]]) – Sequence of kernel sizes (of length 1), or an int.
kernel_shape
will be expanded to define a kernel size in all dimensions.name (Optional[str]) – Name of the module.
Conv2DLSTM¶

class
haiku.
Conv2DLSTM
(input_shape, output_channels, kernel_shape, name=None)[source]¶ 2D convolutional LSTM.
The implementation is based on [3]. Given \(x_t\) and the previous state \((h_{t1}, c_{t1})\) the core computes
\[\begin{array}{ll} i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t1} + b_i) \\ f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t1} + b_f) \\ g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t1} + b_g) \\ o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t1} + b_o) \\ c_t = f_t c_{t1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\]where \(*\) denotes the convolution operator; \(i_t\), \(f_t\), \(o_t\) are input, forget and output gate activations, and \(g_t\) is a vector of cell updates.
Notes
 Forget gate initialization:
Following [2] we add 1.0 to \(b_f\) after initialization in order to reduce the scale of forgetting in the beginning of the training.

__init__
(input_shape, output_channels, kernel_shape, name=None)[source]¶ Constructs a 2D convolutional LSTM.
 Parameters
input_shape (Sequence[int]) – Shape of the inputs excluding batch size.
output_channels (int) – Number of output channels.
kernel_shape (Union[int, Sequence[int]]) – Sequence of kernel sizes (of length 2), or an int.
kernel_shape
will be expanded to define a kernel size in all dimensions.name (Optional[str]) – Name of the module.
Conv3DLSTM¶

class
haiku.
Conv3DLSTM
(input_shape, output_channels, kernel_shape, name=None)[source]¶ 3D convolutional LSTM.
The implementation is based on [3]. Given \(x_t\) and the previous state \((h_{t1}, c_{t1})\) the core computes
\[\begin{array}{ll} i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t1} + b_i) \\ f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t1} + b_f) \\ g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t1} + b_g) \\ o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t1} + b_o) \\ c_t = f_t c_{t1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\]where \(*\) denotes the convolution operator; \(i_t\), \(f_t\), \(o_t\) are input, forget and output gate activations, and \(g_t\) is a vector of cell updates.
Notes
 Forget gate initialization:
Following [2] we add 1.0 to \(b_f\) after initialization in order to reduce the scale of forgetting in the beginning of the training.

__init__
(input_shape, output_channels, kernel_shape, name=None)[source]¶ Constructs a 3D convolutional LSTM.
 Parameters
input_shape (Sequence[int]) – Shape of the inputs excluding batch size.
output_channels (int) – Number of output channels.
kernel_shape (Union[int, Sequence[int]]) – Sequence of kernel sizes (of length 3), or an int.
kernel_shape
will be expanded to define a kernel size in all dimensions.name (Optional[str]) – Name of the module.
Batch modules¶
Reshape¶

class
haiku.
Reshape
(output_shape, preserve_dims=1, name=None)[source]¶ Reshapes input Tensor, preserving the batch dimension.
For example, given an input tensor with shape
[B, H, W, C, D]
:>>> B, H, W, C, D = range(1, 6) >>> x = jnp.ones([B, H, W, C, D])
The default behavior when
output_shape
is(1, D)
is to flatten all dimensions betweenB
andD
:>>> mod = hk.Reshape(output_shape=(1, D)) >>> assert mod(x).shape == (B, H*W*C, D)
You can change the number of preserved leading dimensions via
preserve_dims
:>>> mod = hk.Reshape(output_shape=(1, D), preserve_dims=2) >>> assert mod(x).shape == (B, H, W*C, D) >>> mod = hk.Reshape(output_shape=(1, D), preserve_dims=3) >>> assert mod(x).shape == (B, H, W, C, D) >>> mod = hk.Reshape(output_shape=(1, D), preserve_dims=4) >>> assert mod(x).shape == (B, H, W, C, 1, D)

__init__
(output_shape, preserve_dims=1, name=None)[source]¶ Constructs a
Reshape
module. Parameters
output_shape (Sequence[int]) – Shape to reshape the input tensor to while preserving its first
preserve_dims
dimensions. When the special value1
appears inoutput_shape
the corresponding size is automatically inferred. Note that1
can only appear once inoutput_shape
. To flatten all nonbatch dimensions useFlatten
.preserve_dims (int) – Number of leading dimensions that will not be reshaped.
name (Optional[str]) – Name of the module.
 Raises
ValueError – If
preserve_dims
is not positive.

Flatten¶

class
haiku.
Flatten
(preserve_dims=1, name=None)[source]¶ Flattens the input, preserving the batch dimension(s).
By default,
Flatten
combines all dimensions except the first. Additional leading dimensions can be preserved by settingpreserve_dims
.>>> x = jnp.ones([3, 2, 4]) >>> flat = hk.Flatten() >>> flat(x).shape (3, 8)
When the input to flatten has fewer than
preserve_dims
dimensions it is returned unchanged:>>> x = jnp.ones([3]) >>> flat(x).shape (3,)

__init__
(preserve_dims=1, name=None)[source]¶ Constructs a
Reshape
module. Parameters
output_shape – Shape to reshape the input tensor to while preserving its first
preserve_dims
dimensions. When the special value1
appears inoutput_shape
the corresponding size is automatically inferred. Note that1
can only appear once inoutput_shape
. To flatten all nonbatch dimensions useFlatten
.preserve_dims (int) – Number of leading dimensions that will not be reshaped.
name (Optional[str]) – Name of the module.
 Raises
ValueError – If
preserve_dims
is not positive.

BatchApply¶

class
haiku.
BatchApply
(f, num_dims=2)[source]¶ Temporarily merges leading dimensions of input tensors.
Merges the leading dimensions of a tensor into a single dimension, runs the given callable, then splits the leading dimension of the result to match the input.
Input arrays whose rank is smaller than the number of dimensions to collapse are passed unmodified.
This may be useful for applying a module to each timestep of e.g. a
[Time, Batch, ...]
array.For some
f
s and platforms, this may be more efficient thanjax.vmap
, especially when combined with other transformations likejax.grad
.
__init__
(f, num_dims=2)[source]¶ Constructs a
BatchApply
module. Parameters
f – The callable to be applied to the reshaped array.
num_dims – The number of dimensions to merge.

Embedding modules¶
Embed¶

class
haiku.
Embed
(vocab_size=None, embed_dim=None, embedding_matrix=None, w_init=None, lookup_style='ARRAY_INDEX', name=None)[source]¶ Module for embedding tokens in a lowdimensional space.

__init__
(vocab_size=None, embed_dim=None, embedding_matrix=None, w_init=None, lookup_style='ARRAY_INDEX', name=None)[source]¶ Constructs an Embed module.
 Parameters
vocab_size (Optional[int]) – The number of unique tokens to embed. If not provided, an existing vocabulary matrix from which
vocab_size
can be inferred must be provided asexisting_vocab
.embed_dim (Optional[int]) – Number of dimensions to assign to each embedding. If an existing vocabulary matrix initializes the module, this should not be provided as it will be inferred.
embedding_matrix (Optional[jnp.ndarray]) – A matrixlike object equivalent in size to
[vocab_size, embed_dim]
. If given, it is used as the initial value for the embedding matrix and neithervocab_size
orembed_dim
need be given. If they are given, their values are checked to be consistent with the dimensions ofembedding_matrix
.w_init (Optional[hk.initializers.Initializer]) – An initializer for the embeddings matrix. As a default, embeddings are initialized via a truncated normal distribution.
lookup_style (Union[str, hk.EmbedLookupStyle]) – One of the enum values of
EmbedLookupStyle
determining how to access the value of the embbeddings given an ID. Regardless the input should be a dense array of integer values representing ids. This setting changes how internally this module maps those ides to embeddings. The result is the same, but the speed and memory tradeoffs are different. It default to using numpystyle array indexing. This value is only the default for the module, and at any given invocation can be overriden in__call__()
.name (Optional[str]) – Optional name for this module.
 Raises
ValueError – If none of
embed_dim
,embedding_matrix
andvocab_size
are supplied, or ifembedding_matrix
is supplied andembed_dim
orvocab_size
is not consistent with the supplied matrix.

__call__
(ids, lookup_style=None)[source]¶ Lookup embeddings.
Looks up an embedding vector for each value in
ids
. All ids must be within[0, vocab_size)
to preventNaN
s from propagating. Parameters
ids (jnp.ndarray) – integer array.
lookup_style (Optional[Union[str, hk.EmbedLookupStyle]]) – Overrides the
lookup_style
given in the constructor.
 Return type
jnp.ndarray
 Returns
Tensor of
ids.shape + [embedding_dim]
. Raises
AttributeError – If
lookup_style
is not valid.ValueError – If
ids
is not an integer array.

Initializers¶
Constant¶
Identity¶

class
haiku.initializers.
Identity
(gain=1.0)[source]¶ Initializer that generates the identity matrix.
Constructs a 2D identity matrix or batches of these.
Orthogonal¶

class
haiku.initializers.
Orthogonal
(scale=1.0, axis=1)[source]¶ Uniform scaling initializer.

__init__
(scale=1.0, axis=1)[source]¶ Construct an initializer for uniformly distributed orthogonal matrices.
These matrices will be roworthonormal along the access specified by
axis
. If the rank of the weight is greater than 2, the shape will be flattened in all other dimensions and then will be roworthonormal along the final dimension. Note that this only works if theaxis
dimension is larger, otherwise the matrix will be transposed (equivalently, it will be column orthonormal instead of row orthonormal).If the shape is not square, the matrices will have orthonormal rows or columns depending on which side is smaller.
 Parameters
scale – Scale factor.
axis – Which axis corresponds to the “output dimension” of the tensor.
 Returns
An orthogonally initialized parameter.

RandomNormal¶

class
haiku.initializers.
RandomNormal
(stddev=1.0, mean=0.0)[source]¶ Initializes by sampling from a normal distribution.

__init__
(stddev=1.0, mean=0.0)[source]¶ Constructs a
RandomNormal
initializer. Parameters
stddev – The standard deviation of the normal distribution to sample from.
mean – The mean of the normal distribution to sample from.

RandomUniform¶

class
haiku.initializers.
RandomUniform
(minval=0.0, maxval=1.0)[source]¶ Initializes by sampling from a uniform distribution.

__init__
(minval=0.0, maxval=1.0)[source]¶ Constructs a
RandomUniform
initializer. Parameters
minval – The lower limit of the uniform distribution.
maxval – The upper limit of the uniform distribution.

TruncatedNormal¶

class
haiku.initializers.
TruncatedNormal
(stddev=1.0, mean=0.0)[source]¶ Initializes by sampling from a truncated normal distribution.

__init__
(stddev=1.0, mean=0.0)[source]¶ Constructs a
TruncatedNormal
initializer. Parameters
stddev – The standard deviation parameter of the truncated normal distribution.
mean – The mean of the truncated normal distribution.

VarianceScaling¶

class
haiku.initializers.
VarianceScaling
(scale=1.0, mode='fan_in', distribution='truncated_normal')[source]¶ Initializer which adapts its scale to the shape of the initialized array.
The initializer first computes the scaling factor
s = scale / n
, where n is:Number of input units in the weight tensor, if
mode = fan_in
.Number of output units, if
mode = fan_out
.Average of the numbers of input and output units, if
mode = fan_avg
.
Then, with
distribution="truncated_normal"
or"normal"
, samples are drawn from a distribution with a mean of zero and a standard deviation (after truncation, if used)stddev = sqrt(s)
.With
distribution=uniform
, samples are drawn from a uniform distribution within[limit, limit]
, withlimit = sqrt(3 * s)
.The variance scaling initializer can be configured to generate other standard initializers using the scale, mode and distribution arguments. Here are some example configurations:
Name
Parameters
glorot_uniform
VarianceScaling(1.0, “fan_avg”, “uniform”)
glorot_normal
VarianceScaling(1.0, “fan_avg”, “truncated_normal”)
lecun_uniform
VarianceScaling(1.0, “fan_in”, “uniform”)
lecun_normal
VarianceScaling(1.0, “fan_in”, “truncated_normal”)
he_uniform
VarianceScaling(2.0, “fan_in”, “uniform”)
he_normal
VarianceScaling(2.0, “fan_in”, “truncated_normal”)

__init__
(scale=1.0, mode='fan_in', distribution='truncated_normal')[source]¶ Constructs the
VarianceScaling
initializer. Parameters
scale – Scale to multiply the variance by.
mode – One of
fan_in
,fan_out
,fan_avg
distribution – Random distribution to use. One of
truncated_normal
,normal
oruniform
.
UniformScaling¶

class
haiku.initializers.
UniformScaling
(scale=1.0)[source]¶ Uniform scaling initializer.
Initializes by sampling from a uniform distribution, but with the variance scaled by the inverse square root of the number of input units, multiplied by the scale.

__init__
(scale=1.0)[source]¶ Constructs the
UniformScaling
initializer. Parameters
scale – Scale to multiply the upper limit of the uniform distribution by.

Paddings¶
create¶

haiku.pad.
create
(padding, kernel, rate, n)[source]¶ Generates the padding required for a given padding algorithm.
 Parameters
padding (Union[hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – callable or list of callables of length n. The callables take an integer representing the effective kernel size (kernel size when the rate is 1) and return a list of two integers representing the padding before and padding after for that dimension.
kernel (Union[int, Sequence[int]]) – int or list of ints of length
n
. The size of the kernel for each dimension. If it is an int it will be replicated for the non channel and batch dimensions.rate (Union[int, Sequence[int]]) – int or list of ints of length
n
. The dilation rate for each dimension. If it is an int it will be replicated for the non channel and batch dimensions.n (int) – the number of spatial dimensions.
 Return type
Sequence[Tuple[int, int]]
 Returns
A sequence of length n containing the padding for each element. These are of the form
[pad_before, pad_after]
.
causal¶
full¶
reverse_causal¶
same¶
Networks¶
MLP¶

class
haiku.nets.
MLP
(output_sizes, w_init=None, b_init=None, with_bias=True, activation=<jax.custom_derivatives.custom_jvp object>, activate_final=False, name=None)[source]¶ A multilayer perceptron module.

__init__
(output_sizes, w_init=None, b_init=None, with_bias=True, activation=<jax.custom_derivatives.custom_jvp object>, activate_final=False, name=None)[source]¶ Constructs an MLP.
 Parameters
output_sizes (Iterable[int]) – Sequence of layer sizes.
w_init (Optional[hk.initializers.Initializer]) – Initializer for
Linear
weights.b_init (Optional[hk.initializers.Initializer]) – Initializer for
Linear
bias. Must beNone
ifwith_bias=False
.with_bias (bool) – Whether or not to apply a bias in each layer.
activation (Callable[[jnp.ndarray], jnp.ndarray]) – Activation function to apply between
Linear
layers. Defaults to ReLU.activate_final (bool) – Whether or not to activate the final layer of the MLP.
name (Optional[str]) – Optional name for this module.
 Raises
ValueError – If
with_bias
isFalse
andb_init
is notNone
.

__call__
(inputs, dropout_rate=None, rng=None)[source]¶ Connects the module to some inputs.
 Parameters
inputs (jnp.ndarray) – A Tensor of shape
[batch_size, input_size]
.dropout_rate (Optional[float]) – Optional dropout rate.
rng – Optional RNG key. Require when using dropout.
 Return type
jnp.ndarray
 Returns
The output of the model of size
[batch_size, output_size]
.

reverse
(activate_final=None, name=None)[source]¶ Returns a new MLP which is the layerwise reverse of this MLP.
NOTE: Since computing the reverse of an MLP requires knowing the input size of each linear layer this method will fail if the module has not been called at least once.
The contract of reverse is that the reversed module will accept the output of the parent module as input and produce an output which is the input size of the parent.
>>> mlp = hk.nets.MLP([1, 2, 3]) >>> y = mlp(jnp.ones([1, 2])) >>> rev = mlp.reverse() >>> rev(y) DeviceArray(...)
 Parameters
activate_final (Optional[bool]) – Whether the final layer of the MLP should be activated.
name (Optional[str]) – Optional name for the new module. The default name will be the name of the current module prefixed with
"reversed_"
.
 Return type
‘MLP’
 Returns
An MLP instance which is the reverse of the current instance. Note these instances do not share weights and, apart from being symmetric to each other, are not coupled in any way.

MobileNetV1¶

class
haiku.nets.
MobileNetV1
(strides=(1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1), channels=(64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024), num_classes=1000, use_bn=True, name=None)[source]¶ MobileNetV1 model.

__init__
(strides=(1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1), channels=(64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024), num_classes=1000, use_bn=True, name=None)[source]¶ Constructs a MobileNetV1 model.
 Parameters
strides (Sequence[int]) – The stride to use the in depthwise convolution in each mobilenet block.
channels (Sequence[int]) – Number of output channels from the pointwise convolution to use in each block.
num_classes (int) – Number of classes.
use_bn (bool) – Whether or not to use batch normalization. Defaults to True. When true, biases are not used. When false, biases are used.
name (Optional[str]) – Name of the module.

ResNet¶

class
haiku.nets.
ResNet
(blocks_per_group, num_classes, bn_config=None, resnet_v2=False, bottleneck=True, channels_per_group=(256, 512, 1024, 2048), use_projection=(True, True, True, True), logits_config=None, name=None)[source]¶ ResNet model.

class
BlockGroup
(channels, num_blocks, stride, bn_config, resnet_v2, bottleneck, use_projection, name=None)[source]¶ Higher level block for ResNet implementation.

__init__
(channels, num_blocks, stride, bn_config, resnet_v2, bottleneck, use_projection, name=None)[source]¶ Initializes the current module with the given name.
Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.
 Parameters
name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If
name
is not provided then the class name for the current instance is converted tolower_snake_case
and used instead.


class
BlockV1
(channels, stride, use_projection, bn_config, bottleneck, name=None)[source]¶ ResNet V1 block with optional bottleneck.

__init__
(channels, stride, use_projection, bn_config, bottleneck, name=None)[source]¶ Initializes the current module with the given name.
Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.
 Parameters
name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If
name
is not provided then the class name for the current instance is converted tolower_snake_case
and used instead.


class
BlockV2
(channels, stride, use_projection, bn_config, bottleneck, name=None)[source]¶ ResNet V2 block with optional bottleneck.

__init__
(channels, stride, use_projection, bn_config, bottleneck, name=None)[source]¶ Initializes the current module with the given name.
Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.
 Parameters
name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If
name
is not provided then the class name for the current instance is converted tolower_snake_case
and used instead.


__init__
(blocks_per_group, num_classes, bn_config=None, resnet_v2=False, bottleneck=True, channels_per_group=(256, 512, 1024, 2048), use_projection=(True, True, True, True), logits_config=None, name=None)[source]¶ Constructs a ResNet model.
 Parameters
blocks_per_group (Sequence[int]) – A sequence of length 4 that indicates the number of blocks created in each group.
num_classes (int) – The number of classes to classify the inputs into.
bn_config (Optional[Mapping[str, float]]) – A dictionary of two elements,
decay_rate
andeps
to be passed on to theBatchNorm
layers. By default thedecay_rate
is0.9
andeps
is1e5
.resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to
False
.bottleneck (bool) – Whether the block should bottleneck or not. Defaults to
True
.channels_per_group (Sequence[int]) – A sequence of length 4 that indicates the number of channels used for each block in each group.
use_projection (Sequence[bool]) – A sequence of length 4 that indicates whether each residual block should use projection.
logits_config (Optional[Mapping[str, Any]]) – A dictionary of keyword arguments for the logits layer.
name (Optional[str]) – Name of the module.

class
ResNet18¶

class
haiku.nets.
ResNet18
(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None)[source]¶ ResNet18.

__init__
(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None)[source]¶ Constructs a ResNet model.
 Parameters
num_classes (int) – The number of classes to classify the inputs into.
bn_config (Optional[Mapping[str, float]]) – A dictionary of two elements,
decay_rate
andeps
to be passed on to theBatchNorm
layers.resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to
False
.logits_config (Optional[Mapping[str, Any]]) – A dictionary of keyword arguments for the logits layer.
name (Optional[str]) – Name of the module.

ResNet34¶

class
haiku.nets.
ResNet34
(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None)[source]¶ ResNet34.

__init__
(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None)[source]¶ Constructs a ResNet model.
 Parameters
num_classes (int) – The number of classes to classify the inputs into.
bn_config (Optional[Mapping[str, float]]) – A dictionary of two elements,
decay_rate
andeps
to be passed on to theBatchNorm
layers.resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to
False
.logits_config (Optional[Mapping[str, Any]]) – A dictionary of keyword arguments for the logits layer.
name (Optional[str]) – Name of the module.

ResNet50¶

class
haiku.nets.
ResNet50
(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None)[source]¶ ResNet50.

__init__
(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None)[source]¶ Constructs a ResNet model.
 Parameters
num_classes (int) – The number of classes to classify the inputs into.
bn_config (Optional[Mapping[str, float]]) – A dictionary of two elements,
decay_rate
andeps
to be passed on to theBatchNorm
layers.resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to
False
.logits_config (Optional[Mapping[str, Any]]) – A dictionary of keyword arguments for the logits layer.
name (Optional[str]) – Name of the module.

ResNet101¶

class
haiku.nets.
ResNet101
(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None)[source]¶ ResNet101.

__init__
(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None)[source]¶ Constructs a ResNet model.
 Parameters
num_classes (int) – The number of classes to classify the inputs into.
bn_config (Optional[Mapping[str, float]]) – A dictionary of two elements,
decay_rate
andeps
to be passed on to theBatchNorm
layers.resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to
False
.logits_config (Optional[Mapping[str, Any]]) – A dictionary of keyword arguments for the logits layer.
name (Optional[str]) – Name of the module.

ResNet152¶

class
haiku.nets.
ResNet152
(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None)[source]¶ ResNet152.

__init__
(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None)[source]¶ Constructs a ResNet model.
 Parameters
num_classes (int) – The number of classes to classify the inputs into.
bn_config (Optional[Mapping[str, float]]) – A dictionary of two elements,
decay_rate
andeps
to be passed on to theBatchNorm
layers.resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to
False
.logits_config (Optional[Mapping[str, Any]]) – A dictionary of keyword arguments for the logits layer.
name (Optional[str]) – Name of the module.

ResNet200¶

class
haiku.nets.
ResNet200
(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None)[source]¶ ResNet200.

__init__
(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None)[source]¶ Constructs a ResNet model.
 Parameters
num_classes (int) – The number of classes to classify the inputs into.
bn_config (Optional[Mapping[str, float]]) – A dictionary of two elements,
decay_rate
andeps
to be passed on to theBatchNorm
layers.resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to
False
.logits_config (Optional[Mapping[str, Any]]) – A dictionary of keyword arguments for the logits layer.
name (Optional[str]) – Name of the module.

VectorQuantizer¶

class
haiku.nets.
VectorQuantizer
(embedding_dim, num_embeddings, commitment_cost, dtype=<class 'jax.numpy.lax_numpy.float32'>, name=None)[source]¶ Haiku module representing the VQVAE layer.
Implements the algorithm presented in “Neural Discrete Representation Learning” by van den Oord et al. https://arxiv.org/abs/1711.00937
Input any tensor to be quantized. Last dimension will be used as space in which to quantize. All other dimensions will be flattened and will be seen as different examples to quantize.
The output tensor will have the same shape as the input.
For example a tensor with shape
[16, 32, 32, 64]
will be reshaped into[16384, 64]
and all16384
vectors (each of64
dimensions) will be quantized independently.
embedding_dim
¶ integer representing the dimensionality of the tensors in the quantized space. Inputs to the modules must be in this format as well.

num_embeddings
¶ integer, the number of vectors in the quantized space.

commitment_cost
¶ scalar which controls the weighting of the loss terms (see equation 4 in the paper  this variable is Beta).

__init__
(embedding_dim, num_embeddings, commitment_cost, dtype=<class 'jax.numpy.lax_numpy.float32'>, name=None)[source]¶ Initializes a VQVAE module.
 Parameters
embedding_dim (int) – dimensionality of the tensors in the quantized space. Inputs to the modules must be in this format as well.
num_embeddings (int) – number of vectors in the quantized space.
commitment_cost (float) – scalar which controls the weighting of the loss terms (see equation 4 in the paper  this variable is Beta).
dtype (Any) – dtype for the embeddings variable, defaults to
float32
.name (str) – name of the module.

__call__
(inputs, is_training)[source]¶ Connects the module to some inputs.
 Parameters
inputs – Tensor, final dimension must be equal to
embedding_dim
. All other leading dimensions will be flattened and treated as a large batch.is_training – boolean, whether this connection is to training data.
 Returns
 Dictionary containing the following keys and values:
quantize
: Tensor containing the quantized version of the input.loss
: Tensor containing the loss to optimize.perplexity
: Tensor containing the perplexity of the encodings.encodings
: Tensor containing the discrete encodings, ie which element of the quantized space each input element was mapped to.encoding_indices
: Tensor containing the discrete encoding indices, ie which element of the quantized space each input element was mapped to.
 Return type
dict

VectorQuantizerEMA¶

class
haiku.nets.
VectorQuantizerEMA
(embedding_dim, num_embeddings, commitment_cost, decay, epsilon=1e05, dtype=<class 'jax.numpy.lax_numpy.float32'>, cross_replica_axis=None, name=None)[source]¶ Haiku module representing the VQVAE layer.
Implements a slightly modified version of the algorithm presented in “Neural Discrete Representation Learning” by van den Oord et al. https://arxiv.org/abs/1711.00937
The difference between
VectorQuantizerEMA
andVectorQuantizer
is that this module usesExponentialMovingAverage
s to update the embedding vectors instead of an auxiliary loss. This has the advantage that the embedding updates are independent of the choice of optimizer (SGD, RMSProp, Adam, KFac, …) used for the encoder, decoder and other parts of the architecture. For most experiments the EMA version trains faster than the nonEMA version.Input any tensor to be quantized. Last dimension will be used as space in which to quantize. All other dimensions will be flattened and will be seen as different examples to quantize.
The output tensor will have the same shape as the input.
For example a tensor with shape
[16, 32, 32, 64]
will be reshaped into[16384, 64]
and all16384
vectors (each of 64 dimensions) will be quantized independently.
embedding_dim
¶ integer representing the dimensionality of the tensors in the quantized space. Inputs to the modules must be in this format as well.

num_embeddings
¶ integer, the number of vectors in the quantized space.

commitment_cost
¶ scalar which controls the weighting of the loss terms (see equation 4 in the paper).

decay
¶ float, decay for the moving averages.

epsilon
¶ small float constant to avoid numerical instability.

__init__
(embedding_dim, num_embeddings, commitment_cost, decay, epsilon=1e05, dtype=<class 'jax.numpy.lax_numpy.float32'>, cross_replica_axis=None, name=None)[source]¶ Initializes a VQVAE EMA module.
 Parameters
embedding_dim – integer representing the dimensionality of the tensors in the quantized space. Inputs to the modules must be in this format as well.
num_embeddings – integer, the number of vectors in the quantized space.
commitment_cost – scalar which controls the weighting of the loss terms (see equation 4 in the paper  this variable is Beta).
decay – float between 0 and 1, controls the speed of the Exponential Moving Averages.
epsilon (float) – small constant to aid numerical stability, default
1e5
.dtype (Any) – dtype for the embeddings variable, defaults to
float32
.cross_replica_axis (Optional[str]) – If not
None
, it should be a string representing the axis name over which this module is being run within ajax.pmap
. Supplying this argument means that cluster statistics and the perplexity are calculated across all replicas on that axis.name (str) – name of the module.

__call__
(inputs, is_training)[source]¶ Connects the module to some inputs.
 Parameters
inputs – Tensor, final dimension must be equal to
embedding_dim
. All other leading dimensions will be flattened and treated as a large batch.is_training – boolean, whether this connection is to training data. When this is set to
False
, the internal moving average statistics will not be updated.
 Returns
 Dictionary containing the following keys and values:
quantize
: Tensor containing the quantized version of the input.loss
: Tensor containing the loss to optimize.perplexity
: Tensor containing the perplexity of the encodings.encodings
: Tensor containing the discrete encodings, ie which element of the quantized space each input element was mapped to.encoding_indices
: Tensor containing the discrete encoding indices, ie which element of the quantized space each input element was mapped to.
 Return type
dict

JAX Transforms¶
cond¶
fori_loop¶
grad¶

haiku.
grad
(fun, argnums=0, has_aux=False, holomorphic=False)[source]¶ Creates a function which evaluates the gradient of
fun
.NOTE: You only need this in a very specific case that you want to take a gradient inside a
transform()
ed function and the function you are differentiating usesset_state()
. For example:>>> class MyModule(hk.Module): ... def __call__(self, x): ... hk.set_state("last", x ** 2) ... return x ** 2
>>> def f(x): ... m = MyModule() ... g = hk.grad(m)(x) ... return g
>>> f = hk.transform_with_state(f) >>> x = jnp.array(2.) >>> params, state = jax.jit(f.init)(None, x) >>> state["my_module"]["last"] DeviceArray(4., dtype=float32)
 Parameters
fun – Function to be differentiated. Its arguments at positions specified by
argnums
should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape()
but not arrays with shape(1,)
etc.)argnums – Optional, integer or tuple of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).
has_aux – Optional, bool. Indicates whether
fun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic – Optional, bool. Indicates whether
fun
is promised to be holomorphic. Default False.
 Returns
A function with the same arguments as
fun
, that evaluates the gradient offun
. If argnums is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. Ifhas_aux
is True then a pair ofgradient, auxiliary_data
is returned.
For example:
>>> grad_tanh = jax.grad(jax.numpy.tanh) >>> print(grad_tanh(0.2)) 0.96...
jit¶
remat¶
scan¶
value_and_grad¶

haiku.
value_and_grad
(fun, argnums=0, has_aux=False, holomorphic=False)[source]¶ Creates a function which evaluates both
fun
and the grad offun
.NOTE: You only need this in a very specific case that you want to take a gradient inside a
transform()
ed function and the function you are differentiating usesset_state()
. For example:>>> class MyModule(hk.Module): ... def __call__(self, x): ... hk.set_state("last", jnp.sum(x)) ... return x ** 2
>>> def f(x): ... m = MyModule() ... y, g = hk.value_and_grad(m)(x) ... return y, g
>>> f = hk.transform_with_state(f) >>> x = jnp.array(2.) >>> _ = jax.jit(f.init)(None, x)
 Parameters
fun – Function to be differentiated. Its arguments at positions specified by
argnums
should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape()
but not arrays with shape(1,)
etc.)argnums – Optional, integer or tuple of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).
has_aux – Optional, bool. Indicates whether
fun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic – Optional, bool. Indicates whether
fun
is promised to be holomorphic. Default False.
 Returns
A function with the same arguments as
fun
that evaluates bothfun
and the gradient offun
and returns them as a pair (a twoelement tuple). Ifargnums
is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments.
Testing¶
transform_and_run¶

haiku.testing.
transform_and_run
(f=None, seed=42, run_apply=True)[source]¶ Transforms the given function and runs init then (optionally) apply.
Equivalent to:
>>> def f(x): ... return x >>> x = jnp.ones([]) >>> rng = jax.random.PRNGKey(42) >>> f = hk.transform_with_state(f) >>> params, state = f.init(rng, x) >>> out = f.apply(params, state, rng, x)
This function makes it very convenient to unit test Haiku:
>>> class MyTest(unittest.TestCase): ... @hk.testing.transform_and_run ... def test_linear_output(self): ... mod = hk.Linear(1) ... out = mod(jnp.ones([1, 1])) ... self.assertEqual(out.ndim, 2)
And can also be useful in an interactive environment like ipython, Jupyter or Google Colaboratory:
>>> f = lambda x: hk.Bias()(x) >>> hk.testing.transform_and_run(f)(jnp.ones([1, 1])) DeviceArray([[1.]], dtype=float32)
See
transform()
for more details. Parameters
f – A function method to transform.
seed (Optional[int]) – A seed to pass to init and apply.
run_apply (bool) – Whether to run apply as well as init. Defaults to true.
 Returns
A function that
transform()
sf
and runsinit
and optionallyapply
.
Data structures¶
filter¶

haiku.data_structures.
filter
(predicate, structure)[source]¶ Filters a input structure according to a user specified predicate.
>>> params = {'linear': {'w': None, 'b': None}} >>> predicate = lambda module_name, name, value: name == 'w' >>> hk.data_structures.filter(predicate, params) FlatMapping({'linear': FlatMapping({'w': None})})
Note: returns a new structure not a view.
 Parameters
predicate (Callable[[str, str, jnp.ndarray], bool]) – criterion to be used to partition the input data. The
predicate
argument is expected to be a boolean function taking as inputs the name of the module, the name of a given entry in the module data bundle (e.g. parameter name) and the corresponding data.structure (T) – Haiku params or state data structure to be filtered.
 Return type
T
 Returns
All the input parameters or state as selected by the input predicate.
merge¶

haiku.data_structures.
merge
(*structures)[source]¶ Merges multiple input structures.
>>> weights = {'linear': {'w': None}} >>> biases = {'linear': {'b': None}} >>> hk.data_structures.merge(weights, biases) FlatMapping({'linear': FlatMapping({'w': None, 'b': None})})
When structures are not disjoint the output will contain the value from the last structure for each path:
>>> weights1 = {'linear': {'w': 1}} >>> weights2 = {'linear': {'w': 2}} >>> hk.data_structures.merge(weights1, weights2) FlatMapping({'linear': FlatMapping({'w': 2})})
Note: returns a new structure not a view.
 Parameters
*structures – One or more structures to merge.
 Return type
T
 Returns
A single structure with an entry for each path in the input structures.
partition¶

haiku.data_structures.
partition
(predicate, structure)[source]¶ Partitions the input structure in two according to a given predicate.
For a given set of parameters, you can use
partition()
to split them:>>> params = {'linear': {'w': None, 'b': None}} >>> predicate = lambda module_name, name, value: name == 'w' >>> weights, biases = hk.data_structures.partition(predicate, params) >>> weights FlatMapping({'linear': FlatMapping({'w': None})}) >>> biases FlatMapping({'linear': FlatMapping({'b': None})})
Note: returns new structures not a view.
 Parameters
predicate (Callable[[str, str, jnp.ndarray], bool]) – criterion to be used to partition the input data. The
predicate
argument is expected to be a boolean function taking as inputs the name of the module, the name of a given entry in the module data bundle (e.g. parameter name) and the corresponding data.structure (T) – Haiku params or state data structure to be partitioned.
 Return type
Tuple[T, T]
 Returns
 A tuple containing all the params or state as partitioned by the input
predicate. Entries matching the predicate will be in the first structure, and the rest will be in the second.
to_immutable_dict¶
to_mutable_dict¶
tree_bytes¶

haiku.data_structures.
tree_bytes
(tree)[source]¶ Sums the size in bytes of all arrays in a pytree.
Note that this is the minimum size of the array (e.g. for a float32 we need at least 4 bytes) however on some accelerators buffers may occupy more memory due to padding/alignment constraints.
For example given a ResNet50 model:
>>> f = hk.transform_with_state(lambda x: hk.nets.ResNet50(1000)(x, True)) >>> rng = jax.random.PRNGKey(42) >>> x = jnp.ones([128, 224, 224, 3]) >>> params, state = f.init(rng, x)
We can count the number of parameters and their size at f32:
>>> num_params = hk.data_structures.tree_size(params) >>> byte_size = hk.data_structures.tree_bytes(params) >>> print(f'{num_params} params, size: {byte_size / 1e6:.2f}MB') 25557032 params, size: 102.23MB
And compare that with casting our parameters to bf16:
>>> params = jax.tree_map(lambda x: x.astype(jnp.bfloat16), params) >>> num_params = hk.data_structures.tree_size(params) >>> byte_size = hk.data_structures.tree_bytes(params) >>> print(f'{num_params} params, size: {byte_size / 1e6:.2f}MB') 25557032 params, size: 51.11MB
 Parameters
tree – A tree of jnp.ndarrays.
 Return type
int
 Returns
The total size in bytes of the array(s) in the input.
tree_size¶

haiku.data_structures.
tree_size
(tree)[source]¶ Sums the sizes of all arrays in a pytree.
For example given a ResNet50 model:
>>> f = hk.transform_with_state(lambda x: hk.nets.ResNet50(1000)(x, True)) >>> rng = jax.random.PRNGKey(42) >>> x = jnp.ones([128, 224, 224, 3]) >>> params, state = f.init(rng, x)
We can count the number of parameters and their size at f32:
>>> num_params = hk.data_structures.tree_size(params) >>> byte_size = hk.data_structures.tree_bytes(params) >>> print(f'{num_params} params, size: {byte_size / 1e6:.2f}MB') 25557032 params, size: 102.23MB
And compare that with casting our parameters to bf16:
>>> params = jax.tree_map(lambda x: x.astype(jnp.bfloat16), params) >>> num_params = hk.data_structures.tree_size(params) >>> byte_size = hk.data_structures.tree_bytes(params) >>> print(f'{num_params} params, size: {byte_size / 1e6:.2f}MB') 25557032 params, size: 51.11MB
 Parameters
tree – A tree of jnp.ndarrays.
 Return type
int
 Returns
The total size (number of elements) of the array(s) in the input.
Experimental¶
custom_creator¶

haiku.experimental.
custom_creator
(creator)[source]¶ Registers a custom parameter creator.
When new parameters are created via
get_parameter()
we first run custom creators passing user defined values through. For example:>>> def zeros_creator(next_creator, shape, dtype, init, context): ... init = jnp.zeros ... return next_creator(shape, dtype, init)
>>> with hk.experimental.custom_creator(zeros_creator): ... z = hk.get_parameter("z", [], jnp.float32, jnp.ones) >>> z DeviceArray(0., dtype=float32)
 Parameters
creator (ParamCreator) – A parameter creator.
 Returns
Context manager under which the creator is active.
custom_getter¶

haiku.experimental.
custom_getter
(getter)[source]¶ Registers a custom parameter getter.
When parameters are retrieved using
get_parameter()
we always run all custom getters before returning a value to the user.>>> def bf16_getter(next_getter, value, context): ... value = value.astype(jnp.bfloat16) ... return next_getter(value)
>>> with hk.experimental.custom_getter(bf16_getter): ... w = hk.get_parameter("w", [], jnp.float32, jnp.ones) >>> w.dtype dtype(bfloat16)
 Parameters
getter (ParamGetter) – A parameter getter.
 Returns
Context manager under which the getter is active.
ParamContext¶

class
haiku.experimental.
ParamContext
[source]¶ Read only state showing where parameters are being created.

full_name
¶ The full name of the given parameter (e.g.
mlp/~/linear_0/w
).

module
¶ The module that owns the current parameter,
None
if this parameter exists outside any module.

original_dtype
¶ The dtype that get_parameter() was originally called with.

original_dtype
The shape that get_parameter() was originally called with.

intercept_methods¶

haiku.experimental.
intercept_methods
(interceptor)[source]¶ Register a new method interceptor.
Method interceptors allow you to (at a distance) intercept method calls to modules and modify args/kwargs before calling the underlying method. After the underlying method is called you can modify its result before it is passed back to the user.
For example you could intercept method calls to
BatchNorm
and ensure it is always computed in full precision:>>> def my_interceptor(next_f, args, kwargs, context): ... if (type(context.module) is not hk.BatchNorm ... or context.method_name != "__call__"): ... # We ignore methods other than BatchNorm.__call__. ... return next_f(*args, **kwargs) ... ... def cast_if_array(x): ... if isinstance(x, jnp.ndarray): ... x = x.astype(jnp.float32) ... return x ... ... args, kwargs = jax.tree_map(cast_if_array, (args, kwargs)) ... out = next_f(*args, **kwargs) ... return out
We can create and use our module in the usual way, we just need to wrap any method calls we want to intercept in the context manager:
>>> mod = hk.BatchNorm(decay_rate=0.9, create_scale=True, create_offset=True) >>> x = jnp.ones([], jnp.bfloat16) >>> with hk.experimental.intercept_methods(my_interceptor): ... out = mod(x, is_training=True) >>> assert out.dtype == jnp.float32
Without the interceptor BatchNorm would compute in bf16, however since we cast x before the underlying method is called we compute in f32.
 Parameters
interceptor (MethodGetter) – A method interceptor.
 Returns
Context manager under which the interceptor is active.
MethodContext¶

class
haiku.experimental.
MethodContext
[source]¶ Read only state showing the calling context for a method.
For example lets define two interceptors and print the values in the context. Additionally we will make the first interceptor conditionally short circuit, since interceptors stack and are run in order, an earlier interceptor can decide to call the next interecptor, or short circuit and call the underlying method directly:
>>> module = hk.Linear(1, name="method_context_example") >>> short_circuit = False
>>> def my_interceptor_1(next_fun, args, kwargs, context): ... print('running my_interceptor_1') ... print(' module.name: ', context.module.name) ... print(' method_name: ', context.method_name) ... if short_circuit: ... return context.orig_method(*args, **kwargs) ... else: ... return next_fun(*args, **kwargs) >>> def my_interceptor_2(next_fun, args, kwargs, context): ... print('running my_interceptor_2') ... print(' module.name: ', context.module.name) ... print(' method_name: ', context.method_name) ... return next_fun(*args, **kwargs)
When
short_circuit=False
the two interceptors will run in order:>>> with hk.experimental.intercept_methods(my_interceptor_1), \ ... hk.experimental.intercept_methods(my_interceptor_2): ... _ = module(jnp.ones([1, 1])) running my_interceptor_1  module.name: method_context_example  method_name: __call__ running my_interceptor_2  module.name: method_context_example  method_name: __call__
Setting
short_circuit=True
will cause the first interecptor to call the original method (rather thannext_fun
which will trigger the next interceptor):>>> short_circuit = True >>> with hk.experimental.intercept_methods(my_interceptor_1), \ ... hk.experimental.intercept_methods(my_interceptor_2): ... _ = module(jnp.ones([1, 1])) running my_interceptor_1  module.name: method_context_example  method_name: __call__

method_name
¶ The name of the method being called on the module.

orig_method
¶ The underlying method on the module which when called will not trigger interceptors. You should only call this if you want to short circuit all the other interceptors, in general you should prefer to call the
next_fun
passed to your interceptor which will runorig_method
after running all other interceptors.

named_call¶
optimize_rng_use¶

haiku.experimental.
optimize_rng_use
(fun)[source]¶ Optimizes a RNG key splitting in
fun
.Our strategy here is to use abstract interpretation to run your function twice, the first time we use
jax.eval_shape
to avoid spending any flops and simply observe how many times you callnext_rng_key()
. We then run your function again, but this time we reserve enough RNG keys ahead of time such that we only need to calljax.random.split
once.In the following example, we need three random samples for our weight matricies in our 3 layer MLP. To draw these samples we use
next_rng_key()
which will split a new key for each sample. By usingoptimize_rng_use()
Haiku will preallocate exactly enough RNGs forf
to be evaluated by splitting the input key once and only once. For large models (unlike this example) this can lead to a significant reduction in compilation time forinit
:>>> def f(x): ... net = hk.nets.MLP([300, 100, 10]) ... return net(x) >>> f = hk.experimental.optimize_rng_use(f) >>> f = hk.transform(f) >>> params = f.init(jax.random.PRNGKey(42), jnp.ones([1, 1]))
 Parameters
fun – A function to wrap.
 Returns
A function that applies
fun
but only requires one call tojax.random.split
by Haiku.
lift¶

haiku.experimental.
lift
(*args, **kwargs) → T[source]¶ Lifts the given init function to a function in the current Haiku namespace.
During init, the returned callable will run the given
init_fn
, and include the resulting params in the outer transform’s dictionaries. Duringapply
, the returned callable will instead pull the relevant parameters from the outer transform’s dictionaries.Must be called inside
transform()
, and be passed theinit
member of aTransformed
.The user must ensure that the given
init
does not accidentally catch modules from an outertransform()
via functional closure.This is highly experimental and may be changed or removed at any time.
profiler_name_scopes¶

haiku.experimental.
profiler_name_scopes
(enabled=True)[source]¶ Enable/disable profiler name_scopes on all haiku module methods.
Note: currently only enables for
__call__
. See: :function:`named_call` if you want to annotate other methods explicitly. Parameters
enabled – Whether to enable name scopes or not.
 Returns
The previous value of the name_scopes setting.
Utilities¶
multinomial¶

haiku.
multinomial
(rng, logits, num_samples)[source]¶ Draws samples from a multinomial distribution.
 Parameters
rng – A JAX PRNGKey.
logits – Unnormalized logprobabilities, of shape
[batch_size, categories]
or[categories]
.num_samples – Number of samples to draw.
 Returns
Chosen categories, of shape
[batch_size, num_samples]
or[num_samples]
.
one_hot¶

haiku.
one_hot
(x, num_classes, dtype=<class 'jax.numpy.lax_numpy.float32'>)[source]¶ Returns a onehot version of indices.
DEPRECATED: Use
jax.nn.one_hot(x, num_classes).astype(dtype)
instead. Parameters
x – A tensor of indices.
num_classes – Number of classes in the onehot dimension.
dtype – The dtype.
 Returns
 The onehot tensor. If indices’ shape is [A, B, …], shape is
[A, B, … num_classes].
References¶
 1
Wojciech Zaremba, Ilya Sutskever, and Oriol Vinyals. Recurrent neural network regularization. arXiv preprint arXiv:1409.2329, 2014. URL: https://arxiv.org/abs/1409.2329.
 2(1,2,3,4)
Rafal Jozefowicz, Wojciech Zaremba, and Ilya Sutskever. An empirical exploration of recurrent network architectures. In International Conference on Machine Learning, 2342–2350. 2015.
 3(1,2,3)
SHI Xingjian, Zhourong Chen, Hao Wang, DitYan Yeung, WaiKin Wong, and Wangchun Woo. Convolutional lstm network: a machine learning approach for precipitation nowcasting. In Advances in neural information processing systems, 802–810. 2015.