Haiku API reference
Contents
Haiku API reference#
Haiku Fundamentals#
Haiku Transforms#
|
Transforms a function using Haiku modules into a pair of pure functions. |
Transforms a function using Haiku modules into a pair of pure functions. |
|
Transforms a collection of functions using Haiku into pure functions. |
|
Transforms a collection of functions using Haiku into pure functions. |
|
Removes the rng argument from the apply function. |
|
Wraps a transformed tuple and ignores state in/out. |
transform#
- haiku.transform(f, *, apply_rng=True)[source]#
Transforms a function using Haiku modules into a pair of pure functions.
For a function
out = f(*a, **k)this function returns a pair of two pure functions that callf(*a, **k)explicitly collecting and injecting parameter values:params = init(rng, *a, **k) out = apply(params, rng, *a, **k)
Note that the
rngargument is typically not required forapplyand passingNoneis accepted.The first thing to do is to define a
Module. A module encapsulates some parameters and a computation on those parameters:>>> class MyModule(hk.Module): ... def __call__(self, x): ... w = hk.get_parameter("w", [], init=jnp.zeros) ... return x + w
Next, define some function that creates and applies modules. We use
transform()to transform that function into a pair of functions that allow us to lift all the parameters out of the function (f.init) and apply the function with a given set of parameters (f.apply):>>> def f(x): ... a = MyModule() ... b = MyModule() ... return a(x) + b(x) >>> f = hk.transform(f)
To get the initial state of the module call
initwith an example input:>>> params = f.init(None, 1) >>> params {'my_module': {'w': ...Array(0., dtype=float32)}, 'my_module_1': {'w': ...Array(0., dtype=float32)}}
You can then apply the function with the given parameters by calling
apply(note that since we don’t use Haiku’s random number APIs to apply our network we passNoneas an RNG key):>>> print(f.apply(params, None, 1)) 2.0
It is expected that your program will at some point produce updated parameters and you will want to re-apply
apply. You can do this by callingapplywith different parameters:>>> new_params = {"my_module": {"w": jnp.array(2.)}, ... "my_module_1": {"w": jnp.array(3.)}} >>> print(f.apply(new_params, None, 2)) 9.0
If your transformed function needs to maintain internal state (e.g. moving averages in batch norm) then see
transform_with_state().- Parameters
f – A function closing over
Moduleinstances.apply_rng – In the process of being removed. Can only value True.
- Return type
- Returns
A
Transformedtuple withinitandapplypure functions.
transform_with_state#
- haiku.transform_with_state(f)[source]#
Transforms a function using Haiku modules into a pair of pure functions.
See
transform()for general details on Haiku transformations.For a function
out = f(*a, **k)this function returns a pair of two pure functions that callf(*a, **k)explicitly collecting and injecting parameter values and state:params, state = init(rng, *a, **k) out, state = apply(params, state, rng, *a, **k)
Note that the
rngargument is typically not required forapplyand passingNoneis accepted.This function is equivalent to
transform(), however it allows you to maintain and update internal state (e.g.ExponentialMovingAverageinBatchNorm) viaget_state()andset_state():>>> def f(): ... counter = hk.get_state("counter", shape=[], dtype=jnp.int32, ... init=jnp.zeros) ... hk.set_state("counter", counter + 1) ... return counter >>> f = hk.transform_with_state(f)
>>> params, state = f.init(None) >>> for _ in range(10): ... counter, state = f.apply(params, state, None) >>> print(counter) 9
- Parameters
f – A function closing over
Moduleinstances.- Return type
- Returns
A
TransformedWithStatetuple withinitandapplypure functions.
multi_transform#
- haiku.multi_transform(f)[source]#
Transforms a collection of functions using Haiku into pure functions.
In many scenarios we have several modules which are used either as primitives for several Haiku modules/functions, or whose pure versions are to be reused in downstream code. This utility enables this by applying
transform()to an arbitrary tree of Haiku functions which share modules and have a commoninitfunction.fis expected to return a tuple of two elements. First is atemplateHaiku function which provides an example of how all internal Haiku modules are connected. This function is used to create a commoninitfunction (with your parameters).The second object is an arbitrary tree of Haiku functions all of which reuse the modules connected in the
templatefunction. These functions are transformed to pureapplyfunctions.Example:
>>> def f(): ... encoder = hk.Linear(1, name="encoder") ... decoder = hk.Linear(1, name="decoder") ... ... def init(x): ... z = encoder(x) ... return decoder(z) ... ... return init, (encoder, decoder)
>>> f = hk.multi_transform(f) >>> rng = jax.random.PRNGKey(42) >>> x = jnp.ones([1, 1]) >>> params = f.init(rng, x) >>> jax.tree_util.tree_map(jnp.shape, params) {'decoder': {'b': (1,), 'w': (1, 1)}, 'encoder': {'b': (1,), 'w': (1, 1)}}
>>> encode, decode = f.apply >>> z = encode(params, None, x) >>> y = decode(params, None, z)
- Parameters
f (Callable[[], tuple[TemplateFn, TreeOfApplyFns]]) – A factory function that returns two functions, firstly a common init function that creates all modules, and secondly a pytree of apply functions which make use of those modules.
- Return type
- Returns
- A
MultiTransformedinstance which contains a pure init function that creates all parameters, and a pytree of pure apply functions that given the params apply the given function.
- A
See also
multi_transform_with_state(): Equivalent for modules using state.
multi_transform_with_state#
- haiku.multi_transform_with_state(f)[source]#
Transforms a collection of functions using Haiku into pure functions.
See
multi_transform()for more details.Example:
>>> def f(): ... encoder = hk.Linear(1, name="encoder") ... decoder = hk.Linear(1, name="decoder") ... ... def init(x): ... z = encoder(x) ... return decoder(z) ... ... return init, (encoder, decoder)
>>> f = hk.multi_transform_with_state(f) >>> rng = jax.random.PRNGKey(42) >>> x = jnp.ones([1, 1]) >>> params, state = f.init(rng, x) >>> jax.tree_util.tree_map(jnp.shape, params) {'decoder': {'b': (1,), 'w': (1, 1)}, 'encoder': {'b': (1,), 'w': (1, 1)}}
>>> encode, decode = f.apply >>> z, state = encode(params, state, None, x) >>> y, state = decode(params, state, None, z)
- Parameters
f (Callable[[], tuple[TemplateFn, TreeOfApplyFns]]) – Function returning a “template” function and an arbitrary tree of functions using modules connected in the template function.
- Return type
- Returns
An
initfunction and a tree of pureapplyfunctions.
See also
transform_with_state(): Transform a single apply function.multi_transform(): Transform multiple apply functions without state.
without_apply_rng#
- haiku.without_apply_rng(f)[source]#
Removes the rng argument from the apply function.
This is a convenience wrapper that makes the
rngargument tof.applydefault toNone. This is useful whenfdoesn’t actually use random numbers as part of its computation, such that therngargument wouldn’t be used. Note that iffdoes use random numbers, this will cause an error to be thrown complaining thatfneeds a non-None PRNGKey.- Parameters
f (TransformedT) – A transformed function.
- Return type
TransformedT
- Returns
The same transformed function, with a modified
apply.
without_state#
- haiku.without_state(f)[source]#
Wraps a transformed tuple and ignores state in/out.
The example below is equivalent to
f = hk.transform(f):>>> def f(x): ... mod = hk.Linear(10) ... return mod(x) >>> f = hk.without_state(hk.transform_with_state(f)) >>> rng = jax.random.PRNGKey(42) >>> x = jnp.zeros([1, 1]) >>> params = f.init(rng, x) >>> print(f.apply(params, rng, x)) [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
- Parameters
f (TransformedWithState) – A transformed function.
- Return type
- Returns
A transformed function that does not take or return state.
with_empty_state#
- haiku.with_empty_state(f)[source]#
Wraps a transformed tuple and passes empty state in/out.
The example below is equivalent to
f = hk.transform_with_state(f):>>> def f(x): ... mod = hk.Linear(10) ... return mod(x) >>> f = hk.with_empty_state(hk.transform(f)) >>> rng = jax.random.PRNGKey(42) >>> x = jnp.zeros([1, 1]) >>> params, state = f.init(rng, x) >>> state {} >>> out, state = f.apply(params, state, rng, x) >>> print(out) [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]] >>> state {}
- Parameters
f (Transformed) – A transformed function.
- Return type
- Returns
A transformed function that does accepts and returns state.
Modules, Parameters and State#
|
Base class for Haiku modules. |
|
Converts a function into a callable module class. |
|
Creates or reuses a parameter for the given transformed function. |
|
Gets the current value for state with an optional initializer. |
|
Sets the current value for some state. |
Module#
- class haiku.Module(name=None)[source]#
Base class for Haiku modules.
A Haiku module is a lightweight container for variables and other modules. Modules typically define one or more “forward” methods (e.g.
__call__) which apply operations combining user input and module parameters.Modules must be initialized inside a
transform()call.For example:
>>> class AddModule(hk.Module): ... def __call__(self, x): ... w = hk.get_parameter("w", [], init=jnp.ones) ... return x + w
>>> def forward_fn(x): ... mod = AddModule() ... return mod(x)
>>> forward = hk.transform(forward_fn) >>> x = 1. >>> rng = None >>> params = forward.init(rng, x) >>> print(forward.apply(params, None, x)) 2.0
- __init__(name=None)[source]#
Initializes the current module with the given name.
Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.
- Parameters
name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If
nameis not provided then the class name for the current instance is converted tolower_snake_caseand used instead.
- __post_init__(name=None)[source]#
Initializes the current module with the given name.
Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.
- Parameters
name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If
nameis not provided then the class name for the current instance is converted tolower_snake_caseand used instead.
to_module#
- haiku.to_module(f)[source]#
Converts a function into a callable module class.
Sample usage:
>>> def bias_fn(x): ... b = hk.get_parameter("b", [], init=hk.initializers.RandomNormal()) ... return x + b >>> Bias = hk.to_module(bias_fn) >>> def net(x, y): ... b = Bias(name="my_bias") ... # Bias x and y by the same amount. ... return b(x) * b(y)
- Parameters
f (Callable[..., Any]) – The function to convert.
- Return type
type[CallableModule]
- Returns
A module class which runs
fwhen called.
get_parameter#
- haiku.get_parameter(name, shape, dtype=<class 'jax.numpy.float32'>, init=None)[source]#
Creates or reuses a parameter for the given transformed function.
>>> print(hk.get_parameter("w", [], init=jnp.ones)) 1.0
Parameters within the same
transform()and/orModulewith the same name have the same value:>>> w1 = hk.get_parameter("w", [], init=jnp.zeros) >>> w2 = hk.get_parameter("w", [], init=jnp.zeros) >>> assert w1 is w2
- Parameters
name (str) – A name for the parameter.
shape (Sequence[int]) – The shape of the parameter.
dtype (Any) – The dtype of the parameter.
init (Optional[Initializer]) – A callable of shape, dtype to generate an initial value for the parameter.
- Return type
jax.Array
- Returns
A jax.Array with the parameter of the given shape.
get_state#
- haiku.get_state(name, shape=None, dtype=<class 'jax.numpy.float32'>, init=None)[source]#
Gets the current value for state with an optional initializer.
“State” can be used to represent mutable state in your network. The most common usage of state is to represent the moving averages used in batch normalization (see
ExponentialMovingAverage). If your network uses “state” then you are required to usetransform_with_state()and pass state into and out of the apply function.>>> print(hk.get_state("counter", [], init=jnp.zeros)) 0.0
If the value for the given state is already defined (e.g. using
set_state()) then you can call with just the name:>>> print(hk.get_state("counter")) 0.0
NOTE: state within the same
transform()and/orModulewith the same name have the same value:>>> c1 = hk.get_state("counter") >>> c2 = hk.get_state("counter") >>> assert c1 is c2
- Parameters
name (str) – A name for the state.
shape (Optional[Sequence[int]]) – The shape of the state.
dtype (Any) – The dtype of the state.
init (Optional[Initializer]) – A callable
f(shape, dtype)that returns an initial value for the state.
- Return type
jax.Array
- Returns
A jax.Array with the state of the given shape.
set_state#
- haiku.set_state(name, value)[source]#
Sets the current value for some state.
See
get_state().“State” can be used to represent mutable state in your network. The most common usage of state is to represent the moving averages used in batch normalization (see
ExponentialMovingAverage). If your network uses “state” then you are required to usetransform_with_state()and pass state into and out of the apply function.>>> hk.set_state("counter", jnp.zeros([])) >>> print(hk.get_state("counter")) 0.0
NOTE: state within the same
transform()and/orModulewith the same name have the same value:>>> w1 = hk.get_state("counter") >>> w2 = hk.get_state("counter") >>> assert w1 is w2
- Parameters
name (str) – A name for the state.
value – A value to set.
Getters and Interceptors#
|
Registers a custom parameter and/or state creator. |
|
Registers a custom parameter or state getter. |
|
Registers a custom state setter. |
|
Context about where parameters are being created. |
|
Context about where state is being set. |
|
Register a new method interceptor. |
|
Read only state showing the calling context for a method. |
custom_creator#
- haiku.custom_creator(creator, *, params=True, state=False)[source]#
Registers a custom parameter and/or state creator.
When new parameters are created via
get_parameter()we first run custom creators passing user defined values through. For example:>>> def zeros_creator(next_creator, shape, dtype, init, context): ... init = jnp.zeros ... return next_creator(shape, dtype, init)
>>> with hk.custom_creator(zeros_creator): ... z = hk.get_parameter("z", [], jnp.float32, jnp.ones) >>> print(z) 0.0
If
state=Truethen your creator will additionally run on calls toget_state():>>> with hk.custom_creator(zeros_creator, state=True): ... z = hk.get_state("z", [], jnp.float32, jnp.ones) >>> print(z) 0.0
- Parameters
creator (Creator) – A parameter creator.
params (bool) – Whether to intercept parameter creation, defaults to
True.state (bool) – Whether to intercept state creation, defaults to
False.
- Return type
contextlib.AbstractContextManager
- Returns
Context manager under which the creator is active.
custom_getter#
- haiku.custom_getter(getter, *, params=True, state=False)[source]#
Registers a custom parameter or state getter.
When parameters are retrieved using
get_parameter()we always run all custom getters before returning a value to the user.>>> def bf16_getter(next_getter, value, context): ... value = value.astype(jnp.bfloat16) ... return next_getter(value)
>>> with hk.custom_getter(bf16_getter): ... w = hk.get_parameter("w", [], jnp.float32, jnp.ones) >>> w.dtype dtype(bfloat16)
If
state=Truethe getter will additionally run for calls toget_state():>>> with hk.custom_getter(bf16_getter, state=True): ... c = hk.get_state("c", [], jnp.float32, jnp.ones) >>> c.dtype dtype(bfloat16)
- Parameters
getter (Getter) – A parameter getter.
params (bool) – Whether the getter should run on
get_parameter()state (bool) – Whether the getter should run on
get_state().
- Return type
contextlib.AbstractContextManager
- Returns
Context manager under which the getter is active.
custom_setter#
- haiku.custom_setter(setter)[source]#
Registers a custom state setter.
When state is set using
set_state()we always run all custom setters before saving the value.>>> def zero_during_init(next_setter, value, context): ... if hk.running_init(): ... value = jnp.zeros_like(value) ... return next_setter(value)
>>> with hk.custom_setter(zero_during_init): ... hk.set_state("x", jnp.ones([2])) ... x = hk.get_state("x") >>> print(x) [0. 0.]
- Parameters
setter (Setter) – A state setter.
- Return type
contextlib.AbstractContextManager
- Returns
Context manager under which the setter is active.
GetterContext#
- class haiku.GetterContext(full_name: str, module: Optional[Module], original_dtype: Any, original_shape: Sequence[int], original_init: Optional[Initializer], lifted_prefix_name: Optional[str])[source]#
Context about where parameters are being created.
- full_name#
The full name of the given parameter (e.g.
mlp/~/linear_0/w).- Type
str
- module#
The module that owns the current parameter,
Noneif this parameter exists outside any module.- Type
Optional[Module]
- original_dtype#
The dtype that
get_parameter()orget_state()was originally called with.- Type
Any
- original_shape#
The shape that
get_parameter()orget_state()was originally called with.- Type
Sequence[int]
- original_init#
The initializer that
get_parameter()orget_state()was originally called with.- Type
Optional[Initializer]
- lifted_prefix_name#
The module names of all enclosing lifted modules (see
lift()for more context). Adding this string as a prefix to full_name will be equal to the final parameter name in the outer transform’s parameter dictionary. NOTE: Whenget_parameter()orget_state()is called in an apply context, this name will always be None because only init functions are lifted.- Type
Optional[str]
- module_name#
The full name of enclosing modules.
- name#
The name of this parameter.
SetterContext#
- class haiku.SetterContext(full_name: str, module: Optional[Module], original_dtype: Any, original_shape: Sequence[int], lifted_prefix_name: Optional[str])[source]#
Context about where state is being set.
- full_name#
The full name of the given state (e.g.
mlp/~/linear_0/w).- Type
str
- module#
The module that owns the current state,
Noneif this state exists outside any module.- Type
Optional[Module]
- original_dtype#
The dtype that
set_state()was originally called with.- Type
Any
- original_shape#
The shape that
set_state()orget_state()was originally called with.- Type
Sequence[int]
- lifted_prefix_name#
The module names of all enclosing lifted modules (see
lift()for more context). Adding this string as a prefix to full_name will be equal to the final parameter name in the outer transform’s parameter dictionary. NOTE: Whenget_parameter()orget_state()is called in an apply context, this name will always be None because only init functions are lifted.- Type
Optional[str]
- module_name#
The full name of enclosing modules.
- name#
The name of this state.
intercept_methods#
- haiku.intercept_methods(interceptor)[source]#
Register a new method interceptor.
Method interceptors allow you to (at a distance) intercept method calls to modules and modify args/kwargs before calling the underlying method. After the underlying method is called you can modify its result before it is passed back to the user.
For example you could intercept method calls to
BatchNormand ensure it is always computed in full precision:>>> def my_interceptor(next_f, args, kwargs, context): ... if (type(context.module) is not hk.BatchNorm ... or context.method_name != "__call__"): ... # We ignore methods other than BatchNorm.__call__. ... return next_f(*args, **kwargs) ... ... def cast_if_array(x): ... if isinstance(x, jax.Array): ... x = x.astype(jnp.float32) ... return x ... ... args, kwargs = jax.tree_util.tree_map(cast_if_array, (args, kwargs)) ... out = next_f(*args, **kwargs) ... return out
We can create and use our module in the usual way, we just need to wrap any method calls we want to intercept in the context manager:
>>> mod = hk.BatchNorm(decay_rate=0.9, create_scale=True, create_offset=True) >>> x = jnp.ones([], jnp.bfloat16) >>> with hk.intercept_methods(my_interceptor): ... out = mod(x, is_training=True) >>> assert out.dtype == jnp.float32
Without the interceptor BatchNorm would compute in bf16, however since we cast x before the underlying method is called we compute in f32.
- Parameters
interceptor (MethodGetter) – A method interceptor.
- Returns
Context manager under which the interceptor is active.
MethodContext#
- class haiku.MethodContext(module: 'Module', method_name: str, orig_method: Callable[..., Any], orig_class: type['Module'])[source]#
Read only state showing the calling context for a method.
For example, let’s define two interceptors and print the values in the context. Additionally, we will make the first interceptor conditionally short circuit, since interceptors stack and are run in order, an earlier interceptor can decide to call the next interceptor, or short circuit and call the underlying method directly:
>>> module = hk.Linear(1, name="method_context_example") >>> short_circuit = False
>>> def my_interceptor_1(next_fun, args, kwargs, context): ... print('running my_interceptor_1') ... print('- module.name: ', context.module.name) ... print('- method_name: ', context.method_name) ... if short_circuit: ... return context.orig_method(*args, **kwargs) ... else: ... return next_fun(*args, **kwargs) >>> def my_interceptor_2(next_fun, args, kwargs, context): ... print('running my_interceptor_2') ... print('- module.name: ', context.module.name) ... print('- method_name: ', context.method_name) ... return next_fun(*args, **kwargs)
When
short_circuit=Falsethe two interceptors will run in order:>>> with hk.intercept_methods(my_interceptor_1), \ ... hk.intercept_methods(my_interceptor_2): ... _ = module(jnp.ones([1, 1])) running my_interceptor_1 - module.name: method_context_example - method_name: __call__ running my_interceptor_2 - module.name: method_context_example - method_name: __call__
Setting
short_circuit=Truewill cause the first interceptor to call the original method (rather thannext_funwhich will trigger the next interceptor):>>> short_circuit = True >>> with hk.intercept_methods(my_interceptor_1), \ ... hk.intercept_methods(my_interceptor_2): ... _ = module(jnp.ones([1, 1])) running my_interceptor_1 - module.name: method_context_example - method_name: __call__
- method_name#
The name of the method being called on the module.
- Type
str
- orig_method#
The underlying method on the module which when called will not trigger interceptors. You should only call this if you want to short circuit all the other interceptors, in general you should prefer to call the
next_funpassed to your interceptor which will runorig_methodafter running all other interceptors.- Type
Callable[…, Any]
- orig_class#
The class which defined orig_method. Note that when using inheritance this is not necessarily the same as type(module).
- Type
type[‘Module’]
Random Numbers#
|
Iterator of JAX random keys. |
Returns a unique JAX random key split from the current global key. |
|
|
Returns one or more JAX random keys split from the current global key. |
|
|
|
Pre-allocate some number of JAX RNG keys. |
|
Provides a new sequence for |
Returns the internal state of the PRNG sequence. |
|
|
Replaces the internal state of the PRNG sequence with the given state. |
PRNGSequence#
- class haiku.PRNGSequence(key_or_seed)[source]#
Iterator of JAX random keys.
>>> seq = hk.PRNGSequence(42) # OR pass a jax.random.PRNGKey >>> key1 = next(seq) >>> key2 = next(seq) >>> assert key1 is not key2
If you know how many keys you will want then you can use
reserve()to more efficiently split the keys you need:>>> seq.reserve(4) >>> keys = [next(seq) for _ in range(4)]
- __init__(key_or_seed)[source]#
Creates a new
PRNGSequence.
next_rng_key#
next_rng_keys#
- haiku.next_rng_keys(num)[source]#
Returns one or more JAX random keys split from the current global key.
>>> k1, k2 = hk.next_rng_keys(2) >>> assert (k1 != k2).all() >>> a = jax.random.uniform(k1, []) >>> b = jax.random.uniform(k2, []) >>> assert a != b
- Parameters
num (int) – The number of keys to split.
- Return type
jax.Array
- Returns
An array of shape
[num, 2]unique (within a transformed function) JAX rng keys that can be used with APIs such asjax.random.uniform().
maybe_next_rng_key#
- haiku.maybe_next_rng_key()[source]#
next_rng_key()if random numbers are available, elseNone.- Return type
Optional[PRNGKey]
reserve_rng_keys#
- haiku.reserve_rng_keys(num)[source]#
Pre-allocate some number of JAX RNG keys.
See
next_rng_key().This API offers a way to micro-optimize how RNG keys are split when using Haiku. It is unlikely that you need it unless you find compilation time of your
initfunction to be a problem, or you sample a lot of random numbers inapply.>>> hk.reserve_rng_keys(2) # Pre-allocate 2 keys for us to consume. >>> _ = hk.next_rng_key() # Takes the first pre-allocated key. >>> _ = hk.next_rng_key() # Takes the second pre-allocated key. >>> _ = hk.next_rng_key() # Splits a new key.
- Parameters
num (int) – The number of JAX rng keys to allocate.
with_rng#
- haiku.with_rng(key)[source]#
Provides a new sequence for
next_rng_key()to draw from.When
next_rng_key()is called, it draws a new key from thePRNGSequencedefined by the input key to the transformed function. This context manager overrides the sequence for the duration of the scope.>>> with hk.with_rng(jax.random.PRNGKey(428)): ... s = jax.random.uniform(hk.next_rng_key(), ()) >>> print("{:.1f}".format(s)) 0.5
- Parameters
key (PRNGKey) – The key to seed the sequence with.
- Returns
Context manager under which the given sequence is active.
maybe_get_rng_sequence_state#
replace_rng_sequence_state#
Type Hints#
|
An LSTM core state consists of hidden and cell vectors. |
A Mapping is a generic container for associating key/value pairs. |
|
A MutableMapping is a generic container for associating key/value pairs. |
|
A Mapping is a generic container for associating key/value pairs. |
|
A MutableMapping is a generic container for associating key/value pairs. |
|
|
Holds a pair of pure functions. |
|
Holds a pair of pure functions. |
|
Holds a collection of pure functions. |
|
Holds a collection of pure functions. |
|
Protocol for Module like types. |
|
Protocol for Module like types that are Callable. |
LSTMState#
Params#
- haiku.Params#
alias of
collections.abc.Mapping[str,collections.abc.Mapping[str,jax.Array]]
MutableParams#
- haiku.MutableParams#
alias of
collections.abc.MutableMapping[str,collections.abc.MutableMapping[str,jax.Array]]
State#
- haiku.State#
alias of
collections.abc.Mapping[str,collections.abc.Mapping[str,jax.Array]]
MutableState#
- haiku.MutableState#
alias of
collections.abc.MutableMapping[str,collections.abc.MutableMapping[str,jax.Array]]
Transformed#
TransformedWithState#
- class haiku.TransformedWithState(init: Callable[..., tuple[hk.MutableParams, hk.MutableState]], apply: Callable[..., tuple[Any, hk.MutableState]])[source]#
Holds a pair of pure functions.
- init#
A pure function:
params, state = init(rng, *a, **k)- Type
Callable[…, tuple[hk.MutableParams, hk.MutableState]]
- apply#
A pure function:
out, state = apply(params, state, rng, *a, **k)- Type
Callable[…, tuple[Any, hk.MutableState]]
MultiTransformed#
- class haiku.MultiTransformed(init: Callable[..., hk.MutableParams], apply: Any)[source]#
Holds a collection of pure functions.
- init#
A pure function:
params = init(rng, *a, **k)- Type
Callable[…, hk.MutableParams]
- apply#
A JAX tree of pure functions each with the signature:
out = apply(params, rng, *a, **k).- Type
Any
See also
Transformed: Single apply variant of multi-transform.MultiTransformedWithState: Multi apply with state variant.
MultiTransformedWithState#
- class haiku.MultiTransformedWithState(init: Callable[..., tuple[hk.MutableParams, hk.MutableState]], apply: Any)[source]#
Holds a collection of pure functions.
- init#
A pure function:
params, state = init(rng, *a, **k)- Type
Callable[…, tuple[hk.MutableParams, hk.MutableState]]
- apply#
A JAX tree of pure functions each with the signature:
out, state = apply(params, state, rng, *a, **k).- Type
Any
See also
TransformedWithState: Single apply variant of multi-transform.MultiTransformed: Multi apply with state variant.
ModuleProtocol#
SupportsCall#
- class haiku.SupportsCall(*args, **kwargs)[source]#
Protocol for Module like types that are Callable.
Being a protocol means you don’t need to explicitly extend this type in order to support instance checks with it. For example,
Linearonly extendsModule, however since it conforms (e.g. implements__call__) to this protocol you can instance check using it:>>> assert isinstance(hk.Linear(1), hk.SupportsCall)
Flax Interop#
Haiku inside Flax#
Module#
- class haiku.experimental.flax.Module(transformed, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
A Flax
nn.Modulethat runs a Haiku transformed function.This type is designed to make it easy to take a Haiku transformed function and/or a Haiku module and use it inside a program that otherwise uses Flax.
Given a Haiku transformed function
>>> def f(x): ... return hk.Linear(1)(x) >>> f = hk.transform(f)
You can convert it into a Flax module using:
>>> mod = hk.experimental.flax.Module(f)
Calling this module is the same as calling any regular Flax module:
>>> rng = jax.random.PRNGKey(42) >>> x = jnp.ones([1, 1]) >>> variables = mod.init(rng, x) >>> out = mod.apply(variables, x)
If you just want to convert a Haiku module class such that it can be used with Flax you can use the
createclass method:>>> mod = hk.experimental.flax.Module.create(hk.Linear, 1) >>> variables = mod.init(rng, x) >>> out = mod.apply(variables, x)
flatten_flax_to_haiku#
Flax inside Haiku#
lift#
- haiku.experimental.flax.lift(mod, *, name)[source]#
Lifts a flax nn.Module into a Haiku transformed function.
For a Flax Module (e.g.
mod = nn.Dense(10)),mod = lift(mod)allows you to run the call method of the module as if the module was a regular Haiku module.Parameters and state from the Flax module are registered with Haiku and become part of the params/state dictionaries (as returned from
init/apply).>>> def f(x): ... # Create and "lift" a Flax module. ... mod = hk.experimental.flax.lift(nn.Dense(300), name='dense') ... x = mod(x) # Any params/state will be registered ... # with Haiku when applying the module. ... x = jax.nn.relu(x) ... x = hk.nets.MLP([100, 10]) # You can of course mix Haiku modules in. ... return x >>> f = hk.transform(f) >>> rng = jax.random.PRNGKey(42) >>> x = jnp.ones([1, 1]) >>> params = f.init(rng, x) >>> out = f.apply(params, None, x)
- Parameters
mod (nn.Module) – Any Flax
nn.Moduleinstance.name (str) – Name scope to prefix entries in the outer params/state dict.
- Return type
Callable[…, Any]
- Returns
A function that when applied calls the call method of the given Flax module and returns its output. As a side effect of calling the module any module parameters and state variables are registered with Haiku.
Advanced State Management#
Lifting#
|
Registers parameters from an inner init function in an outer transform. |
|
Registers params and state from an init function in an outer transform. |
|
Registers parameters in an outer transform without adding a name scope. |
|
Registers params and state in an outer transform without adding scope. |
|
Handles updating the state for a lift_with_state computation. |
lift#
- haiku.lift(init_fn, *, allow_reuse=False, name='lifted')[source]#
Registers parameters from an inner init function in an outer transform.
HINT:
lift()is for when you want to make non-trivial use of JAX transforms (e.g.jax.vmap) inside of atransform()ortransform_with_state(). We generally recommend trying to use JAX transforms on the pure functions returned bytransform(), in which case you do not needlift().Use
lift()when nesting Haiku transforms to register the parameters of the inner transform in any outer transform. This is mainly useful when using JAX functions inside of a Haiku module (eg. usingjax.vmapon a layer). See https://dm-haiku.readthedocs.io/en/latest/notebooks/transforms.html#Using-hk.lift for more explanation of when to uselift(). (If you’re not using JAX functions inside of a module or don’t need access to your parameters inside of a transform, you probably don’t need to uselift())Must be called inside
transform(), and be passed theinitmember of aTransformed.During init, the returned callable will run the given
init_fn, and include the resulting params in the outer transform’s dictionaries. Duringapply, the returned callable will instead pull the relevant parameters from the outer transform’s dictionaries.By default, users must ensure that the given
initdoes not accidentally catch modules from an outertransform()via functional closure. If this behavior is desirable, setallow_reusetoTrue.Example:
A common usage of
lift()is to use JAX transformations likevmapin non-trivial ways, inside atransform(). For example, we can uselift()andjax.vmapto create an ensemble.First we’ll create a helper function that uses
lift()to applyvmapto our model. As you can see from the comments, we are usingvmapto change how parameters should be created (in this case we create a unique set of parameters for each member of the ensemble) and we change how apply works (we “map” the parameters meaning JAX will compute the forward pass separately, in parallel, for each member of the ensemble):>>> def create_ensemble(model, size: int): ... init_rng = hk.next_rng_keys(size) if hk.running_init() else None ... model = hk.transform(model) ... # in_axes: rng is mapped, data is not. ... init_model = jax.vmap(model.init, in_axes=(0, None)) ... # Use hk.lift to "lift" parameters created by `init_model` into the ... # outer transform. ... init_model = hk.lift(init_model, name="ensemble") ... def ensemble(x): ... params = init_model(init_rng, x) ... # in_axes: params are mapped, rng/data are not. ... return jax.vmap(model.apply, in_axes=(0, None, None))(params, None, x) ... return ensemble
We can now use this function to ensemble any Haiku module(s), inside of a transform. First we define a function for each member of the ensemble:
>>> def member_fn(x): ... return hk.nets.MLP([300, 100, 10])(x)
Secondly we can combine our two functions, inside a
transform()to create an ensemble:>>> def f(x): ... ensemble = create_ensemble(member_fn, size=4) ... x = ensemble(x) ... # You could create other modules here which were not ensembled. ... return x >>> f = hk.transform(f)
When we initialize the network, our ensemble member’s parameters have a leading dimension the size of the ensemble:
>>> rng = jax.random.PRNGKey(777) >>> x = jnp.ones([32, 128]) >>> params = f.init(rng, x) >>> jax.tree_util.tree_map(lambda x: x.shape, params) {'ensemble/mlp/~/linear_0': {'b': (4, 300), 'w': (4, 128, 300)}, 'ensemble/mlp/~/linear_1': {'b': (4, 100), 'w': (4, 300, 100)}, 'ensemble/mlp/~/linear_2': {'b': (4, 10), 'w': (4, 100, 10)}}
When we apply the network, we get an output for each member of the ensemble for the entire batch:
>>> y = f.apply(params, None, x) >>> y.shape (4, 32, 10)
- Parameters
init_fn (Callable[..., hk.Params]) – The
initfunction from anTransformed.allow_reuse (bool) – Allows lifted parameters and state to be reused from the outer
transform(). This can be desirable when usingliftwithin control flow (e.g.hk.scan).name (str) – A string name to prefix parameters with.
- Return type
Callable[…, hk.Params]
- Returns
A callable that during
initinjects parameter values into the outer context and duringapplyretrieves parameters from the outer context. In both cases returns parameter values to be used with anapplyfunction.
See also
lift_with_state(): Register params and state with an outer transform.transparent_lift(): Register params with an outer transform without a namespace.transparent_lift_with_state(): Register params and state with an outer transform without a namespace.
lift_with_state#
- haiku.lift_with_state(init_fn, *, allow_reuse=False, name='lifted')[source]#
Registers params and state from an init function in an outer transform.
See
lift()for more context on when to uselift.This function returns two objects. The first is a callable that runs your init function with slightly different behaviour based on if it’s run during init vs. apply time. The second is an updater that can be used to pass updated state values that result from running your apply function. See later in the docs for a worked example.
During init, the returned callable will run the given
init_fn, and include the resulting params/state in the outer transform’s dictionaries. Duringapply, the returned callable will instead pull the relevant params/state from the outer transform’s dictionaries.Must be called inside
transform_with_state(), and be passed theinitmember of aTransformedWithState.By default, users must ensure that the given
initdoes not accidentally catch modules from an outertransform_with_state()via functional closure. If this behavior is desirable, setallow_reusetoTrue.Example
>>> def g(x): ... return hk.nets.ResNet50(1)(x, True) >>> g = hk.transform_with_state(g) >>> params_and_state_fn, updater = ( ... hk.lift_with_state(g.init, name='f_lift')) >>> init_rng = hk.next_rng_key() if hk.running_init() else None >>> x = jnp.ones([1, 224, 224, 3]) >>> params, state = params_and_state_fn(init_rng, x) >>> out, state = g.apply(params, state, None, x) >>> updater.update(state)
- Parameters
init_fn (Callable[..., tuple[hk.Params, hk.State]]) – The
initfunction from anTransformedWithState.allow_reuse (bool) – Allows lifted parameters and state to be reused from the outer
transform_with_state(). This can be desirable when usinglift_with_statewithin control flow (e.g.hk.scan).name (str) – A string name to prefix parameters with.
- Return type
tuple[Callable[…, tuple[hk.Params, hk.State]], LiftWithStateUpdater]
- Returns
A callable that during
initinjects parameter values into the outer context and duringapplyreuses parameters from the outer context. In both cases returns parameter values to be used with anapplyfunction. Theinitfunction additionally returns an object used to update the outer context with new state afterapplyis called.
See also
lift(): Register parameters with an outer transform.transparent_lift(): Register parameters with an outer transform without a namespace.transparent_lift_with_state(): Register parameters and state with an outer transform without a namespace.
transparent_lift#
- haiku.transparent_lift(init_fn, *, allow_reuse=False)[source]#
Registers parameters in an outer transform without adding a name scope.
Functionally this is equivalent to
lift()but without automatically adding an additional variable scoping. Note that closing over a module from an outer scope is disallowed.See
lift()for more context on when to uselift.- Parameters
init_fn (Callable[..., hk.Params]) – The
initfunction from anTransformed.allow_reuse (bool) – Allows lifted parameters to be reused from the outer
transform_with_state(). This can be desirable when e.g. within control flow (e.g.hk.scan).
- Return type
Callable[…, hk.Params]
- Returns
A callable that during
initinjects parameter values into the outer context and duringapplyreuses parameters from the outer context. In both cases returns parameter values to be used with anapplyfunction.
See also
lift(): Register params with an outer transform.lift_with_state(): Register params and state with an outer transform.transparent_lift_with_state(): Register params and state with an outer transform without a namespace.
transparent_lift_with_state#
- haiku.transparent_lift_with_state(init_fn, *, allow_reuse=False)[source]#
Registers params and state in an outer transform without adding scope.
Functionally this is equivalent to
lift_with_state()but without automatically adding an additional variable scoping.See
lift_with_state()for more context on when to uselift_with_state.- Parameters
init_fn (Callable[..., tuple[hk.Params, hk.State]]) – The
initfunction from anTransformedWithState.allow_reuse (bool) – Allows lifted parameters and state to be reused from the outer
transform_with_state(). This can be desirable when e.g. within control flow (e.g.hk.scan).
- Return type
tuple[Callable[…, tuple[hk.Params, hk.State]], LiftWithStateUpdater]
- Returns
A callable that during
initinjects parameter values into the outer context and duringapplyreuses parameters from the outer context. In both cases returns parameter values to be used with anapplyfunction. Theinitfunction additionally returns an object used to update the outer context with new state afterapplyis called.
See also
lift(): Register params with an outer transform.lift_with_state(): Register params and state with an outer transform.transparent_lift(): Register params with an outer transform without a namespace.
LiftWithStateUpdater#
Layer Stack#
|
Utility to wrap a Haiku function and recursively apply it to an input. |
|
Module name mapping for transparent layer_stack. |
layer_stack#
- class haiku.layer_stack(num_layers, with_per_layer_inputs=False, unroll=1, pass_reverse_to_layer_fn=False, transparent=False, transparency_map=None, name=None)[source]#
Utility to wrap a Haiku function and recursively apply it to an input.
This can be used to improve model compile times.
A function is valid if it uses only explicit position parameters, and its return type matches its input type. The position parameters can be arbitrarily nested structures with
jax.Arrayat the leaf nodes. Note that kwargs are not supported, neither are functions with variable number of parameters (specified by*args).Note that layer_stack cannot at the moment be used with functions that build Haiku modules with state.
If
with_per_layer_inputs=Falsethen the new, wrapped function can be understood as performing the following:>>> f = lambda x: x+1 >>> num_layers = 4 >>> x = 0 >>> for i in range(num_layers): ... x = f(x) >>> x 4
And if
with_per_layer_inputs=True, assumingftakes two arguments on top ofx:>>> f = lambda x, y0, y1: (x+1, y0+y1) >>> num_layers = 4 >>> x = 0 >>> ys_0 = [1, 2, 3, 4] >>> ys_1 = [5, 6, 7, 8] >>> zs = [] >>> for i in range(num_layers): ... x, z = f(x, ys_0[i], ys_1[i]) ... zs.append(z) >>> x, zs (4, [6, 8, 10, 12])
The code using
layer_stackfor the above function would be:>>> f = lambda x, y0, y1: (x+1, y0+y1) >>> num_layers = 4 >>> x = 0 >>> ys_0 = jnp.array([1, 2, 3, 4]) >>> ys_1 = jnp.array([5, 6, 7, 8]) >>> stack = hk.layer_stack(num_layers, with_per_layer_inputs=True) >>> x, zs = stack(f)(x, ys_0, ys_1) >>> print(x, zs) 4 [ 6 8 10 12]
Check the tests in
layer_stack_test.pyfor further examples.Crucially, any parameters created inside
fwill not be shared across iterations.- Parameters
num_layers (int) – The number of times to iterate the wrapped function.
with_per_layer_inputs – Whether or not to pass per-layer inputs to the wrapped function.
unroll (int) – the unroll used by
scan.pass_reverse_to_layer_fn (bool) – Whether or not to pass the
reversekeyword to the functionf, so that it is aware if the layer stack is being run forward or in reverse (and the underlyingscan). To run the layer stack in reverse you need to pass inreverse=Trueto the call to the layer stack.transparent (bool) – Whether to apply layer_stack transparently. When this is True, and a correct transparency_map is provided, the parameters are generated in such a way that layer_stack can be replaced by a regular for loop without changing the parameter tree.
transparency_map (Optional[LayerStackTransparencyMapping]) – How to map stacked module names to flat names and reverse. See
LayerStackTransparencyMappingandlayer_stack_test.pyfor an example.name (Optional[str]) – name of the Haiku context.
- Returns
Callable that will produce a layer stack when called with a valid function.
LayerStackTransparencyMapping#
Naming#
|
Context manager which adds a prefix to all new modules, params or state. |
Returns the currently active module name. |
|
Causes a parameter or state value to not be stored. |
|
Returns the parameters for the current |
|
Returns the current state for the current |
|
Returns the initial state for the current |
|
|
Forces Haiku to use this name, ignoring all context information. |
|
Allows a method to be named like some other method. |
|
Decorator to wrap a method, preventing automatic variable scope wrapping. |
name_scope#
- haiku.name_scope(name, *, method_name='__call__')[source]#
Context manager which adds a prefix to all new modules, params or state.
>>> with hk.name_scope("my_name_scope"): ... net = hk.Linear(1, name="my_linear") >>> net.module_name 'my_name_scope/my_linear'
When used inside a module, any submodules, parameters or state created inside the name scope will have a prefix added to their names:
>>> class MyModule(hk.Module): ... def __call__(self, x): ... with hk.name_scope("my_name_scope"): ... submodule = hk.Linear(1, name="submodule") ... w = hk.get_parameter("w", [], init=jnp.ones) ... return submodule(x) + w
>>> f = hk.transform(lambda x: MyModule()(x)) >>> params = f.init(jax.random.PRNGKey(42), jnp.ones([1, 1])) >>> jax.tree_util.tree_map(jnp.shape, params) {'my_module/my_name_scope': {'w': ()}, 'my_module/my_name_scope/submodule': {'b': (1,), 'w': (1, 1)}}
Name scopes are very similar to putting all of the code inside the context manager inside a method on a
Modulewith the name you provide. Behind the scenes this is precisely how name scopes are implemented.If you are familiar with TensorFlow then Haiku’s
name_scope()is similar totf.variable_scope(..)in TensorFlow 1 andtf.name_scope(..)in TensorFlow 1 and 2 in that it changes the names associated with modules, parameters and state.- Parameters
name (str) – The name scope to use (e.g.
"foo"or"foo/bar").method_name (str) – (Advanced uses only). Since name scopes are equivalent to calling methods on modules the method name attribute allows you to specify which method name you want to simulate. Most users should leave this as the default value (“__call__”).
- Return type
ContextManager[None]
- Returns
A single use context manager that when active prefixes new modules, parameters or state with the given name.
current_name#
- haiku.current_name()[source]#
Returns the currently active module name.
Outside of a Haiku module (but inside a Haiku transform) this will return
~which matches the key in the params/state dict where top level values are stored.>>> hk.current_name() '~'
Inside a module this returns the current module name:
>>> class ExampleModule(hk.Module): ... def __call__(self): ... return hk.current_name() >>> ExampleModule()() 'example_module'
Inside a name scope this returns the current name scope:
>>> with hk.name_scope('example_name_scope'): ... print(hk.current_name()) example_name_scope
- Return type
str
- Returns
The currently active module or name scope name. If modules or name scopes are in use returns
~.
DO_NOT_STORE#
- haiku.DO_NOT_STORE = <haiku._src.base.DoNotStore object>#
Causes a parameter or state value to not be stored.
By default, Haiku will put the value returned from
get_parameter(),get_state()andset_state()into the dictionaries returned byinit. This is not always desirable.For example, a user may want to have part of their network come from a pretrained checkpoint, and they may want to freeze those values (aka. have them not appear in the params dict passed later to
grad). You can achieve this by manipulating the params dict, however sometimes it is more convenient to do this using custom creators/getters/setters.Consider the following function:
>>> def f(x): ... x = hk.Linear(300, name='torso')(x) ... x = hk.Linear(10, name='tail')(x) ... return x
Imagine you have a pre-trained set of weights for the torso:
>>> pretrained = {'torso': {'w': jnp.ones([28 * 28, 300]), ... 'b': jnp.ones([300])}}
First we define a creator, that tells Haiku to not store any parameters that are part of the pretrained dict:
>>> def my_creator(next_creator, shape, dtype, init, context): ... if context.module_name in pretrained: ... return hk.DO_NOT_STORE ... return next_creator(shape, dtype, init)
Then we need a getter that provides the parameter value from the pretrained dict:
>>> def my_getter(next_getter, value, context): ... if context.module_name in pretrained: ... assert value is hk.DO_NOT_STORE ... value = pretrained[context.module_name][context.name] ... return next_getter(value)
Finally we’ll wrap our function in context managers activating our creator and getter:
>>> def f_with_pretrained_torso(x): ... with hk.custom_creator(my_creator), \ ... hk.custom_getter(my_getter): ... return f(x)
You can see that when we run our function we only get parameters from modules that were not in the pretrained dict:
>>> f_with_pretrained_torso = hk.transform(f_with_pretrained_torso) >>> rng = jax.random.PRNGKey(42) >>> x = jnp.ones([1, 28 * 28]) >>> params = f_with_pretrained_torso.init(rng, x) >>> assert list(params) == ['tail']
This value can be used in initialisers,
custom_creator()orcustom_setter().
get_params#
- haiku.get_params()[source]#
Returns the parameters for the current
transform().>>> def report(when): ... shapes = jax.tree_util.tree_map(jnp.shape, hk.get_params()) ... print(f'{when}: {shapes}') >>> def f(x): ... report('Before call') ... x = hk.Linear(1)(x) ... report('After call') ... return x >>> f = hk.transform(f) >>> rng = jax.random.PRNGKey(42) >>> x = jnp.ones([1, 1])
During
initthe parameters dictionary will get populated as modules are called:>>> params = f.init(rng, x) Before call: {} After call: {'linear': {'b': (1,), 'w': (1, 1)}}
During
applythe parameters dictionary will remain unchanged:>>> _ = f.apply(params, None, x) Before call: {'linear': {'b': (1,), 'w': (1, 1)}} After call: {'linear': {'b': (1,), 'w': (1, 1)}}
NOTE: Does not run
custom_getters()or parameter initializers.- Return type
Params
- Returns
A copy of the parameters dictionary. During
initthis dictionary will be populated with any parameters that have been created so far. Duringapplythis will contain all parameters of all modules (the params dict does not change during apply).
See also
get_initial_state(): The initial state for the function.get_current_state(): The current state for the function.
get_current_state#
- haiku.get_current_state()[source]#
Returns the current state for the current
transform_with_state().Example:
>>> def report(when): ... state = jax.tree_util.tree_map(int, hk.get_current_state()) ... print(f'{when}: {state}') >>> def f(): ... report('Before get_state') ... x = hk.get_state('x', [], init=jnp.zeros) ... report('After get_state') ... hk.set_state('x', x + 1) ... report('After set_state') >>> f = hk.transform_with_state(f)
During
init, the most recently set value (either directly viaset_state()or via theinitargument toget_state()) will be returned:>>> _, state = f.init(None) Before get_state: {} After get_state: {'~': {'x': 0}} After set_state: {'~': {'x': 1}}
During
applythe most recently set value will be used, if no value has been set then the value that is passed intoapplywill be used:>>> state = {'~': {'x': 10}} >>> _ = f.apply({}, state, None) Before get_state: {'~': {'x': 10}} After get_state: {'~': {'x': 10}} After set_state: {'~': {'x': 11}}
NOTE: Does not run
custom_getters()or state initializers.- Return type
State
- Returns
A copy of the state dictionary that would be returned from
initorapply.
See also
get_params(): The current parameters for the function.get_initial_state(): The initial state for the function.
get_initial_state#
- haiku.get_initial_state()[source]#
Returns the initial state for the current
transform_with_state().Example:
>>> def report(when): ... state = jax.tree_util.tree_map(int, hk.get_initial_state()) ... print(f'{when}: {state}') >>> def f(): ... report('Before get_state') ... x = hk.get_state('x', [], init=jnp.zeros) ... report('After get_state') ... hk.set_state('x', x + 1) ... report('After set_state') >>> f = hk.transform_with_state(f)
During
init, the first set value (either directly viaset_state()or via theinitargument toget_state()) will be returned:>>> _, state = f.init(None) Before get_state: {} After get_state: {'~': {'x': 0}} After set_state: {'~': {'x': 0}}
During
applythe value passed into theapplyfunction will be used:>>> state = {'~': {'x': 10}} >>> _ = f.apply({}, state, None) Before get_state: {'~': {'x': 10}} After get_state: {'~': {'x': 10}} After set_state: {'~': {'x': 10}}
NOTE: Does not run
custom_getters()or state initializers.- Return type
State
- Returns
A copy of the state dictionary that would be returned from
initor passed intoapply.
See also
get_params(): The current parameters for the function.get_current_state(): The current state for the function.
force_name#
- haiku.force_name(name)[source]#
Forces Haiku to use this name, ignoring all context information.
NOTE: This method is intended for advanced use cases only and should be avoided whenever possible as it effectively enforces a singleton pattern when setting absolute names.
Haiku names modules according to where they are created (e.g. the stack of modules that created them, or the current
name_scope()). This function allows you to create modules that ignore all of this and have precisely the name you provide.This might be useful in the case that you have two modules and you want to force them to share parameters:
>>> mod0 = hk.Linear(1) >>> some_hyperparameter = True >>> if some_hyperparameter: ... # Force mod1 and mod0 to have shared weights. ... mod1 = hk.Linear(1, name=hk.force_name(mod0.module_name)) ... else: ... # mod0 and mod1 are independent. ... mod1 = hk.Linear(1)
(A simpler version of this snippet would do mod1 = mod0 instead of using force_name, however in real examples it can be simpler to use force_name, especially in cases where you may not have access to the module instance without lots of plumbing, but getting the module name is easy [e.g. it is a hyperparameter]).
- Parameters
name (str) – String name for the module. For example
"foo"or"foo/bar".- Return type
str
- Returns
A value suitable to pass into the
nameargument of any Haiku module constructor.
name_like#
- haiku.name_like(method_name)[source]#
Allows a method to be named like some other method.
In Haiku submodules are named based on the name of their parent module and the method in which they are created. When refactoring code it may be desirable to maintain previous names in order to keep checkpoint compatibility, this can be achieved using
name_like().As an example, consider the following toy autoencoder:
>>> class Autoencoder(hk.Module): ... def __call__(self, x): ... z = hk.Linear(10, name="enc")(x) # name: autoencoder/enc ... y = hk.Linear(10, name="dec")(z) # name: autoencoder/dec ... return y
If we want to refactor this such that users can encode or decode, we would create two methods (encode, decode) which would create and apply our modules. In order to retain checkpoint compatibility with the original module we can use
name_like()to name those submodules as if they were created inside__call__:>>> class Autoencoder(hk.Module): ... @hk.name_like("__call__") ... def encode(self, x): ... return hk.Linear(10, name="enc")(x) # name: autoencoder/enc ... ... @hk.name_like("__call__") ... def decode(self, z): ... return hk.Linear(10, name="dec")(z) # name: autoencoder/dec ... ... def __call__(self, x): ... return self.decode(self.encode(x))
One sharp edge is if users rely on Haiku’s numbering to take care of giving unique names and refactor using
name_like(). For example when refactoring the following:>>> class Autoencoder(hk.Module): ... def __call__(self, x): ... y = hk.Linear(10)(z) # name: autoencoder/linear_1 ... z = hk.Linear(10)(x) # name: autoencoder/linear ... return y
To use
name_like(), the unnamed linear modules in encode/decode will end up with the same name (both:autoencoder/linear) because module numbering is only applied within a method:>>> class Autoencoder(hk.Module): ... @hk.name_like("__call__") ... def encode(self, x): ... return hk.Linear(10)(x) # name: autoencoder/linear ... ... @hk.name_like("__call__") ... def decode(self, z): ... return hk.Linear(10)(z) # name: autoencoder/linear <-- NOT INTENDED
To fix this case you need to explicitly name the modules within the method with their former name:
>>> class Autoencoder(hk.Module): ... @hk.name_like("__call__") ... def encode(self, x): ... return hk.Linear(10, name="linear")(x) # name: autoencoder/linear ... ... @hk.name_like("__call__") ... def decode(self, z): ... return hk.Linear(10, name="linear_1")(z) # name: autoencoder/linear_1
- Parameters
method_name (str) – The name of a method whose name we should adopt. This method does not actually have to be defined on the class.
- Return type
Callable[[T], T]
- Returns
A decorator that when applied to a method marks it as having a different name.
transparent#
- haiku.transparent(method)[source]#
Decorator to wrap a method, preventing automatic variable scope wrapping.
By default, all variables and modules created in a method are scoped by the module and method names. This is undesirable in some cases. Any method decorated with
transparent()will create variables and modules in the scope in which it was called.- Parameters
method (T) – the method to wrap.
- Return type
T
- Returns
The method, with a flag indicating no name scope wrapping should occur.
Visualisation#
|
Converts a function using Haiku modules to a dot graph. |
to_dot#
- haiku.to_dot(fun)[source]#
Converts a function using Haiku modules to a dot graph.
To view the resulting graph in Google Colab or an iPython notebook use the
graphvizpackage:dot = hk.to_dot(f)(x) import graphviz graphviz.Source(dot)
- Parameters
fun (Callable[..., Any]) – A function using Haiku modules.
- Return type
Callable[…, str]
- Returns
A function that returns the source code string to a graphviz graph describing the operations executed by the given function clustered by Haiku module.
See also
abstract_to_dot(): Generates a graphviz graph using abstract inputs.
Common Modules#
Linear#
|
Linear module. |
|
Adds a bias to inputs. |
Linear#
- class haiku.Linear(output_size, with_bias=True, w_init=None, b_init=None, name=None)[source]#
Linear module.
- __init__(output_size, with_bias=True, w_init=None, b_init=None, name=None)[source]#
Constructs the Linear module.
- Parameters
output_size (int) – Output dimensionality.
with_bias (bool) – Whether to add a bias to the output.
w_init (Optional[hk.initializers.Initializer]) – Optional initializer for weights. By default, uses random values from truncated normal, with stddev
1 / sqrt(fan_in). See https://arxiv.org/abs/1502.03167v3.b_init (Optional[hk.initializers.Initializer]) – Optional initializer for bias. By default, zero.
name (Optional[str]) – Name of the module.
Bias#
- class haiku.Bias(output_size=None, bias_dims=None, b_init=None, name=None)[source]#
Adds a bias to inputs.
Example Usage:
>>> N, H, W, C = 1, 2, 3, 4 >>> x = jnp.ones([N, H, W, C]) >>> scalar_bias = hk.Bias(bias_dims=[]) >>> scalar_bias_output = scalar_bias(x) >>> assert scalar_bias.bias_shape == ()
Create a bias over all non-minibatch dimensions:
>>> all_bias = hk.Bias() >>> all_bias_output = all_bias(x) >>> assert all_bias.bias_shape == (H, W, C)
Create a bias over the last non-minibatch dimension:
>>> last_bias = hk.Bias(bias_dims=[-1]) >>> last_bias_output = last_bias(x) >>> assert last_bias.bias_shape == (C,)
Create a bias over the first non-minibatch dimension:
>>> first_bias = hk.Bias(bias_dims=[1]) >>> first_bias_output = first_bias(x) >>> assert first_bias.bias_shape == (H, 1, 1)
Subtract and later add the same learned bias:
>>> bias = hk.Bias() >>> h1 = bias(x, multiplier=-1) >>> h2 = bias(x) >>> h3 = bias(x, multiplier=-1) >>> reconstructed_x = bias(h3) >>> assert (x == reconstructed_x).all()
- __init__(output_size=None, bias_dims=None, b_init=None, name=None)[source]#
Constructs a
Biasmodule that supports broadcasting.- Parameters
output_size (Optional[Sequence[int]]) – Output size (output shape without batch dimension). If
output_sizeis left as None, the size will be directly inferred by the input.bias_dims (Optional[Sequence[int]]) – Sequence of which dimensions to retain from the input shape when constructing the bias. The remaining dimensions will be broadcast over (given size of 1), and leading dimensions will be removed completely. See class doc for examples.
b_init (Optional[hk.initializers.Initializer]) – Optional initializer for the bias. Default to zeros.
name (Optional[str]) – Name of the module.
- __call__(inputs, multiplier=None)[source]#
Adds bias to
inputsand optionally multiplies bymultiplier.- Parameters
inputs (jax.Array) – A Tensor of size
[batch_size, input_size1, ...].multiplier (Optional[Union[float, jax.Array]]) – A scalar or Tensor which the bias term is multiplied by before adding it to
inputs. Anything which works in the expressionbias * multiplieris acceptable here. This may be useful if you want to add a bias in one place and subtract the same bias in another place viamultiplier=-1.
- Return type
jax.Array
- Returns
A Tensor of size
[batch_size, input_size1, ...].
Pooling#
|
Average pool. |
|
Average pool. |
|
Max pool. |
|
Max pool. |
Average Pool#
- haiku.avg_pool(value, window_shape, strides, padding, channel_axis=- 1)[source]#
Average pool.
- Parameters
value (jax.Array) – Value to pool.
window_shape (Union[int, Sequence[int]]) – Shape of the pooling window, same rank as value.
strides (Union[int, Sequence[int]]) – Strides of the pooling window, same rank as value.
padding (str) – Padding algorithm. Either
VALIDorSAME.channel_axis (Optional[int]) – Axis of the spatial channels for which pooling is skipped.
- Return type
jax.Array
- Returns
Pooled result. Same rank as value.
- Raises
ValueError – If the padding is not valid.
- class haiku.AvgPool(window_shape, strides, padding, channel_axis=- 1, name=None)[source]#
Average pool.
Equivalent to partial application of
avg_pool().- __init__(window_shape, strides, padding, channel_axis=- 1, name=None)[source]#
Average pool.
- Parameters
window_shape (Union[int, Sequence[int]]) – Shape of the pooling window, same rank as value.
strides (Union[int, Sequence[int]]) – Strides of the pooling window, same rank as value.
padding (str) – Padding algorithm. Either
VALIDorSAME.channel_axis (Optional[int]) – Axis of the spatial channels for which pooling is skipped.
name (Optional[str]) – String name for the module.
Max Pool#
- haiku.max_pool(value, window_shape, strides, padding, channel_axis=- 1)[source]#
Max pool.
- Parameters
value (jax.Array) – Value to pool.
window_shape (Union[int, Sequence[int]]) – Shape of the pooling window, same rank as value.
strides (Union[int, Sequence[int]]) – Strides of the pooling window, same rank as value.
padding (str) – Padding algorithm. Either
VALIDorSAME.channel_axis (Optional[int]) – Axis of the spatial channels for which pooling is skipped.
- Return type
jax.Array
- Returns
Pooled result. Same rank as value.
- class haiku.MaxPool(window_shape, strides, padding, channel_axis=- 1, name=None)[source]#
Max pool.
Equivalent to partial application of
max_pool().- __init__(window_shape, strides, padding, channel_axis=- 1, name=None)[source]#
Max pool.
- Parameters
window_shape (Union[int, Sequence[int]]) – Shape of the pooling window, same rank as value.
strides (Union[int, Sequence[int]]) – Strides of the pooling window, same rank as value.
padding (str) – Padding algorithm. Either
VALIDorSAME.channel_axis (Optional[int]) – Axis of the spatial channels for which pooling is skipped.
name (Optional[str]) – String name for the module.
Dropout#
|
Randomly drop units in the input at a given rate. |
dropout#
- haiku.dropout(rng, rate, x, broadcast_dims=())[source]#
Randomly drop units in the input at a given rate.
See: http://www.cs.toronto.edu/~hinton/absps/dropout.pdf
- Parameters
rng (PRNGKey) – A JAX random key.
rate (float) – Probability that each element of
xis discarded. Must be a scalar in the range[0, 1).x (jax.Array) – The value to be dropped out.
broadcast_dims (Sequence[int]) – specifies dimensions that will share the same dropout mask.
- Return type
jax.Array
- Returns
x, but dropped out and scaled by
1 / (1 - rate).
Note
This involves generating x.size pseudo-random samples from U([0, 1)) computed with the full precision required to compare them with rate. When rate is a Python float, this is typically 32 bits, which is often more than what applications require. A work-around is to pass rate with a lower precision, e.g. using np.float16(rate).
Combinator#
|
Sequentially calls the given list of layers. |
Sequential#
- class haiku.Sequential(layers, name=None)[source]#
Sequentially calls the given list of layers.
Note that
Sequentialis limited in the range of possible architectures it can handle. This is a deliberate design decision;Sequentialis only meant to be used for the simple case of fusing together modules/ops where the input of a particular module/op is the output of the previous one.Another restriction is that it is not possible to have extra arguments in the
__call__()method that are passed to the constituents of the module - for example, if there is aBatchNormmodule inSequentialand the user wishes to switch theis_trainingflag. If this is the desired use case, the recommended solution is to subclassModuleand implement__call__:>>> class CustomModule(hk.Module): ... def __call__(self, x, is_training): ... x = hk.Conv2D(32, 4, 2)(x) ... x = hk.BatchNorm(True, True, 0.9)(x, is_training) ... x = jax.nn.relu(x) ... return x
- __init__(layers, name=None)[source]#
Initializes the current module with the given name.
Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.
- Parameters
name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If
nameis not provided then the class name for the current instance is converted tolower_snake_caseand used instead.
Convolutional#
|
General N-dimensional convolutional. |
|
One dimensional convolution. |
|
Two dimensional convolution. |
|
Three dimensional convolution. |
|
General n-dimensional transposed convolution (aka. |
|
One dimensional transposed convolution (aka. |
|
Two dimensional transposed convolution (aka. |
|
Three dimensional transposed convolution (aka. |
|
One dimensional convolution. |
|
Two dimensional convolution. |
|
Three dimensional convolution. |
|
Returns the channel index when given a valid data format. |
ConvND#
- class haiku.ConvND(num_spatial_dims, output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='channels_last', mask=None, feature_group_count=1, name=None)[source]#
General N-dimensional convolutional.
- __init__(num_spatial_dims, output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='channels_last', mask=None, feature_group_count=1, name=None)[source]#
Initializes the module.
- Parameters
num_spatial_dims (int) – The number of spatial dimensions of the input.
output_channels (int) – Number of output channels.
kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length
num_spatial_dims.stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length
num_spatial_dims. Defaults to 1.rate (Union[int, Sequence[int]]) – Optional kernel dilation rate. Either an integer or a sequence of length
num_spatial_dims. 1 corresponds to standard ND convolution,rate > 1corresponds to dilated convolution. Defaults to 1.padding (Union[str, Sequence[tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – Optional padding algorithm. Either
VALIDorSAMEor a sequence of n(low, high)integer pairs that give the padding to apply before and after each spatial dimension. or a callable or sequence of callables of sizenum_spatial_dims. Any callables must take a single integer argument equal to the effective kernel size and return a sequence of two integers representing the padding before and after. Seehaiku.pad.*for more details and example functions. Defaults toSAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.with_bias (bool) – Whether to add a bias. By default, true.
w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.
b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.
data_format (str) – The data format of the input. Can be either
channels_first,channels_last,N...CorNC.... By default,channels_last. Seeget_channel_index().mask (Optional[jax.Array]) – Optional mask of the weights.
feature_group_count (int) – Optional number of groups in group convolution. Default value of 1 corresponds to normal dense convolution. If a higher value is used, convolutions are applied separately to that many groups, then stacked together. This reduces the number of parameters and possibly the compute for a given
output_channels. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.name (Optional[str]) – The name of the module.
- __call__(inputs, *, precision=None)[source]#
Connects
ConvNDlayer.- Parameters
inputs (jax.Array) – An array of shape
[spatial_dims, C]and rank-N+1 if unbatched, or an array of shape[N, spatial_dims, C]and rank-N+2 if batched.precision (Optional[lax.Precision]) – Optional
jax.lax.Precisionto pass tojax.lax.conv_general_dilated().
- Return type
jax.Array
- Returns
- An array of shape
[spatial_dims, output_channels]and rank-N+1 if unbatched, or an array of shape
[N, spatial_dims, output_channels]and rank-N+2 if batched.
- An array of shape
Conv1D#
- class haiku.Conv1D(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', mask=None, feature_group_count=1, name=None)[source]#
One dimensional convolution.
- __init__(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', mask=None, feature_group_count=1, name=None)[source]#
Initializes the module.
- Parameters
output_channels (int) – Number of output channels.
kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 1.
stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 1. Defaults to 1.
rate (Union[int, Sequence[int]]) – Optional kernel dilation rate. Either an integer or a sequence of length 1. 1 corresponds to standard ND convolution,
rate > 1corresponds to dilated convolution. Defaults to 1.padding (Union[str, Sequence[tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – Optional padding algorithm. Either
VALIDorSAMEor a callable or sequence of callables of length 1. Any callables must take a single integer argument equal to the effective kernel size and return a list of two integers representing the padding before and after. See haiku.pad.* for more details and example functions. Defaults toSAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.with_bias (bool) – Whether to add a bias. By default, true.
w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.
b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.
data_format (str) – The data format of the input. Either
NWCorNCW. By default,NWC.mask (Optional[jax.Array]) – Optional mask of the weights.
feature_group_count (int) – Optional number of groups in group convolution. Default value of 1 corresponds to normal dense convolution. If a higher value is used, convolutions are applied separately to that many groups, then stacked together. This reduces the number of parameters and possibly the compute for a given
output_channels. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.name (Optional[str]) – The name of the module.
Conv2D#
- class haiku.Conv2D(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', mask=None, feature_group_count=1, name=None)[source]#
Two dimensional convolution.
- __init__(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', mask=None, feature_group_count=1, name=None)[source]#
Initializes the module.
- Parameters
output_channels (int) – Number of output channels.
kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 2.
stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 2. Defaults to 1.
rate (Union[int, Sequence[int]]) – Optional kernel dilation rate. Either an integer or a sequence of length 2. 1 corresponds to standard ND convolution,
rate > 1corresponds to dilated convolution. Defaults to 1.padding (Union[str, Sequence[tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – Optional padding algorithm. Either
VALIDorSAMEor a callable or sequence of callables of length 2. Any callables must take a single integer argument equal to the effective kernel size and return a list of two integers representing the padding before and after. See haiku.pad.* for more details and example functions. Defaults toSAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.with_bias (bool) – Whether to add a bias. By default, true.
w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.
b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.
data_format (str) – The data format of the input. Either
NHWCorNCHW. By default,NHWC.mask (Optional[jax.Array]) – Optional mask of the weights.
feature_group_count (int) – Optional number of groups in group convolution. Default value of 1 corresponds to normal dense convolution. If a higher value is used, convolutions are applied separately to that many groups, then stacked together. This reduces the number of parameters and possibly the compute for a given
output_channels. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.name (Optional[str]) – The name of the module.
Conv3D#
- class haiku.Conv3D(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', mask=None, feature_group_count=1, name=None)[source]#
Three dimensional convolution.
- __init__(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', mask=None, feature_group_count=1, name=None)[source]#
Initializes the module.
- Parameters
output_channels (int) – Number of output channels.
kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 3.
stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 3. Defaults to 1.
rate (Union[int, Sequence[int]]) – Optional kernel dilation rate. Either an integer or a sequence of length 3. 1 corresponds to standard ND convolution, rate > 1 corresponds to dilated convolution. Defaults to 1.
padding (Union[str, Sequence[tuple[int, int]], hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – Optional padding algorithm. Either
VALIDorSAMEor a callable or sequence of callables of length 3. Any callables must take a single integer argument equal to the effective kernel size and return a list of two integers representing the padding before and after. See haiku.pad.* for more details and example functions. Defaults toSAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.with_bias (bool) – Whether to add a bias. By default, true.
w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.
b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.
data_format (str) – The data format of the input. Either
NDHWCorNCDHW. By default,NDHWC.mask (Optional[jax.Array]) – Optional mask of the weights.
feature_group_count (int) – Optional number of groups in group convolution. Default value of 1 corresponds to normal dense convolution. If a higher value is used, convolutions are applied separately to that many groups, then stacked together. This reduces the number of parameters and possibly the compute for a given
output_channels. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.name (Optional[str]) – The name of the module.
ConvNDTranspose#
- class haiku.ConvNDTranspose(num_spatial_dims, output_channels, kernel_shape, stride=1, output_shape=None, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='channels_last', mask=None, name=None)[source]#
General n-dimensional transposed convolution (aka. deconvolution).
- __init__(num_spatial_dims, output_channels, kernel_shape, stride=1, output_shape=None, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='channels_last', mask=None, name=None)[source]#
Initializes the module.
- Parameters
num_spatial_dims (int) – The number of spatial dimensions of the input.
output_channels (int) – Number of output channels.
kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length
num_spatial_dims.stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length
num_spatial_dims. Defaults to 1.output_shape (Optional[Union[int, Sequence[int]]]) – Output shape of the spatial dimensions of a transpose convolution. Can be either an integer or an iterable of integers. If a None value is given, a default shape is automatically calculated.
padding (Union[str, Sequence[tuple[int, int]]]) – Optional padding algorithm. Either “VALID” or “SAME”. Defaults to “SAME”. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
with_bias (bool) – Whether to add a bias. By default, true.
w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.
b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.
data_format (str) – The data format of the input. Can be either
channels_first,channels_last,N...CorNC.... By default,channels_last.mask (Optional[jax.Array]) – Optional mask of the weights.
name (Optional[str]) – The name of the module.
- __call__(inputs, *, precision=None)[source]#
Computes the transposed convolution of the input.
- Parameters
inputs (jax.Array) – An array of shape
[spatial_dims, C]and rank-N+1 if unbatched, or an array of shape[N, spatial_dims, C]and rank-N+2 if batched.precision (Optional[lax.Precision]) – Optional
jax.lax.Precisionto pass tojax.lax.conv_transpose().
- Return type
jax.Array
- Returns
- An array of shape
[spatial_dims, output_channels]and rank-N+1 if unbatched, or an array of shape
[N, spatial_dims, output_channels]and rank-N+2 if batched.
- An array of shape
Conv1DTranspose#
- class haiku.Conv1DTranspose(output_channels, kernel_shape, stride=1, output_shape=None, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', mask=None, name=None)[source]#
One dimensional transposed convolution (aka. deconvolution).
- __init__(output_channels, kernel_shape, stride=1, output_shape=None, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', mask=None, name=None)[source]#
Initializes the module.
- Parameters
output_channels (int) – Number of output channels.
kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 1.
stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 1. Defaults to 1.
output_shape (Optional[Union[int, Sequence[int]]]) – Output shape of the spatial dimensions of a transpose convolution. Can be either an integer or an iterable of integers. If a None value is given, a default shape is automatically calculated.
padding (Union[str, Sequence[tuple[int, int]]]) – Optional padding algorithm. Either
VALIDorSAME. Defaults toSAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.with_bias (bool) – Whether to add a bias. By default, true.
w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.
b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.
data_format (str) – The data format of the input. Either
NWCorNCW. By default,NWC.mask (Optional[jax.Array]) – Optional mask of the weights.
name (Optional[str]) – The name of the module.
Conv2DTranspose#
- class haiku.Conv2DTranspose(output_channels, kernel_shape, stride=1, output_shape=None, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', mask=None, name=None)[source]#
Two dimensional transposed convolution (aka. deconvolution).
- __init__(output_channels, kernel_shape, stride=1, output_shape=None, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', mask=None, name=None)[source]#
Initializes the module.
- Parameters
output_channels (int) – Number of output channels.
kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 2.
stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 2. Defaults to 1.
output_shape (Optional[Union[int, Sequence[int]]]) – Output shape of the spatial dimensions of a transpose convolution. Can be either an integer or an iterable of integers. If a None value is given, a default shape is automatically calculated.
padding (Union[str, Sequence[tuple[int, int]]]) – Optional padding algorithm. Either
VALIDorSAME. Defaults toSAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.with_bias (bool) – Whether to add a bias. By default, true.
w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.
b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.
data_format (str) – The data format of the input. Either
NHWCorNCHW. By default,NHWC.mask (Optional[jax.Array]) – Optional mask of the weights.
name (Optional[str]) – The name of the module.
Conv3DTranspose#
- class haiku.Conv3DTranspose(output_channels, kernel_shape, stride=1, output_shape=None, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', mask=None, name=None)[source]#
Three dimensional transposed convolution (aka. deconvolution).
- __init__(output_channels, kernel_shape, stride=1, output_shape=None, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', mask=None, name=None)[source]#
Initializes the module.
- Parameters
output_channels (int) – Number of output channels.
kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 3.
stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 3. Defaults to 1.
output_shape (Optional[Union[int, Sequence[int]]]) – Output shape of the spatial dimensions of a transpose convolution. Can be either an integer or an iterable of integers. If a None value is given, a default shape is automatically calculated.
padding (Union[str, Sequence[tuple[int, int]]]) – Optional padding algorithm. Either
VALIDorSAME. Defaults toSAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.with_bias (bool) – Whether to add a bias. By default, true.
w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.
b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.
data_format (str) – The data format of the input. Either
NDHWCorNCDHW. By default,NDHWC.mask (Optional[jax.Array]) – Optional mask of the weights.
name (Optional[str]) – The name of the module.
DepthwiseConv1D#
- class haiku.DepthwiseConv1D(channel_multiplier, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', name=None)[source]#
One dimensional convolution.
- __init__(channel_multiplier, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', name=None)[source]#
Construct a 1D Depthwise Convolution.
- Parameters
channel_multiplier (int) – Multiplicity of output channels. To keep the number of output channels the same as the number of input channels, set 1.
kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 1.
stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 1. Defaults to 1.
rate (Union[int, Sequence[int]]) – Optional kernel dilation rate. Either an integer or a sequence of length 1. 1 corresponds to standard ND convolution,
rate > 1corresponds to dilated convolution. Defaults to 1.padding (Union[str, Sequence[tuple[int, int]]]) – Optional padding algorithm. Either
VALID,SAMEor a sequence ofbefore, afterpairs. Defaults toSAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.with_bias (bool) – Whether to add a bias. By default, true.
w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.
b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.
data_format (str) – The data format of the input. Can be either
channels_first,channels_last,N...CorNC.... By default,channels_last. Seeget_channel_index().name (Optional[str]) – The name of the module.
DepthwiseConv2D#
- class haiku.DepthwiseConv2D(channel_multiplier, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', name=None)[source]#
Two dimensional convolution.
- __init__(channel_multiplier, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', name=None)[source]#
Construct a 2D Depthwise Convolution.
- Parameters
channel_multiplier (int) – Multiplicity of output channels. To keep the number of output channels the same as the number of input channels, set 1.
kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 2.
stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 2. Defaults to 1.
rate (Union[int, Sequence[int]]) – Optional kernel dilation rate. Either an integer or a sequence of length 1. 1 corresponds to standard ND convolution,
rate > 1corresponds to dilated convolution. Defaults to 1.padding (Union[str, Sequence[tuple[int, int]]]) – Optional padding algorithm. Either
VALID,SAMEor a sequence ofbefore, afterpairs. Defaults toSAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.with_bias (bool) – Whether to add a bias. By default, true.
w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.
b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.
data_format (str) – The data format of the input. Can be either
channels_first,channels_last,N...CorNC.... By default,channels_last. Seeget_channel_index().name (Optional[str]) – The name of the module.
DepthwiseConv3D#
- class haiku.DepthwiseConv3D(channel_multiplier, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', name=None)[source]#
Three dimensional convolution.
- __init__(channel_multiplier, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', name=None)[source]#
Construct a 3D Depthwise Convolution.
- Parameters
channel_multiplier (int) – Multiplicity of output channels. To keep the number of output channels the same as the number of input channels, set 1.
kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 3.
stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 3. Defaults to 1.
rate (Union[int, Sequence[int]]) – Optional kernel dilation rate. Either an integer or a sequence of length 1. 1 corresponds to standard ND convolution,
rate > 1corresponds to dilated convolution. Defaults to 1.padding (Union[str, Sequence[tuple[int, int]]]) – Optional padding algorithm. Either
VALID,SAMEor a sequence ofbefore, afterpairs. Defaults toSAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.with_bias (bool) – Whether to add a bias. By default, true.
w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.
b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.
data_format (str) – The data format of the input. Can be either
channels_first,channels_last,N...CorNC.... By default,channels_last. Seeget_channel_index().name (Optional[str]) – The name of the module.
SeparableDepthwiseConv2D#
- class haiku.SeparableDepthwiseConv2D(channel_multiplier, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', name=None)[source]#
Separable 2-D Depthwise Convolution Module.
- __init__(channel_multiplier, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', name=None)[source]#
Construct a Separable 2D Depthwise Convolution module.
- Parameters
channel_multiplier (int) – Multiplicity of output channels. To keep the number of output channels the same as the number of input channels, set 1.
kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length
num_spatial_dims.stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length
num_spatial_dims. Defaults to 1.padding (Union[str, Sequence[tuple[int, int]]]) – Optional padding algorithm. Either
VALID,SAMEor a sequence ofbefore, afterpairs. Defaults toSAME. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution.with_bias (bool) – Whether to add a bias. By default, true.
w_init (Optional[hk.initializers.Initializer]) – Optional weight initialization. By default, truncated normal.
b_init (Optional[hk.initializers.Initializer]) – Optional bias initialization. By default, zeros.
data_format (str) – The data format of the input. Can be either
channels_first,channels_last,N...CorNC.... By default,channels_last.name (Optional[str]) – The name of the module.
get_channel_index#
- haiku.get_channel_index(data_format)[source]#
Returns the channel index when given a valid data format.
>>> hk.get_channel_index('channels_last') -1 >>> hk.get_channel_index('channels_first') 1 >>> hk.get_channel_index('N...C') -1 >>> hk.get_channel_index('NCHW') 1
- Parameters
data_format (str) – String, the data format to get the channel index from. Valid data formats are spatial (e.g.``NCHW``), sequential (e.g.
BTHWD),channels_firstandchannels_last).- Return type
int
- Returns
The channel index as an int, either
1or-1.- Raises
ValueError – If the data format is unrecognised.
Normalization#
|
Normalizes inputs to maintain a mean of ~0 and stddev of ~1. |
|
Group normalization module. |
|
Normalizes inputs along the spatial dimensions. |
|
LayerNorm module. |
|
RMSNorm module. |
|
Normalizes an input by its first singular value. |
|
Maintains an exponential moving average. |
|
Applies Spectral Normalization to all parameters in a tree. |
|
Maintains an exponential moving average for all parameters in a tree. |
BatchNorm#
- class haiku.BatchNorm(create_scale, create_offset, decay_rate, eps=1e-05, scale_init=None, offset_init=None, axis=None, cross_replica_axis=None, cross_replica_axis_index_groups=None, data_format='channels_last', name=None)[source]#
Normalizes inputs to maintain a mean of ~0 and stddev of ~1.
See: https://arxiv.org/abs/1502.03167.
There are many different variations for how users want to manage scale and offset if they require them at all. These are:
No scale/offset in which case
create_*should be set toFalseandscale/offsetaren’t passed when the module is called.Trainable scale/offset in which case
create_*should be set toTrueand againscale/offsetaren’t passed when the module is called. In this case this module creates and owns thescale/offsetvariables.Externally generated
scale/offset, such as for conditional normalization, in which casecreate_*should be set toFalseand then the values fed in at call time.
NOTE:
jax.vmap(hk.transform(BatchNorm))will update summary statistics and normalize values on a per-batch basis; we currently do not support normalizing across a batch axis introduced by vmap.- __init__(create_scale, create_offset, decay_rate, eps=1e-05, scale_init=None, offset_init=None, axis=None, cross_replica_axis=None, cross_replica_axis_index_groups=None, data_format='channels_last', name=None)[source]#
Constructs a BatchNorm module.
- Parameters
create_scale (bool) – Whether to include a trainable scaling factor.
create_offset (bool) – Whether to include a trainable offset.
decay_rate (float) – Decay rate for EMA.
eps (float) – Small epsilon to avoid division by zero variance. Defaults
1e-5, as in the paper and Sonnet.scale_init (Optional[hk.initializers.Initializer]) – Optional initializer for gain (aka scale). Can only be set if
create_scale=True. By default,1.offset_init (Optional[hk.initializers.Initializer]) – Optional initializer for bias (aka offset). Can only be set if
create_offset=True. By default,0.axis (Optional[Sequence[int]]) – Which axes to reduce over. The default (
None) signifies that all but the channel axis should be normalized. Otherwise this is a list of axis indices which will have normalization statistics calculated.cross_replica_axis (Optional[Union[str, Sequence[str]]]) – If not
None, it should be a string (or sequence of strings) representing the axis name(s) over which this module is being run within a jax map (e.g.jax.pmaporjax.vmap). Supplying this argument means that batch statistics are calculated across all replicas on the named axes.cross_replica_axis_index_groups (Optional[Sequence[Sequence[int]]]) – Specifies how devices are grouped. Valid only within
jax.pmapcollectives.data_format (str) – The data format of the input. Can be either
channels_first,channels_last,N...CorNC.... By default it ischannels_last. Seeget_channel_index().name (Optional[str]) – The module name.
- __call__(inputs, is_training, test_local_stats=False, scale=None, offset=None)[source]#
Computes the normalized version of the input.
- Parameters
inputs (jax.Array) – An array, where the data format is
[..., C].is_training (bool) – Whether this is during training.
test_local_stats (bool) – Whether local stats are used when is_training=False.
scale (Optional[jax.Array]) – An array up to n-D. The shape of this tensor must be broadcastable to the shape of
inputs. This is the scale applied to the normalized inputs. This cannot be passed in if the module was constructed withcreate_scale=True.offset (Optional[jax.Array]) – An array up to n-D. The shape of this tensor must be broadcastable to the shape of
inputs. This is the offset applied to the normalized inputs. This cannot be passed in if the module was constructed withcreate_offset=True.
- Return type
jax.Array
- Returns
The array, normalized across all but the last dimension.
GroupNorm#
- class haiku.GroupNorm(groups, axis=slice(1, None, None), create_scale=True, create_offset=True, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]#
Group normalization module.
This applies group normalization to the x. This involves splitting the channels into groups before calculating the mean and variance. The default behaviour is to compute the mean and variance over the spatial dimensions and the grouped channels. The mean and variance will never be computed over the created groups axis.
It transforms the input
xinto:\[\d{outputs} = \d{scale} \dfrac{x - \mu}{\sigma + \epsilon} + \d{offset}\]Where \(\mu\) and \(\sigma\) are respectively the mean and standard deviation of
x.There are many different variations for how users want to manage scale and offset if they require them at all. These are:
No
scale/offsetin which casecreate_*should be set toFalseandscale/offsetaren’t passed when the module is called.Trainable
scale/offsetin which case create_* should be set toTrueand againscale/offsetaren’t passed when the module is called. In this case this module creates and owns the scale/offset parameters.Externally generated
scale/offset, such as for conditional normalization, in which casecreate_*should be set toFalseand then the values fed in at call time.
- __init__(groups, axis=slice(1, None, None), create_scale=True, create_offset=True, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]#
Constructs a
GroupNormmodule.- Parameters
groups (int) – number of groups to divide the channels by. The number of channels must be divisible by this.
axis (Union[int, slice, Sequence[int]]) –
int,sliceor sequence of ints representing the axes which should be normalized across. By default this is all but the first dimension. For time series data use slice(2, None) to average over the none Batch and Time data.create_scale (bool) – whether to create a trainable scale per channel applied after the normalization.
create_offset (bool) – whether to create a trainable offset per channel applied after normalization and scaling.
eps (float) – Small epsilon to add to the variance to avoid division by zero. Defaults to
1e-5.scale_init (Optional[hk.initializers.Initializer]) – Optional initializer for the scale parameter. Can only be set if
create_scale=True. By default scale is initialized to1.offset_init (Optional[hk.initializers.Initializer]) – Optional initializer for the offset parameter. Can only be set if
create_offset=True. By default offset is initialized to0.data_format (str) – The data format of the input. Can be either
channels_first,channels_last,N...CorNC.... By default it ischannels_last. Seeget_channel_index().name (Optional[str]) – Name of the module.
- __call__(x, scale=None, offset=None)[source]#
Returns normalized inputs.
- Parameters
x (jax.Array) – An n-D tensor of the
data_formatspecified in the constructor on which the transformation is performed.scale (Optional[jax.Array]) – A tensor up to n-D. The shape of this tensor must be broadcastable to the shape of
x. This is the scale applied to the normalized x. This cannot be passed in if the module was constructed withcreate_scale=True.offset (Optional[jax.Array]) – A tensor up to n-D. The shape of this tensor must be broadcastable to the shape of
x. This is the offset applied to the normalizedx. This cannot be passed in if the module was constructed withcreate_offset=True.
- Return type
jax.Array
- Returns
An n-d tensor of the same shape as x that has been normalized.
InstanceNorm#
- class haiku.InstanceNorm(create_scale, create_offset, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]#
Normalizes inputs along the spatial dimensions.
See
LayerNormfor more details.- __init__(create_scale, create_offset, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]#
Constructs an
InstanceNormmodule.This method creates a module which normalizes over the spatial dimensions.
- Parameters
create_scale (bool) –
boolrepresenting whether to create a trainable scale per channel applied after the normalization.create_offset (bool) –
boolrepresenting whether to create a trainable offset per channel applied after normalization and scaling.eps (float) – Small epsilon to avoid division by zero variance. Defaults to
1e-5.scale_init (Optional[hk.initializers.Initializer]) – Optional initializer for the scale variable. Can only be set if
create_scale=True. By default scale is initialized to1.offset_init (Optional[hk.initializers.Initializer]) – Optional initializer for the offset variable. Can only be set if
create_offset=True. By default offset is initialized to0.data_format (str) – The data format of the input. Can be either
channels_first,channels_last,N...CorNC.... By default it ischannels_last. Seeget_channel_index().name (Optional[str]) – Name of the module.
LayerNorm#
- class haiku.LayerNorm(axis, create_scale, create_offset, eps=1e-05, scale_init=None, offset_init=None, use_fast_variance=False, name=None, *, param_axis=None)[source]#
LayerNorm module.
See: https://arxiv.org/abs/1607.06450.
Example usage:
>>> ln = hk.LayerNorm(axis=-1, param_axis=-1, ... create_scale=True, create_offset=True) >>> x = ln(jnp.ones([8, 224, 224, 3]))
- __init__(axis, create_scale, create_offset, eps=1e-05, scale_init=None, offset_init=None, use_fast_variance=False, name=None, *, param_axis=None)[source]#
Constructs a LayerNorm module.
- Parameters
axis (AxisOrAxes) – Integer, list of integers, or slice indicating which axes to normalize over. Note that the shape of the scale/offset parameters are controlled by the
param_axisargument.create_scale (bool) – Bool, defines whether to create a trainable scale per channel applied after the normalization.
create_offset (bool) – Bool, defines whether to create a trainable offset per channel applied after normalization and scaling.
eps (float) – Small epsilon to avoid division by zero variance. Defaults
1e-5, as in the paper and Sonnet.scale_init (Optional[hk.initializers.Initializer]) – Optional initializer for gain (aka scale). By default, one.
offset_init (Optional[hk.initializers.Initializer]) – Optional initializer for bias (aka offset). By default, zero.
use_fast_variance (bool) – If true, use a faster but less numerically stable formulation for computing variance.
name (Optional[str]) – The module name.
param_axis (Optional[AxisOrAxes]) – Axis used to determine the parameter shape of the learnable scale/offset. Sonnet sets this to the channel/feature axis (e.g. to
-1forNHWC). Other libraries set this to the same as the reduction axis (e.g.axis=param_axis).
- __call__(inputs, scale=None, offset=None)[source]#
Connects the layer norm.
- Parameters
inputs (jax.Array) – An array, where the data format is
[N, ..., C].scale (Optional[jax.Array]) – An array up to n-D. The shape of this tensor must be broadcastable to the shape of
inputs. This is the scale applied to the normalized inputs. This cannot be passed in if the module was constructed withcreate_scale=True.offset (Optional[jax.Array]) – An array up to n-D. The shape of this tensor must be broadcastable to the shape of
inputs. This is the offset applied to the normalized inputs. This cannot be passed in if the module was constructed withcreate_offset=True.
- Return type
jax.Array
- Returns
The array, normalized.
RMSNorm#
- class haiku.RMSNorm(axis, eps=1e-05, scale_init=None, name=None, create_scale=True, *, param_axis=None)[source]#
RMSNorm module.
RMSNorm provides an alternative that can be both faster and more stable than LayerNorm. The inputs are normalized by the root-mean-squared (RMS) and scaled by a learned parameter, but they are not recentered around their mean.
See https://arxiv.org/pdf/1910.07467.pdf
- __init__(axis, eps=1e-05, scale_init=None, name=None, create_scale=True, *, param_axis=None)[source]#
Constructs a RMSNorm module.
- Parameters
axis (AxisOrAxes) – Integer, list of integers, or slice indicating which axes to normalize over.
eps (float) – Small epsilon to avoid division by zero variance. Defaults to 1e-5.
scale_init (Optional[hk.initializers.Initializer]) – Optional initializer for gain (aka scale). By default, one.
name (Optional[str]) – The module name.
create_scale (bool) – Bool, defines whether to create a trainable scale per channel applied after the normalization.
param_axis (Optional[AxisOrAxes]) – Axis used to determine the parameter shape of the learnable scale/offset. Sonnet sets this to the channel/feature axis (e.g. to
-1forNHWC). Other libraries set this to the same as the reduction axis (e.g.axis=param_axis). None defaults to (-1,).
SpectralNorm#
- class haiku.SpectralNorm(eps=0.0001, n_steps=1, name=None)[source]#
Normalizes an input by its first singular value.
This module uses power iteration to calculate this value based on the input and an internal hidden state.
- __init__(eps=0.0001, n_steps=1, name=None)[source]#
Initializes an SpectralNorm module.
- Parameters
eps (float) – The constant used for numerical stability.
n_steps (int) – How many steps of power iteration to perform to approximate the singular value of the input.
name (Optional[str]) – The name of the module.
- __call__(value, update_stats=True, error_on_non_matrix=False)[source]#
Performs Spectral Normalization and returns the new value.
- Parameters
value – The array-like object for which you would like to perform an spectral normalization on.
update_stats (bool) – A boolean defaulting to True. Regardless of this arg, this function will return the normalized input. When update_stats is True, the internal state of this object will also be updated to reflect the input value. When update_stats is False the internal stats will remain unchanged.
error_on_non_matrix (bool) – Spectral normalization is only defined on matrices. By default, this module will return scalars unchanged and flatten higher-order tensors in their leading dimensions. Setting this flag to True will instead throw errors in those cases.
- Return type
jax.Array
- Returns
The input value normalized by it’s first singular value.
- Raises
ValueError – If error_on_non_matrix is True and value has ndims > 2.
ExponentialMovingAverage#
- class haiku.ExponentialMovingAverage(decay, zero_debias=True, warmup_length=0, name=None)[source]#
Maintains an exponential moving average.
This uses the Adam debiasing procedure. See https://arxiv.org/pdf/1412.6980.pdf for details.
- __init__(decay, zero_debias=True, warmup_length=0, name=None)[source]#
Initializes an ExponentialMovingAverage module.
- Parameters
decay – The chosen decay. Must in
[0, 1). Values close to 1 result in slow decay; values close to0result in fast decay.zero_debias (bool) – Whether to run with zero-debiasing.
warmup_length (int) – A positive integer, EMA has no effect until the internal counter has reached warmup_length at which point the initial value for the decaying average is initialized to the input value after warmup_length iterations.
name (Optional[str]) – The name of the module.
- initialize(shape, dtype=<class 'jax.numpy.float32'>)[source]#
If uninitialized sets the average to
zerosof the given shape/dtype.
- __call__(value, update_stats=True)[source]#
Updates the EMA and returns the new value.
- Parameters
value (Union[float, jax.Array]) – The array-like object for which you would like to perform an exponential decay on.
update_stats (bool) – A Boolean, whether to update the internal state of this object to reflect the input value. When update_stats is False the internal stats will remain unchanged.
- Return type
jax.Array
- Returns
The exponentially weighted average of the input value.
SNParamsTree#
- class haiku.SNParamsTree(eps=0.0001, n_steps=1, ignore_regex='', name=None)[source]#
Applies Spectral Normalization to all parameters in a tree.
This is isomorphic to EMAParamsTree in moving_averages.py.
- __init__(eps=0.0001, n_steps=1, ignore_regex='', name=None)[source]#
Initializes an SNParamsTree module.
- Parameters
eps (float) – The constant used for numerical stability.
n_steps (int) – How many steps of power iteration to perform to approximate the singular value of the input.
ignore_regex (str) – A string. Any parameter in the tree whose name matches this regex will not have spectral normalization applied to it. The empty string means this module applies to all parameters.
name (Optional[str]) – The name of the module.
EMAParamsTree#
- class haiku.EMAParamsTree(decay, zero_debias=True, warmup_length=0, ignore_regex='', name=None)[source]#
Maintains an exponential moving average for all parameters in a tree.
While ExponentialMovingAverage is meant to be applied to single parameters within a function, this class is meant to be applied to the entire tree of parameters for a function.
Given a set of parameters for some network:
>>> network_fn = lambda x: hk.Linear(10)(x) >>> x = jnp.ones([1, 1]) >>> params = hk.transform(network_fn).init(jax.random.PRNGKey(428), x)
You might use the EMAParamsTree like follows:
>>> ema_fn = hk.transform_with_state(lambda x: hk.EMAParamsTree(0.2)(x)) >>> _, ema_state = ema_fn.init(None, params) >>> ema_params, ema_state = ema_fn.apply(None, ema_state, None, params)
Here, we are transforming a Haiku function and constructing its parameters via an init_fn as normal, but are creating a second transformed function which expects a tree of parameters as input. This function is then called with the current parameters as input, which then returns an identical tree with every parameter replaced with its exponentially decayed average. This ema_params object can then be passed into the network_fn as usual, and will cause it to run with EMA weights.
- __init__(decay, zero_debias=True, warmup_length=0, ignore_regex='', name=None)[source]#
Initializes an EMAParamsTree module.
- Parameters
decay – The chosen decay. Must in
[0, 1). Values close to1result in slow decay; values close to0result in fast decay.zero_debias (bool) – Whether to run with zero-debiasing.
warmup_length (int) – A positive integer, EMA has no effect until the internal counter has reached warmup_length at which point the initial value for the decaying average is initialized to the input value after warmup_length iterations.
ignore_regex (str) – A string. Any parameter in the tree whose name matches this regex will not have any moving average applied to it. The empty string means this module will EMA all parameters.
name (Optional[str]) – The name of the module.
Recurrent#
|
Base class for RNN cores. |
|
Performs a dynamic unroll of an RNN. |
|
Performs a static unroll of an RNN. |
|
Wraps f to temporarily add a size-1 axis to its inputs. |
|
Basic fully-connected RNN core. |
|
Long short-term memory (LSTM) RNN core. |
|
Gated Recurrent Unit. |
|
Wraps a sequence of cores and callables as a single core. |
|
Constructs a |
|
A wrapper for managing state resets during unrolls. |
|
A recurrent core that forwards the inputs and an empty state. |
|
1-D convolutional LSTM. |
|
2-D convolutional LSTM. |
|
3-D convolutional LSTM. |
RNNCore#
- class haiku.RNNCore(name=None)[source]#
Base class for RNN cores.
This class defines the basic functionality that every core should implement:
initial_state(), used to construct an example of the core state; and__call__()which applies the core parameterized by a previous state to an input.Cores may be used with
dynamic_unroll()andstatic_unroll()to iteratively construct an output sequence from the given input sequence.- abstract __call__(inputs, prev_state)[source]#
Run one step of the RNN.
- Parameters
inputs – An arbitrarily nested structure.
prev_state – Previous core state.
- Return type
tuple[Any, Any]
- Returns
A tuple with two elements
output, next_state.outputis an arbitrarily nested structure.next_stateis the next core state, this must be the same shape asprev_state.
- abstract initial_state(batch_size)[source]#
Constructs an initial state for this core.
- Parameters
batch_size (Optional[int]) – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.
- Returns
Arbitrarily nested initial state for this core.
dynamic_unroll#
- haiku.dynamic_unroll(core, input_sequence, initial_state, time_major=True, reverse=False, return_all_states=False, unroll=1)[source]#
Performs a dynamic unroll of an RNN.
An unroll corresponds to calling the core on each element of the input sequence in a loop, carrying the state through:
state = initial_state for t in range(len(input_sequence)): outputs, state = core(input_sequence[t], state)
A dynamic unroll preserves the loop structure when executed inside
jax.jit(). Seestatic_unroll()for an unroll function which replaces a loop with its body repeated multiple times.- Parameters
core – An
RNNCoreto unroll.input_sequence – An arbitrarily nested structure of tensors of shape
[T, ...]if time-major=True, or[B, T, ...]if time_major=False, whereTis the number of time steps.initial_state – An initial state of the given core.
time_major – If True, inputs are expected time-major, otherwise they are expected batch-major.
reverse – If True, inputs are scanned in the reversed order. Equivalent to reversing the time dimension in both inputs and outputs. See https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html for more details.
return_all_states – If True, all intermediate states are returned rather than only the last one in time.
unroll – How many scan iterations to unroll within a single iteration of a loop.
- Returns
output_sequence - An arbitrarily nested structure of tensors of shape
[T, ...]if time-major, otherwise[B, T, ...].state_sequence - If return_all_states is True, returns the sequence of core states. Otherwise, core state at time step
T.
- Return type
A tuple with two elements
static_unroll#
- haiku.static_unroll(core, input_sequence, initial_state, time_major=True)[source]#
Performs a static unroll of an RNN.
An unroll corresponds to calling the core on each element of the input sequence in a loop, carrying the state through:
state = initial_state for t in range(len(input_sequence)): outputs, state = core(input_sequence[t], state)
A static unroll replaces a loop with its body repeated multiple times when executed inside
jax.jit():state = initial_state outputs0, state = core(input_sequence[0], state) outputs1, state = core(input_sequence[1], state) outputs2, state = core(input_sequence[2], state) ...
See
dynamic_unroll()for a loop-preserving unroll function.- Parameters
core – An
RNNCoreto unroll.input_sequence – An arbitrarily nested structure of tensors of shape
[T, ...]if time-major=True, or[B, T, ...]if time_major=False, whereTis the number of time steps.initial_state – An initial state of the given core.
time_major – If True, inputs are expected time-major, otherwise they are expected batch-major.
- Returns
output_sequence - An arbitrarily nested structure of tensors of shape
[T, ...]if time-major, otherwise[B, T, ...].final_state - Core state at time step
T.
- Return type
A tuple with two elements
expand_apply#
- haiku.expand_apply(f, axis=0)[source]#
Wraps f to temporarily add a size-1 axis to its inputs.
Syntactic sugar for:
ins = jax.tree_util.tree_map(lambda t: np.expand_dims(t, axis=axis), ins) out = f(ins) out = jax.tree_util.tree_map(lambda t: np.squeeze(t, axis=axis), out)
This may be useful for applying a function built for
[Time, Batch, ...]arrays to a single timestep.- Parameters
f – The callable to be applied to the expanded inputs.
axis – Where to add the extra axis.
- Returns
f, wrapped as described above.
VanillaRNN#
- class haiku.VanillaRNN(hidden_size, double_bias=True, name=None)[source]#
Basic fully-connected RNN core.
Given \(x_t\) and the previous hidden state \(h_{t-1}\) the core computes
\[h_t = \operatorname{ReLU}(w_i x_t + b_i + w_h h_{t-1} + b_h)\]The output is equal to the new state, \(h_t\).
- __init__(hidden_size, double_bias=True, name=None)[source]#
Constructs a vanilla RNN core.
- Parameters
hidden_size (int) – Hidden layer size.
double_bias (bool) – Whether to use a bias in the two linear layers. This changes nothing to the learning performance of the cell. However, doubling will create two sets of bias parameters rather than one.
name (Optional[str]) – Name of the module.
- __call__(inputs, prev_state)[source]#
Run one step of the RNN.
- Parameters
inputs – An arbitrarily nested structure.
prev_state – Previous core state.
- Returns
A tuple with two elements
output, next_state.outputis an arbitrarily nested structure.next_stateis the next core state, this must be the same shape asprev_state.
- initial_state(batch_size)[source]#
Constructs an initial state for this core.
- Parameters
batch_size (Optional[int]) – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.
- Returns
Arbitrarily nested initial state for this core.
LSTM#
- class haiku.LSTM(hidden_size, name=None)[source]#
Long short-term memory (LSTM) RNN core.
The implementation is based on [1]. Given \(x_t\) and the previous state \((h_{t-1}, c_{t-1})\) the core computes
\[\begin{array}{ll} i_t = \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\ f_t = \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\ o_t = \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\]where \(i_t\), \(f_t\), \(o_t\) are input, forget and output gate activations, and \(g_t\) is a vector of cell updates.
The output is equal to the new hidden, \(h_t\).
Notes
- Forget gate initialization:
Following [2] we add 1.0 to \(b_f\) after initialization in order to reduce the scale of forgetting in the beginning of the training.
- __init__(hidden_size, name=None)[source]#
Constructs an LSTM.
- Parameters
hidden_size (int) – Hidden layer size.
name (Optional[str]) – Name of the module.
- __call__(inputs, prev_state)[source]#
Run one step of the RNN.
- Parameters
inputs (jax.Array) – An arbitrarily nested structure.
prev_state (LSTMState) – Previous core state.
- Return type
tuple[jax.Array, LSTMState]
- Returns
A tuple with two elements
output, next_state.outputis an arbitrarily nested structure.next_stateis the next core state, this must be the same shape asprev_state.
- initial_state(batch_size)[source]#
Constructs an initial state for this core.
- Parameters
batch_size (Optional[int]) – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.
- Return type
- Returns
Arbitrarily nested initial state for this core.
GRU#
- class haiku.GRU(hidden_size, w_i_init=None, w_h_init=None, b_init=None, name=None)[source]#
Gated Recurrent Unit.
The implementation is based on: https://arxiv.org/pdf/1412.3555v1.pdf with biases.
Given \(x_t\) and the previous state \(h_{t-1}\) the core computes
\[\begin{array}{ll} z_t &= \sigma(W_{iz} x_t + W_{hz} h_{t-1} + b_z) \\ r_t &= \sigma(W_{ir} x_t + W_{hr} h_{t-1} + b_r) \\ a_t &= \tanh(W_{ia} x_t + W_{ha} (r_t \bigodot h_{t-1}) + b_a) \\ h_t &= (1 - z_t) \bigodot h_{t-1} + z_t \bigodot a_t \end{array}\]where \(z_t\) and \(r_t\) are reset and update gates.
The output is equal to the new hidden state, \(h_t\).
- __init__(hidden_size, w_i_init=None, w_h_init=None, b_init=None, name=None)[source]#
Initializes the current module with the given name.
Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.
- Parameters
name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If
nameis not provided then the class name for the current instance is converted tolower_snake_caseand used instead.
- __call__(inputs, state)[source]#
Run one step of the RNN.
- Parameters
inputs – An arbitrarily nested structure.
prev_state – Previous core state.
- Returns
A tuple with two elements
output, next_state.outputis an arbitrarily nested structure.next_stateis the next core state, this must be the same shape asprev_state.
- initial_state(batch_size)[source]#
Constructs an initial state for this core.
- Parameters
batch_size (Optional[int]) – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.
- Returns
Arbitrarily nested initial state for this core.
DeepRNN#
- class haiku.DeepRNN(layers, name=None)[source]#
Wraps a sequence of cores and callables as a single core.
>>> deep_rnn = hk.DeepRNN([ ... hk.LSTM(hidden_size=4), ... jax.nn.relu, ... hk.LSTM(hidden_size=2), ... ])
The state of a
DeepRNNis a tuple with one element perRNNCore. If no layers areRNNCores, the state is an empty tuple.- __init__(layers, name=None)[source]#
Initializes the current module with the given name.
Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.
- Parameters
name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If
nameis not provided then the class name for the current instance is converted tolower_snake_caseand used instead.
- haiku.deep_rnn_with_skip_connections(layers, name=None)[source]#
Constructs a
DeepRNNwith skip connections.Skip connections alter the dependency structure within a
DeepRNN. Specifically, input to the i-th layer (i > 0) is given by a concatenation of the core’s inputs and the outputs of the (i-1)-th layer.The output of the
DeepRNNis the concatenation of the outputs of all cores.outputs0, ... = layers[0](inputs, ...) outputs1, ... = layers[1](tf.concat([inputs, outputs0], axis=-1], ...) outputs2, ... = layers[2](tf.concat([inputs, outputs1], axis=-1], ...) ...
ResetCore#
- class haiku.ResetCore(core, name=None)[source]#
A wrapper for managing state resets during unrolls.
When unrolling an
RNNCoreon a batch of inputs sequences it may be necessary to reset the core’s state at different timesteps for different elements of the batch. TheResetCoreclass enables this by taking a batch ofshould_resetbooleans in addition to the batch of inputs, and conditionally resetting the core’s state for individual elements of the batch. You may also reset individual entries of the state by passing ashould_resetnest compatible with the state structure.- __init__(core, name=None)[source]#
Initializes the current module with the given name.
Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.
- Parameters
name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If
nameis not provided then the class name for the current instance is converted tolower_snake_caseand used instead.
- __call__(inputs, state)[source]#
Run one step of the wrapped core, handling state reset.
- Parameters
inputs – Tuple with two elements,
inputs, should_reset, whereshould_resetis the signal used to reset the wrapped core’s state.should_resetcan be either tensor or nest. If nest,should_resetmust match the state structure, and its components’ shapes must be prefixes of the corresponding entries tensors’ shapes in the state nest. If tensor, supported shapes are all commom shape prefixes of the state component tensors, e.g.[batch_size].state – Previous wrapped core state.
- Returns
Tuple of the wrapped core’s
output, next_state.
- initial_state(batch_size)[source]#
Constructs an initial state for this core.
- Parameters
batch_size (Optional[int]) – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.
- Returns
Arbitrarily nested initial state for this core.
IdentityCore#
- class haiku.IdentityCore(name=None)[source]#
A recurrent core that forwards the inputs and an empty state.
This is commonly used when switching between recurrent and feedforward versions of a model while preserving the same interface.
- __call__(inputs, state)[source]#
Run one step of the RNN.
- Parameters
inputs – An arbitrarily nested structure.
prev_state – Previous core state.
- Returns
A tuple with two elements
output, next_state.outputis an arbitrarily nested structure.next_stateis the next core state, this must be the same shape asprev_state.
- initial_state(batch_size)[source]#
Constructs an initial state for this core.
- Parameters
batch_size (Optional[int]) – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.
- Returns
Arbitrarily nested initial state for this core.
Conv1DLSTM#
- class haiku.Conv1DLSTM(input_shape, output_channels, kernel_shape, name=None)[source]#
1-D convolutional LSTM.
The implementation is based on [3]. Given \(x_t\) and the previous state \((h_{t-1}, c_{t-1})\) the core computes
\[\begin{array}{ll} i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\ f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\ o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\]where \(*\) denotes the convolution operator; \(i_t\), \(f_t\), \(o_t\) are input, forget and output gate activations, and \(g_t\) is a vector of cell updates.
The output is equal to the new hidden state, \(h_t\).
Notes
- Forget gate initialization:
Following [2] we add 1.0 to \(b_f\) after initialization in order to reduce the scale of forgetting in the beginning of the training.
- __init__(input_shape, output_channels, kernel_shape, name=None)[source]#
Constructs a 1-D convolutional LSTM.
- Parameters
input_shape (Sequence[int]) – Shape of the inputs excluding batch size.
output_channels (int) – Number of output channels.
kernel_shape (Union[int, Sequence[int]]) – Sequence of kernel sizes (of length 1), or an int.
kernel_shapewill be expanded to define a kernel size in all dimensions.name (Optional[str]) – Name of the module.
Conv2DLSTM#
- class haiku.Conv2DLSTM(input_shape, output_channels, kernel_shape, name=None)[source]#
2-D convolutional LSTM.
The implementation is based on [3]. Given \(x_t\) and the previous state \((h_{t-1}, c_{t-1})\) the core computes
\[\begin{array}{ll} i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\ f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\ o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\]where \(*\) denotes the convolution operator; \(i_t\), \(f_t\), \(o_t\) are input, forget and output gate activations, and \(g_t\) is a vector of cell updates.
The output is equal to the new hidden state, \(h_t\).
Notes
- Forget gate initialization:
Following [2] we add 1.0 to \(b_f\) after initialization in order to reduce the scale of forgetting in the beginning of the training.
- __init__(input_shape, output_channels, kernel_shape, name=None)[source]#
Constructs a 2-D convolutional LSTM.
- Parameters
input_shape (Sequence[int]) – Shape of the inputs excluding batch size.
output_channels (int) – Number of output channels.
kernel_shape (Union[int, Sequence[int]]) – Sequence of kernel sizes (of length 2), or an int.
kernel_shapewill be expanded to define a kernel size in all dimensions.name (Optional[str]) – Name of the module.
Conv3DLSTM#
- class haiku.Conv3DLSTM(input_shape, output_channels, kernel_shape, name=None)[source]#
3-D convolutional LSTM.
The implementation is based on [3]. Given \(x_t\) and the previous state \((h_{t-1}, c_{t-1})\) the core computes
\[\begin{array}{ll} i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\ f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\ o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\]where \(*\) denotes the convolution operator; \(i_t\), \(f_t\), \(o_t\) are input, forget and output gate activations, and \(g_t\) is a vector of cell updates.
The output is equal to the new hidden state, \(h_t\).
Notes
- Forget gate initialization:
Following [2] we add 1.0 to \(b_f\) after initialization in order to reduce the scale of forgetting in the beginning of the training.
- __init__(input_shape, output_channels, kernel_shape, name=None)[source]#
Constructs a 3-D convolutional LSTM.
- Parameters
input_shape (Sequence[int]) – Shape of the inputs excluding batch size.
output_channels (int) – Number of output channels.
kernel_shape (Union[int, Sequence[int]]) – Sequence of kernel sizes (of length 3), or an int.
kernel_shapewill be expanded to define a kernel size in all dimensions.name (Optional[str]) – Name of the module.
Attention#
MultiHeadAttention#
- class haiku.MultiHeadAttention(num_heads, key_size, w_init_scale=None, *, w_init=None, with_bias=True, b_init=None, value_size=None, model_size=None, name=None)[source]#
Multi-headed attention (MHA) module.
This module is intended for attending over sequences of vectors.
Rough sketch: - Compute keys (K), queries (Q), and values (V) as projections of inputs. - Attention weights are computed as W = softmax(QK^T / sqrt(key_size)). - Output is another projection of WV^T.
- For more detail, see the original Transformer paper:
“Attention is all you need” https://arxiv.org/abs/1706.03762.
Glossary of shapes: - T: Sequence length. - D: Vector (embedding) size. - H: Number of attention heads.
- __init__(num_heads, key_size, w_init_scale=None, *, w_init=None, with_bias=True, b_init=None, value_size=None, model_size=None, name=None)[source]#
Initialises the module.
- Parameters
num_heads (int) – Number of independent attention heads (H).
key_size (int) – The size of keys (K) and queries used for attention.
w_init_scale (Optional[float]) – DEPRECATED. Please use w_init instead.
w_init (Optional[hk.initializers.Initializer]) – Initialiser for weights in the linear map. Once w_init_scale is fully deprecated w_init will become mandatory. Until then it has a default value of None for backwards compatability.
with_bias (bool) – Whether to add a bias when computing various linear projections.
b_init (Optional[hk.initializers.Initializer]) – Optional initializer for bias. By default, zero.
value_size (Optional[int]) – Optional size of the value projection (V). If None, defaults to the key size (K).
model_size (Optional[int]) – Optional size of the output embedding (D’). If None, defaults to the key size multiplied by the number of heads (K * H).
name (Optional[str]) – Optional name for this module.
- __call__(query, key, value, mask=None)[source]#
Computes (optionally masked) MHA with queries, keys & values.
This module broadcasts over zero or more ‘batch-like’ leading dimensions.
- Parameters
query (jax.Array) – Embeddings sequence used to compute queries; shape […, T’, D_q].
key (jax.Array) – Embeddings sequence used to compute keys; shape […, T, D_k].
value (jax.Array) – Embeddings sequence used to compute values; shape […, T, D_v].
mask (Optional[jax.Array]) – Optional mask applied to attention weights; shape […, H=1, T’, T].
- Return type
jax.Array
- Returns
- A new sequence of embeddings, consisting of a projection of the
attention-weighted value projections; shape […, T’, D’].
Batch#
|
Reshapes input Tensor, preserving the batch dimension. |
|
Flattens the input, preserving the batch dimension(s). |
|
Temporarily merges leading dimensions of input tensors. |
Reshape#
- class haiku.Reshape(output_shape, preserve_dims=1, name=None)[source]#
Reshapes input Tensor, preserving the batch dimension.
For example, given an input tensor with shape
[B, H, W, C, D]:>>> B, H, W, C, D = range(1, 6) >>> x = jnp.ones([B, H, W, C, D])
The default behavior when
output_shapeis(-1, D)is to flatten all dimensions betweenBandD:>>> mod = hk.Reshape(output_shape=(-1, D)) >>> assert mod(x).shape == (B, H*W*C, D)
You can change the number of preserved leading dimensions via
preserve_dims:>>> mod = hk.Reshape(output_shape=(-1, D), preserve_dims=2) >>> assert mod(x).shape == (B, H, W*C, D) >>> mod = hk.Reshape(output_shape=(-1, D), preserve_dims=3) >>> assert mod(x).shape == (B, H, W, C, D) >>> mod = hk.Reshape(output_shape=(-1, D), preserve_dims=4) >>> assert mod(x).shape == (B, H, W, C, 1, D)
Alternatively, a negative value of
preserve_dimsspecifies the number of trailing dimensions to replace withoutput_shape:>>> mod = hk.Reshape(output_shape=(-1, D), preserve_dims=-3) >>> assert mod(x).shape == (B, H, W*C, D)
This is useful in the case of applying the same module to batched and unbatched outputs:
>>> mod = hk.Reshape(output_shape=(-1, D), preserve_dims=-3) >>> assert mod(x[0]).shape == (H, W*C, D)
- __init__(output_shape, preserve_dims=1, name=None)[source]#
Constructs a
Reshapemodule.- Parameters
output_shape (Sequence[int]) – Shape to reshape the input tensor to while preserving its first
preserve_dimsdimensions. When the special value-1appears inoutput_shapethe corresponding size is automatically inferred. Note that-1can only appear once inoutput_shape. To flatten all non-batch dimensions useFlatten.preserve_dims (int) – Number of leading dimensions that will not be reshaped. If negative, this is interpreted instead as the number of trailing dimensions to replace with the new shape.
name (Optional[str]) – Name of the module.
- Raises
ValueError – If
preserve_dimsis zero.
Flatten#
- class haiku.Flatten(preserve_dims=1, name=None)[source]#
Flattens the input, preserving the batch dimension(s).
By default,
Flattencombines all dimensions except the first. Additional leading dimensions can be preserved by settingpreserve_dims.>>> x = jnp.ones([3, 2, 4]) >>> flat = hk.Flatten() >>> flat(x).shape (3, 8)
When the input to flatten has fewer than
preserve_dimsdimensions it is returned unchanged:>>> x = jnp.ones([3]) >>> flat(x).shape (3,)
Alternatively, a negative value of preserve_dims specifies the number of trailing dimensions flattened:
>>> x = jnp.ones([3, 2, 4]) >>> negative_flat = hk.Flatten(preserve_dims=-2) >>> negative_flat(x).shape (3, 8)
This allows the same module to be seamlessly applied to a single element or a batch of elements with the same element shape:
>> negative_flat(x[0]).shape (8,)
- __init__(preserve_dims=1, name=None)[source]#
Constructs a
Reshapemodule.- Parameters
output_shape – Shape to reshape the input tensor to while preserving its first
preserve_dimsdimensions. When the special value-1appears inoutput_shapethe corresponding size is automatically inferred. Note that-1can only appear once inoutput_shape. To flatten all non-batch dimensions useFlatten.preserve_dims (int) – Number of leading dimensions that will not be reshaped. If negative, this is interpreted instead as the number of trailing dimensions to replace with the new shape.
name (Optional[str]) – Name of the module.
- Raises
ValueError – If
preserve_dimsis zero.
BatchApply#
- class haiku.BatchApply(f, num_dims=2)[source]#
Temporarily merges leading dimensions of input tensors.
Merges the leading dimensions of a tensor into a single dimension, runs the given callable, then splits the leading dimension of the result to match the input.
Input arrays whose rank is smaller than the number of dimensions to collapse are passed unmodified.
This may be useful for applying a module to each timestep of e.g. a
[Time, Batch, ...]array.For some
fs and platforms, this may be more efficient thanjax.vmap(), especially when combined with other transformations likejax.grad().- __init__(f, num_dims=2)[source]#
Constructs a
BatchApplymodule.- Parameters
f – The callable to be applied to the reshaped array.
num_dims – The number of dimensions to merge.
Embedding#
|
Module for embedding tokens in a low-dimensional space. |
|
How to return the embedding matrices given IDs. |
Embed#
- class haiku.Embed(vocab_size=None, embed_dim=None, embedding_matrix=None, w_init=None, lookup_style='ARRAY_INDEX', name=None, precision=<Precision.HIGHEST: 2>)[source]#
Module for embedding tokens in a low-dimensional space.
- __init__(vocab_size=None, embed_dim=None, embedding_matrix=None, w_init=None, lookup_style='ARRAY_INDEX', name=None, precision=<Precision.HIGHEST: 2>)[source]#
Constructs an Embed module.
- Parameters
vocab_size (Optional[int]) – The number of unique tokens to embed. If not provided, an existing vocabulary matrix from which
vocab_sizecan be inferred must be provided asembedding_matrix.embed_dim (Optional[int]) – Number of dimensions to assign to each embedding. If an existing vocabulary matrix initializes the module, this should not be provided as it will be inferred.
embedding_matrix (Optional[Union[np.ndarray, jax.Array]]) – A matrix-like object equivalent in size to
[vocab_size, embed_dim]. If given, it is used as the initial value for the embedding matrix and neithervocab_sizeorembed_dimneed be given. If they are given, their values are checked to be consistent with the dimensions ofembedding_matrix.w_init (Optional[hk.initializers.Initializer]) – An initializer for the embeddings matrix. As a default, embeddings are initialized via a truncated normal distribution.
lookup_style (Union[str, EmbedLookupStyle]) – One of the enum values of
EmbedLookupStyledetermining how to access the value of the embeddings given an ID. Regardless the input should be a dense array of integer values representing ids. This setting changes how internally this module maps those ids to embeddings. The result is the same, but the speed and memory tradeoffs are different. It defaults to using NumPy-style array indexing. This value is only the default for the module, and at any given invocation can be overridden in__call__().name (Optional[str]) – Optional name for this module.
precision (jax.lax.Precision) – Only used when lookup_style is ONE_HOT. The precision to use for the dot-product between the one-hot-encoded inputs and the embedding vectors. It is possible to attain a ~2x speedup on TPU using jax.lax.Precision.DEFAULT at the cost of a slightly lower precision.
- Raises
ValueError – If none of
embed_dim,embedding_matrixandvocab_sizeare supplied, or ifembedding_matrixis supplied andembed_dimorvocab_sizeis not consistent with the supplied matrix.
- __call__(ids, lookup_style=None, precision=None)[source]#
Lookup embeddings.
Looks up an embedding vector for each value in
ids. All ids must be within[0, vocab_size)to preventNaNs from propagating.- Parameters
ids (Union[jax.Array, Sequence[int]]) – integer array.
lookup_style (Optional[Union[str, hk.EmbedLookupStyle]]) – Overrides the
lookup_stylegiven in the constructor.precision (Optional[jax.lax.Precision]) – Overrides the
precisiongiven in the constructor.
- Return type
jax.Array
- Returns
Tensor of
ids.shape + [embedding_dim].- Raises
AttributeError – If
lookup_styleis not valid.ValueError – If
idsis not an integer array.
EmbedLookupStyle#
Utilities#
|
Defers the construction of another module until the first call. |
Deferred#
- class haiku.Deferred(factory, call_methods=('__call__',))[source]#
Defers the construction of another module until the first call.
Deferred can be used to declare modules that depend on computed properties of other modules before those modules are defined. This allows users to separate the declaration and use of modules. For example at the start of your program you can declare two modules which are coupled:
>>> encoder = hk.Linear(64) >>> decoder = hk.Deferred(lambda: hk.Linear(encoder.input_size))
Later you can use these naturally (note: that using decoder first would cause an error since encoder.input_size is only defined after encoder has been called):
>>> x = jnp.ones([8, 32]) >>> y = encoder(x) >>> z = decoder(y) # Constructs the Linear encoder by calling the lambda.
The result will satisfy the following conditions:
>>> assert x.shape == z.shape >>> assert y.shape == (8, 64) >>> assert decoder.input_size == encoder.output_size >>> assert decoder.output_size == encoder.input_size
- __init__(factory, call_methods=('__call__',))[source]#
Initializes the
Deferredmodule.- Parameters
factory (Callable[[], T]) – A no argument callable which constructs the module to defer to. The first time one of the call_methods are called the factory will be run and then the constructed module will be called with the same method and arguments as the deferred module.
call_methods (Sequence[str]) – Methods which should trigger construction of the target module. The default value configures this module to construct the first time __call__ is run. If you want to add methods other than call you should explicitly pass them (optionally), for example call_methods=(“__call__”, “encode”, “decode”).
Initializers#
alias of |
|
|
Initializes with a constant. |
|
Initializer that generates the identity matrix. |
|
Uniform scaling initializer. |
|
Initializes by sampling from a normal distribution. |
|
Initializes by sampling from a uniform distribution. |
|
Initializes by sampling from a truncated normal distribution. |
|
Initializer which adapts its scale to the shape of the initialized array. |
|
Uniform scaling initializer. |
Initializer#
- haiku.initializers.Initializer#
alias of
Callable[[collections.abc.Sequence[int],Any],jax.Array]
Constant#
Identity#
- class haiku.initializers.Identity(gain=1.0)[source]#
Initializer that generates the identity matrix.
Constructs a 2D identity matrix or batches of these.
Orthogonal#
- class haiku.initializers.Orthogonal(scale=1.0, axis=- 1)[source]#
Uniform scaling initializer.
- __init__(scale=1.0, axis=- 1)[source]#
Construct an initializer for uniformly distributed orthogonal matrices.
These matrices will be row-orthonormal along the access specified by
axis. If the rank of the weight is greater than 2, the shape will be flattened in all other dimensions and then will be row-orthonormal along the final dimension. Note that this only works if theaxisdimension is larger, otherwise the matrix will be transposed (equivalently, it will be column orthonormal instead of row orthonormal).If the shape is not square, the matrices will have orthonormal rows or columns depending on which side is smaller.
- Parameters
scale – Scale factor.
axis – Which axis corresponds to the “output dimension” of the tensor.
- Returns
An orthogonally initialized parameter.
RandomNormal#
- class haiku.initializers.RandomNormal(stddev=1.0, mean=0.0)[source]#
Initializes by sampling from a normal distribution.
- __init__(stddev=1.0, mean=0.0)[source]#
Constructs a
RandomNormalinitializer.- Parameters
stddev – The standard deviation of the normal distribution to sample from.
mean – The mean of the normal distribution to sample from.
RandomUniform#
- class haiku.initializers.RandomUniform(minval=0.0, maxval=1.0)[source]#
Initializes by sampling from a uniform distribution.
- __init__(minval=0.0, maxval=1.0)[source]#
Constructs a
RandomUniforminitializer.- Parameters
minval – The lower limit of the uniform distribution.
maxval – The upper limit of the uniform distribution.
TruncatedNormal#
- class haiku.initializers.TruncatedNormal(stddev=1.0, mean=0.0, lower=- 2.0, upper=2.0)[source]#
Initializes by sampling from a truncated normal distribution.
- __init__(stddev=1.0, mean=0.0, lower=- 2.0, upper=2.0)[source]#
Constructs a
TruncatedNormalinitializer.- Parameters
stddev (Union[float, jax.Array]) – The standard deviation parameter of the truncated normal distribution.
mean (Union[float, complex, jax.Array]) – The mean of the truncated normal distribution.
lower (Union[float, jax.Array]) – Float or array representing the lower bound for truncation.
upper (Union[float, jax.Array]) – Float or array representing the upper bound for truncation.
VarianceScaling#
- class haiku.initializers.VarianceScaling(scale=1.0, mode='fan_in', distribution='truncated_normal', fan_in_axes=None)[source]#
Initializer which adapts its scale to the shape of the initialized array.
The initializer first computes the scaling factor
s = scale / n, where n is:Number of input units in the weight tensor, if
mode = fan_in.Number of output units, if
mode = fan_out.Average of the numbers of input and output units, if
mode = fan_avg.
Then, with
distribution="truncated_normal"or"normal", samples are drawn from a distribution with a mean of zero and a standard deviation (after truncation, if used)stddev = sqrt(s).With
distribution=uniform, samples are drawn from a uniform distribution within[-limit, limit], withlimit = sqrt(3 * s).The variance scaling initializer can be configured to generate other standard initializers using the scale, mode and distribution arguments. Here are some example configurations:
Name
Parameters
glorot_uniform
VarianceScaling(1.0, “fan_avg”, “uniform”)
glorot_normal
VarianceScaling(1.0, “fan_avg”, “truncated_normal”)
lecun_uniform
VarianceScaling(1.0, “fan_in”, “uniform”)
lecun_normal
VarianceScaling(1.0, “fan_in”, “truncated_normal”)
he_uniform
VarianceScaling(2.0, “fan_in”, “uniform”)
he_normal
VarianceScaling(2.0, “fan_in”, “truncated_normal”)
- __init__(scale=1.0, mode='fan_in', distribution='truncated_normal', fan_in_axes=None)[source]#
Constructs the
VarianceScalinginitializer.- Parameters
scale – Scale to multiply the variance by.
mode – One of
fan_in,fan_out,fan_avgdistribution – Random distribution to use. One of
truncated_normal,normaloruniform.fan_in_axes – Optional sequence of int specifying which axes of the shape are part of the fan-in. If none provided, then the weight is assumed to be like a convolution kernel, where all leading dimensions are part of the fan-in, and only the trailing dimension is part of the fan-out. Useful if instantiating multi-headed attention weights.
UniformScaling#
- class haiku.initializers.UniformScaling(scale=1.0)[source]#
Uniform scaling initializer.
Initializes by sampling from a uniform distribution, but with the variance scaled by the inverse square root of the number of input units, multiplied by the scale.
- __init__(scale=1.0)[source]#
Constructs the
UniformScalinginitializer.- Parameters
scale – Scale to multiply the upper limit of the uniform distribution by.
Paddings#
alias of |
|
|
Tests whether the given argument is a single or sequence of PadFns. |
|
Generates the padding required for a given padding algorithm. |
|
Generates the padding required for a given padding algorithm. |
|
Create a padding tuple using partially specified padding tuple. |
|
Pre-padding such that output has no dependence on the future. |
|
Maximal padding whilst not convolving over just padded elements. |
|
Post-padding such that output has no dependence on the past. |
|
Pads such that the output size matches input size for stride=1. |
|
No padding. |
PadFn#
- haiku.pad.PadFn#
alias of
Callable[[int],tuple[int,int]]
is_padfn#
create#
- haiku.pad.create(padding, kernel, rate, n)[source]#
Generates the padding required for a given padding algorithm.
- Parameters
padding (Union[hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – callable/tuple or a sequence of callables/tuples. The callables take an integer representing the effective kernel size (kernel size when the rate is 1) and return a sequence of two integers representing the padding before and padding after for that dimension. The tuples are defined with two elements, padding before and after. If padding is a sequence it must be of length 1 or n.
kernel (Union[int, Sequence[int]]) – int or sequence of ints of length
n. The size of the kernel for each dimension. If it is an int it will be replicated for the non channel and batch dimensions.rate (Union[int, Sequence[int]]) – int or sequence of ints of length
n. The dilation rate for each dimension. If it is an int it will be replicated for the non channel and batch dimensions.n (int) – the number of spatial dimensions.
- Return type
Sequence[tuple[int, int]]
- Returns
A sequence of length n containing the padding for each element. These are of the form
[pad_before, pad_after].
create_from_padfn#
- haiku.pad.create_from_padfn(padding, kernel, rate, n)[source]#
Generates the padding required for a given padding algorithm.
- Parameters
padding (Union[hk.pad.PadFn, Sequence[hk.pad.PadFn]]) – callable/tuple or a sequence of callables/tuples. The callables take an integer representing the effective kernel size (kernel size when the rate is 1) and return a sequence of two integers representing the padding before and padding after for that dimension. The tuples are defined with two elements, padding before and after. If padding is a sequence it must be of length 1 or n.
kernel (Union[int, Sequence[int]]) – int or sequence of ints of length
n. The size of the kernel for each dimension. If it is an int it will be replicated for the non channel and batch dimensions.rate (Union[int, Sequence[int]]) – int or sequence of ints of length
n. The dilation rate for each dimension. If it is an int it will be replicated for the non channel and batch dimensions.n (int) – the number of spatial dimensions.
- Return type
Sequence[tuple[int, int]]
- Returns
A sequence of length n containing the padding for each element. These are of the form
[pad_before, pad_after].
create_from_tuple#
causal#
full#
reverse_causal#
same#
valid#
Full Networks#
MLP#
- class haiku.nets.MLP(output_sizes, w_init=None, b_init=None, with_bias=True, activation=<jax._src.custom_derivatives.custom_jvp object>, activate_final=False, name=None)[source]#
A multi-layer perceptron module.
- __init__(output_sizes, w_init=None, b_init=None, with_bias=True, activation=<jax._src.custom_derivatives.custom_jvp object>, activate_final=False, name=None)[source]#
Constructs an MLP.
- Parameters
output_sizes (Iterable[int]) – Sequence of layer sizes.
w_init (Optional[hk.initializers.Initializer]) – Initializer for
Linearweights.b_init (Optional[hk.initializers.Initializer]) – Initializer for
Linearbias. Must beNoneifwith_bias=False.with_bias (bool) – Whether or not to apply a bias in each layer.
activation (Callable[[jax.Array], jax.Array]) – Activation function to apply between
Linearlayers. Defaults to ReLU.activate_final (bool) – Whether or not to activate the final layer of the MLP.
name (Optional[str]) – Optional name for this module.
- Raises
ValueError – If
with_biasisFalseandb_initis notNone.
- __call__(inputs, dropout_rate=None, rng=None)[source]#
Connects the module to some inputs.
- Parameters
inputs (jax.Array) – A Tensor of shape
[batch_size, input_size].dropout_rate (Optional[float]) – Optional dropout rate.
rng – Optional RNG key. Require when using dropout.
- Return type
jax.Array
- Returns
The output of the model of size
[batch_size, output_size].
- reverse(activate_final=None, name=None)[source]#
Returns a new MLP which is the layer-wise reverse of this MLP.
NOTE: Since computing the reverse of an MLP requires knowing the input size of each linear layer this method will fail if the module has not been called at least once.
The contract of reverse is that the reversed module will accept the output of the parent module as input and produce an output which is the input size of the parent.
>>> mlp = hk.nets.MLP([1, 2, 3]) >>> mlp_in = jnp.ones([1, 2]) >>> y = mlp(mlp_in) >>> rev = mlp.reverse() >>> rev_mlp_out = rev(y) >>> mlp_in.shape == rev_mlp_out.shape True
- Parameters
activate_final (Optional[bool]) – Whether the final layer of the MLP should be activated.
name (Optional[str]) – Optional name for the new module. The default name will be the name of the current module prefixed with
"reversed_".
- Return type
- Returns
An MLP instance which is the reverse of the current instance. Note these instances do not share weights and, apart from being symmetric to each other, are not coupled in any way.
MobileNet#
MobileNetV1#
- class haiku.nets.MobileNetV1(strides=(1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1), channels=(64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024), num_classes=1000, use_bn=True, name=None)[source]#
MobileNetV1 model.
- __init__(strides=(1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1), channels=(64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024), num_classes=1000, use_bn=True, name=None)[source]#
Constructs a MobileNetV1 model.
- Parameters
strides (Sequence[int]) – The stride to use the in depthwise convolution in each mobilenet block.
channels (Sequence[int]) – Number of output channels from the pointwise convolution to use in each block.
num_classes (int) – Number of classes.
use_bn (bool) – Whether or not to use batch normalization. Defaults to True. When true, biases are not used. When false, biases are used.
name (Optional[str]) – Name of the module.
ResNet#
|
ResNet model. |
|
Higher level block for ResNet implementation. |
|
ResNet V1 block with optional bottleneck. |
|
ResNet V2 block with optional bottleneck. |
|
ResNet18. |
|
ResNet34. |
|
ResNet50. |
|
ResNet101. |
|
ResNet152. |
|
ResNet200. |
ResNet#
- class haiku.nets.ResNet(blocks_per_group, num_classes, bn_config=None, resnet_v2=False, bottleneck=True, channels_per_group=(256, 512, 1024, 2048), use_projection=(True, True, True, True), logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#
ResNet model.
- class BlockGroup(channels, num_blocks, stride, bn_config, resnet_v2, bottleneck, use_projection, name=None)[source]#
Higher level block for ResNet implementation.
- __init__(channels, num_blocks, stride, bn_config, resnet_v2, bottleneck, use_projection, name=None)[source]#
Initializes the current module with the given name.
Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.
- Parameters
name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If
nameis not provided then the class name for the current instance is converted tolower_snake_caseand used instead.
- class BlockV1(channels, stride, use_projection, bn_config, bottleneck, name=None)[source]#
ResNet V1 block with optional bottleneck.
- __init__(channels, stride, use_projection, bn_config, bottleneck, name=None)[source]#
Initializes the current module with the given name.
Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.
- Parameters
name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If
nameis not provided then the class name for the current instance is converted tolower_snake_caseand used instead.
- class BlockV2(channels, stride, use_projection, bn_config, bottleneck, name=None)[source]#
ResNet V2 block with optional bottleneck.
- __init__(channels, stride, use_projection, bn_config, bottleneck, name=None)[source]#
Initializes the current module with the given name.
Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.
- Parameters
name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If
nameis not provided then the class name for the current instance is converted tolower_snake_caseand used instead.
- __init__(blocks_per_group, num_classes, bn_config=None, resnet_v2=False, bottleneck=True, channels_per_group=(256, 512, 1024, 2048), use_projection=(True, True, True, True), logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#
Constructs a ResNet model.
- Parameters
blocks_per_group (Sequence[int]) – A sequence of length 4 that indicates the number of blocks created in each group.
num_classes (int) – The number of classes to classify the inputs into.
bn_config (Optional[Mapping[str, FloatStrOrBool]]) – A dictionary of two elements,
decay_rateandepsto be passed on to theBatchNormlayers. By default thedecay_rateis0.9andepsis1e-5.resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to
False.bottleneck (bool) – Whether the block should bottleneck or not. Defaults to
True.channels_per_group (Sequence[int]) – A sequence of length 4 that indicates the number of channels used for each block in each group.
use_projection (Sequence[bool]) – A sequence of length 4 that indicates whether each residual block should use projection.
logits_config (Optional[Mapping[str, Any]]) – A dictionary of keyword arguments for the logits layer.
name (Optional[str]) – Name of the module.
initial_conv_config (Optional[Mapping[str, FloatStrOrBool]]) – Keyword arguments passed to the constructor of the initial
Conv2Dmodule.strides (Sequence[int]) – A sequence of length 4 that indicates the size of stride of convolutions for each block in each group.
ResNet18#
- class haiku.nets.ResNet18(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#
ResNet18.
- __init__(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#
Constructs a ResNet model.
- Parameters
num_classes (int) – The number of classes to classify the inputs into.
bn_config (Optional[Mapping[str, FloatStrOrBool]]) – A dictionary of two elements,
decay_rateandepsto be passed on to theBatchNormlayers.resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to
False.logits_config (Optional[Mapping[str, Any]]) – A dictionary of keyword arguments for the logits layer.
name (Optional[str]) – Name of the module.
initial_conv_config (Optional[Mapping[str, FloatStrOrBool]]) – Keyword arguments passed to the constructor of the initial
Conv2Dmodule.strides (Sequence[int]) – A sequence of length 4 that indicates the size of stride of convolutions for each block in each group.
ResNet34#
- class haiku.nets.ResNet34(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#
ResNet34.
- __init__(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#
Constructs a ResNet model.
- Parameters
num_classes (int) – The number of classes to classify the inputs into.
bn_config (Optional[Mapping[str, FloatStrOrBool]]) – A dictionary of two elements,
decay_rateandepsto be passed on to theBatchNormlayers.resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to
False.logits_config (Optional[Mapping[str, Any]]) – A dictionary of keyword arguments for the logits layer.
name (Optional[str]) – Name of the module.
initial_conv_config (Optional[Mapping[str, FloatStrOrBool]]) – Keyword arguments passed to the constructor of the initial
Conv2Dmodule.strides (Sequence[int]) – A sequence of length 4 that indicates the size of stride of convolutions for each block in each group.
ResNet50#
- class haiku.nets.ResNet50(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#
ResNet50.
- __init__(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#
Constructs a ResNet model.
- Parameters
num_classes (int) – The number of classes to classify the inputs into.
bn_config (Optional[Mapping[str, FloatStrOrBool]]) – A dictionary of two elements,
decay_rateandepsto be passed on to theBatchNormlayers.resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to
False.logits_config (Optional[Mapping[str, Any]]) – A dictionary of keyword arguments for the logits layer.
name (Optional[str]) – Name of the module.
initial_conv_config (Optional[Mapping[str, FloatStrOrBool]]) – Keyword arguments passed to the constructor of the initial
Conv2Dmodule.strides (Sequence[int]) – A sequence of length 4 that indicates the size of stride of convolutions for each block in each group.
ResNet101#
- class haiku.nets.ResNet101(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#
ResNet101.
- __init__(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#
Constructs a ResNet model.
- Parameters
num_classes (int) – The number of classes to classify the inputs into.
bn_config (Optional[Mapping[str, FloatStrOrBool]]) – A dictionary of two elements,
decay_rateandepsto be passed on to theBatchNormlayers.resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to
False.logits_config (Optional[Mapping[str, Any]]) – A dictionary of keyword arguments for the logits layer.
name (Optional[str]) – Name of the module.
initial_conv_config (Optional[Mapping[str, FloatStrOrBool]]) – Keyword arguments passed to the constructor of the initial
Conv2Dmodule.strides (Sequence[int]) – A sequence of length 4 that indicates the size of stride of convolutions for each block in each group.
ResNet152#
- class haiku.nets.ResNet152(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#
ResNet152.
- __init__(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#
Constructs a ResNet model.
- Parameters
num_classes (int) – The number of classes to classify the inputs into.
bn_config (Optional[Mapping[str, FloatStrOrBool]]) – A dictionary of two elements,
decay_rateandepsto be passed on to theBatchNormlayers.resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to
False.logits_config (Optional[Mapping[str, Any]]) – A dictionary of keyword arguments for the logits layer.
name (Optional[str]) – Name of the module.
initial_conv_config (Optional[Mapping[str, FloatStrOrBool]]) – Keyword arguments passed to the constructor of the initial
Conv2Dmodule.strides (Sequence[int]) – A sequence of length 4 that indicates the size of stride of convolutions for each block in each group.
ResNet200#
- class haiku.nets.ResNet200(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#
ResNet200.
- __init__(num_classes, bn_config=None, resnet_v2=False, logits_config=None, name=None, initial_conv_config=None, strides=(1, 2, 2, 2))[source]#
Constructs a ResNet model.
- Parameters
num_classes (int) – The number of classes to classify the inputs into.
bn_config (Optional[Mapping[str, FloatStrOrBool]]) – A dictionary of two elements,
decay_rateandepsto be passed on to theBatchNormlayers.resnet_v2 (bool) – Whether to use the v1 or v2 ResNet implementation. Defaults to
False.logits_config (Optional[Mapping[str, Any]]) – A dictionary of keyword arguments for the logits layer.
name (Optional[str]) – Name of the module.
initial_conv_config (Optional[Mapping[str, FloatStrOrBool]]) – Keyword arguments passed to the constructor of the initial
Conv2Dmodule.strides (Sequence[int]) – A sequence of length 4 that indicates the size of stride of convolutions for each block in each group.
VectorQuantizer#
|
Haiku module representing the VQ-VAE layer. |
|
Haiku module representing the VQ-VAE layer. |
VectorQuantizer#
- class haiku.nets.VectorQuantizer(embedding_dim, num_embeddings, commitment_cost, dtype=<class 'jax.numpy.float32'>, name=None, cross_replica_axis=None)[source]#
Haiku module representing the VQ-VAE layer.
Implements the algorithm presented in “Neural Discrete Representation Learning” by van den Oord et al. https://arxiv.org/abs/1711.00937
Input any tensor to be quantized. Last dimension will be used as space in which to quantize. All other dimensions will be flattened and will be seen as different examples to quantize.
The output tensor will have the same shape as the input.
For example a tensor with shape
[16, 32, 32, 64]will be reshaped into[16384, 64]and all16384vectors (each of64dimensions) will be quantized independently.- embedding_dim#
integer representing the dimensionality of the tensors in the quantized space. Inputs to the modules must be in this format as well.
- num_embeddings#
integer, the number of vectors in the quantized space.
- commitment_cost#
scalar which controls the weighting of the loss terms (see equation 4 in the paper - this variable is Beta).
- __init__(embedding_dim, num_embeddings, commitment_cost, dtype=<class 'jax.numpy.float32'>, name=None, cross_replica_axis=None)[source]#
Initializes a VQ-VAE module.
- Parameters
embedding_dim (int) – dimensionality of the tensors in the quantized space. Inputs to the modules must be in this format as well.
num_embeddings (int) – number of vectors in the quantized space.
commitment_cost (float) – scalar which controls the weighting of the loss terms (see equation 4 in the paper - this variable is Beta).
dtype (Any) – dtype for the embeddings variable, defaults to
float32.name (Optional[str]) – name of the module.
cross_replica_axis (Optional[str]) – If not
None, it should be a string representing the axis name over which this module is being run within ajax.pmap(). Supplying this argument means that perplexity is calculated across all replicas on that axis.
- __call__(inputs, is_training)[source]#
Connects the module to some inputs.
- Parameters
inputs – Tensor, final dimension must be equal to
embedding_dim. All other leading dimensions will be flattened and treated as a large batch.is_training – boolean, whether this connection is to training data.
- Returns
- Dictionary containing the following keys and values:
quantize: Tensor containing the quantized version of the input.loss: Tensor containing the loss to optimize.perplexity: Tensor containing the perplexity of the encodings.encodings: Tensor containing the discrete encodings, ie which element of the quantized space each input element was mapped to.encoding_indices: Tensor containing the discrete encoding indices, ie which element of the quantized space each input element was mapped to.
- Return type
dict
VectorQuantizerEMA#
- class haiku.nets.VectorQuantizerEMA(embedding_dim, num_embeddings, commitment_cost, decay, epsilon=1e-05, dtype=<class 'jax.numpy.float32'>, cross_replica_axis=None, name=None)[source]#
Haiku module representing the VQ-VAE layer.
Implements a slightly modified version of the algorithm presented in “Neural Discrete Representation Learning” by van den Oord et al. https://arxiv.org/abs/1711.00937
The difference between
VectorQuantizerEMAandVectorQuantizeris that this module usesExponentialMovingAverages to update the embedding vectors instead of an auxiliary loss. This has the advantage that the embedding updates are independent of the choice of optimizer (SGD, RMSProp, Adam, K-Fac, …) used for the encoder, decoder and other parts of the architecture. For most experiments the EMA version trains faster than the non-EMA version.Input any tensor to be quantized. Last dimension will be used as space in which to quantize. All other dimensions will be flattened and will be seen as different examples to quantize.
The output tensor will have the same shape as the input.
For example a tensor with shape
[16, 32, 32, 64]will be reshaped into[16384, 64]and all16384vectors (each of 64 dimensions) will be quantized independently.- embedding_dim#
integer representing the dimensionality of the tensors in the quantized space. Inputs to the modules must be in this format as well.
- num_embeddings#
integer, the number of vectors in the quantized space.
- commitment_cost#
scalar which controls the weighting of the loss terms (see equation 4 in the paper).
- decay#
float, decay for the moving averages.
- epsilon#
small float constant to avoid numerical instability.
- __init__(embedding_dim, num_embeddings, commitment_cost, decay, epsilon=1e-05, dtype=<class 'jax.numpy.float32'>, cross_replica_axis=None, name=None)[source]#
Initializes a VQ-VAE EMA module.
- Parameters
embedding_dim – integer representing the dimensionality of the tensors in the quantized space. Inputs to the modules must be in this format as well.
num_embeddings – integer, the number of vectors in the quantized space.
commitment_cost – scalar which controls the weighting of the loss terms (see equation 4 in the paper - this variable is Beta).
decay – float between 0 and 1, controls the speed of the Exponential Moving Averages.
epsilon (float) – small constant to aid numerical stability, default
1e-5.dtype (Any) – dtype for the embeddings variable, defaults to
float32.cross_replica_axis (Optional[str]) – If not
None, it should be a string representing the axis name over which this module is being run within ajax.pmap(). Supplying this argument means that cluster statistics and the perplexity are calculated across all replicas on that axis.name (Optional[str]) – name of the module.
- __call__(inputs, is_training)[source]#
Connects the module to some inputs.
- Parameters
inputs – Tensor, final dimension must be equal to
embedding_dim. All other leading dimensions will be flattened and treated as a large batch.is_training – boolean, whether this connection is to training data. When this is set to
False, the internal moving average statistics will not be updated.
- Returns
- Dictionary containing the following keys and values:
quantize: Tensor containing the quantized version of the input.loss: Tensor containing the loss to optimize.perplexity: Tensor containing the perplexity of the encodings.encodings: Tensor containing the discrete encodings, ie which element of the quantized space each input element was mapped to.encoding_indices: Tensor containing the discrete encoding indices, ie which element of the quantized space each input element was mapped to.
- Return type
dict
JAX Fundamentals#
Control Flow#
|
Conditionally apply |
|
Equivalent to |
|
Equivalent to |
|
Equivalent to |
|
Equivalent to |
|
Equivalent to jax.lax.while_loop with Haiku state threaded in/out. |
cond#
- haiku.cond(pred, true_fun, false_fun, *operands, operand=<object object>, linear=None)[source]#
Conditionally apply
true_funorfalse_fun.Wraps XLA’s Conditional operator.
Provided arguments are correctly typed,
cond()has equivalent semantics to this Python implementation, wherepredmust be a scalar type:def cond(pred, true_fun, false_fun, *operands): if pred: return true_fun(*operands) else: return false_fun(*operands)
In contrast with
jax.lax.select(), usingcondindicates that only one of the two branches is executed (up to compiler rewrites and optimizations). However, when transformed withvmap()to operate over a batch of predicates,condis converted toselect().- Parameters
pred – Boolean scalar type, indicating which branch function to apply.
true_fun (
Callable) – Function (A -> B), to be applied ifpredis True.false_fun (
Callable) – Function (A -> B), to be applied ifpredis False.operands – Operands (A) input to either branch depending on
pred. The type can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof.
- Returns
Value (B) of either
true_fun(*operands)orfalse_fun(*operands), depending on the value ofpred. The type can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof.
fori_loop#
map#
scan#
switch#
- haiku.switch(index, branches, *operands)[source]#
Equivalent to
jax.lax.switch()but with Haiku state passed in/out.Note that creating parameters inside a switch branch is not supported, as such at init time we recommend you unconditionally evaluate all branches of your switch and only use the switch at apply. For example:
>>> experts = [hk.nets.MLP([300, 100, 10]) for _ in range(5)] >>> x = jnp.ones([1, 28 * 28]) >>> if hk.running_init(): ... # During init unconditionally create params/state for all experts. ... for expert in experts: ... out = expert(x) ... else: ... # During apply conditionally apply (and update) only one expert. ... index = jax.random.randint(hk.next_rng_key(), [], 0, len(experts) - 1) ... out = hk.switch(index, experts, x)
- Parameters
index – Integer scalar type, indicating which branch function to apply.
branches – Sequence of functions (A -> B) to be applied based on index.
operands – Operands (A) input to whichever branch is applied.
- Returns
Value (B) of branch(*operands) for the branch that was selected based on index.
while_loop#
JAX Transforms#
|
Equivalent to jax.eval_shape with any changed Haiku state discarded. |
|
Creates a function which evaluates the gradient of |
|
Equivalent to jax.checkpoint but passing Haiku state. |
|
Creates a function which evaluates both |
|
Equivalent to |
eval_shape#
grad#
- haiku.grad(fun, argnums=0, has_aux=False, holomorphic=False)[source]#
Creates a function which evaluates the gradient of
fun.NOTE: You only need this in a very specific case that you want to take a gradient inside a
transform()ed function and the function you are differentiating usesset_state(). For example:>>> class MyModule(hk.Module): ... def __call__(self, x): ... hk.set_state("last", x ** 2) ... return x ** 2
>>> def f(x): ... m = MyModule() ... g = hk.grad(m)(x) ... return g
>>> f = hk.transform_with_state(f) >>> x = jnp.array(2.) >>> params, state = jax.jit(f.init)(None, x) >>> print(state["my_module"]["last"]) 4.0
- Parameters
fun – Function to be differentiated. Its arguments at positions specified by
argnumsshould be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape()but not arrays with shape(1,)etc.)argnums – Optional, integer or tuple of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).
has_aux – Optional, bool. Indicates whether
funreturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic – Optional, bool. Indicates whether
funis promised to be holomorphic. Default False.
- Returns
A function with the same arguments as
fun, that evaluates the gradient offun. If argnums is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. Ifhas_auxis True then a pair ofgradient, auxiliary_datais returned.
For example:
>>> grad_tanh = jax.grad(jax.numpy.tanh) >>> print(grad_tanh(0.2)) 0.96...
remat#
value_and_grad#
- haiku.value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False)[source]#
Creates a function which evaluates both
funand the grad offun.NOTE: You only need this in a very specific case that you want to take a gradient inside a
transform()ed function and the function you are differentiating usesset_state(). For example:>>> class MyModule(hk.Module): ... def __call__(self, x): ... hk.set_state("last", jnp.sum(x)) ... return x ** 2
>>> def f(x): ... m = MyModule() ... y, g = hk.value_and_grad(m)(x) ... return y, g
>>> f = hk.transform_with_state(f) >>> x = jnp.array(2.) >>> _ = jax.jit(f.init)(None, x)
- Parameters
fun – Function to be differentiated. Its arguments at positions specified by
argnumsshould be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape()but not arrays with shape(1,)etc.)argnums – Optional, integer or tuple of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).
has_aux – Optional, bool. Indicates whether
funreturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic – Optional, bool. Indicates whether
funis promised to be holomorphic. Default False.
- Returns
A function with the same arguments as
funthat evaluates bothfunand the gradient offunand returns them as a pair (a two-element tuple). Ifargnumsis an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments.
vmap#
- haiku.vmap(fun, in_axes=0, out_axes=0, axis_name=None, axis_size=None, *, split_rng)[source]#
Equivalent to
jax.vmap()with module parameters/state not mapped.The behaviour of Haiku random key APIs under
vmap()is controlled by thesplit_rngargument:>>> x = jnp.arange(2) >>> f = hk.vmap(lambda _: hk.next_rng_key(), split_rng=False) >>> key1, key2 = f(x) >>> assert (key1 == key2).all()
>>> f = hk.vmap(lambda _: hk.next_rng_key(), split_rng=True) >>> key1, key2 = f(x) >>> assert not (key1 == key2).all()
Random numbers in Haiku are typically used for two things, firstly for initialising model parameters, and secondly for creating random samples as part of the forward pass of a neural network (e.g. for dropout). If you are using
vmap()with a module that uses Haiku random keys for both (e.g. you don’t pass keys explicitly into the network), then it is quite likely that you will want to vary the value ofsplit_rngdepending on whether we are initalizing (e.g. creating model parameters) or applying the model. An easy way to do this is to setsplit_rng=(not hk.running_init()).For more advanced use cases, such as mapping module parameters, we suggest users instead use
lift()ortransparent_lift()in combination withjax.vmap().- Parameters
fun (Callable[..., Any]) – See
jax.vmap().in_axes – See
jax.vmap().out_axes – See
jax.vmap().axis_name (Optional[str]) – See
jax.vmap().axis_size (Optional[int]) – See
jax.vmap().split_rng (bool) – Controls whether random key APIs in Haiku (e.g.
next_rng_key()) return different (aka. the internal key is split before calling your mapped function) or the same (aka. the internal key is broadcast before calling your mapped fucntion) key. See the docstring for examples.
- Return type
Callable[…, Any]
- Returns
See
jax.vmap().
Mixed Precision#
Automatic Mixed Precision#
|
Uses the given policy for all instances of the module class. |
Retrieves the currently active policy in the current context. |
|
|
Retrieves the currently active policy for the given class. |
|
Clears any policy assocated with the given class. |
|
Sets the given policy for the given class while the context is active. |
set_policy#
- haiku.mixed_precision.set_policy(cls, policy)[source]#
Uses the given policy for all instances of the module class.
NOTE: Policies are only applied to modules created in the current thread.
A mixed precision policy describes how inputs, module parameters and module outputs should be cast at runtime. By applying a policy to a given type of module, you can control how all instances of that module behave in your program.
For example, you might want to try running a ResNet50 model in a mixture of
float16andfloat32on GPU to get higher throughput. To do so you can apply a mixed precision policy to the ResNet50 type that will create parameters infloat32, but cast them tofloat16before use, along with all module inputs:>>> policy = jmp.get_policy('params=float32,compute=float16,output=float32') >>> hk.mixed_precision.set_policy(hk.nets.ResNet50, policy) >>> net = hk.nets.ResNet50(4) >>> x = jnp.ones([4, 224, 224, 3]) >>> print(net(x, is_training=True)) [[0. 0. 0. 0.] [0. 0. 0. 0.] [0. 0. 0. 0.] [0. 0. 0. 0.]]
For a fully worked mixed precision example see the imagenet example in Haiku’s examples directory. This example shows mixed precision on GPU offering a 2x speedup in training time with only a small impact on final top-1 accuracy.
>>> hk.mixed_precision.clear_policy(hk.nets.ResNet50)
- Parameters
cls (type[hk.Module]) – A Haiku module class.
policy (jmp.Policy) – A JMP policy to apply to the module.
See also
push_policy(): Context manager for setting policies.current_policy(): Retrieves the currently active policy (if any).clear_policy(): Clears any policies associated with a class.get_policy(): Gets the policy for a given class.
current_policy#
- haiku.mixed_precision.current_policy()[source]#
Retrieves the currently active policy in the current context.
- Return type
Optional[jmp.Policy]
- Returns
The currently active mixed precision policy, or
None.
See also
clear_policy(): Clears any policies associated with a class.get_policy(): Gets the policy for a given class.set_policy(): Sets a policy for a given class.push_policy(): Context manager for setting policies.
get_policy#
- haiku.mixed_precision.get_policy(cls)[source]#
Retrieves the currently active policy for the given class.
Note that policies applied explicitly to a top level class (e.g.
ResNet) will be applied implicitly to all child modules (e.g.ConvND) called from the parent. This function only returns policies that have been applied explicitly (e.g. viaset_policy()).- Parameters
cls (type[hk.Module]) – A Haiku module class.
- Return type
Optional[jmp.Policy]
- Returns
A JMP policy that is used for the given class, or
Noneif one is not active.
See also
current_policy(): Retrieves the currently active policy (if any).clear_policy(): Clears any policies associated with a class.set_policy(): Sets a policy for a given class.push_policy(): Context manager for setting policies.
clear_policy#
- haiku.mixed_precision.clear_policy(cls)[source]#
Clears any policy assocated with the given class.
- Parameters
cls (type[hk.Module]) – A Haiku module class.
See also
current_policy(): Retrieves the currently active policy (if any).get_policy(): Gets the policy for a given class.set_policy(): Sets a policy for a given class.push_policy(): Context manager for setting policies.
push_policy#
- haiku.mixed_precision.push_policy(cls, policy)[source]#
Sets the given policy for the given class while the context is active.
- Parameters
cls (type[hk.Module]) – A Haiku module class.
policy (jmp.Policy) – A JMP policy to apply to the module.
- Yields
None.
See also
clear_policy(): Clears any policies associated with a class.get_policy(): Gets the policy for a given class.set_policy(): Sets a policy for a given class.current_policy(): Retrieves the currently active policy (if any).
🚧 Experimental#
Graphviz Visualisation#
|
Converts a function using Haiku modules to a dot graph. |
abstract_to_dot#
- haiku.experimental.abstract_to_dot(fun)[source]#
Converts a function using Haiku modules to a dot graph.
Same as
to_dot()but uses JAX’s abstract interpretation machinery to evaluate the function without requiring concrete inputs. Valid inputs for the wrapped function includejax.ShapeDtypeStruct.abstract_to_dot()does not support data-dependent control-flow, because no concrete values are provided to the function.- Parameters
fun (Callable[..., Any]) – A function using Haiku modules.
- Return type
Callable[…, str]
- Returns
A function that returns the source code string to a graphviz graph describing the operations executed by the given function clustered by Haiku module.
See also
to_dot(): Generates a graphviz graph using concrete inputs.
Summarisation#
|
Produces a summarised view of the execution of |
|
Records module method calls performed by |
|
Shaped and sized specification of an array. |
|
Record of a method being invoked on a given module. |
|
Module and method related information. |
tabulate#
- haiku.experimental.tabulate(f, *, columns=('module', 'config', 'owned_params', 'input', 'output', 'params_size', 'params_bytes'), filters=('has_output',), tabulate_kwargs={'tablefmt': 'grid'})[source]#
Produces a summarised view of the execution of
f.>>> def f(x): ... return hk.nets.MLP([300, 100, 10])(x) >>> x = jnp.ones([8, 28 * 28]) >>> f = hk.transform(f) >>> print(hk.experimental.tabulate(f)(x)) +-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+ | Module | Config | Module params | Input | Output | Param count | Param bytes | +=========================+==========================================+=================+============+============+===============+===============+ | mlp (MLP) | MLP(output_sizes=[300, 100, 10]) | | f32[8,784] | f32[8,10] | 266,610 | 1.07 MB | +-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+ | mlp/~/linear_0 (Linear) | Linear(output_size=300, name='linear_0') | w: f32[784,300] | f32[8,784] | f32[8,300] | 235,500 | 942.00 KB | | └ mlp (MLP) | | b: f32[300] | | | | | +-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+ | mlp/~/linear_1 (Linear) | Linear(output_size=100, name='linear_1') | w: f32[300,100] | f32[8,300] | f32[8,100] | 30,100 | 120.40 KB | | └ mlp (MLP) | | b: f32[100] | | | | | +-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+ | mlp/~/linear_2 (Linear) | Linear(output_size=10, name='linear_2') | w: f32[100,10] | f32[8,100] | f32[8,10] | 1,010 | 4.04 KB | | └ mlp (MLP) | | b: f32[10] | | | | | +-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+
Possible values for
columns:module: Displays module and method name.config: Displays the constructor arguments used for the module.owned_params: Displays parameters directly owned by this module.input: Displays module inputs.output: Displays module output.params_size: Displays the number of parametersparams_bytes: Displays parameter size in bytes.
Possible values for
filters:has_output: Only include methods returning a value other thanNone.has_params: Removes methods from modules that do not have parameters.
- Parameters
f (Union[Callable[..., Any], hk.Transformed, hk.TransformedWithState]) – A function to transform OR one of the init/apply functions from Haiku or the result of
transform()ortransform_with_state().columns (Optional[Sequence[str]]) – A list of column names to enable.
filters (Optional[Sequence[str]]) – A list of filters to apply to remove certain module methods.
tabulate_kwargs – Keyword arguments to pass to
tabulate.tabulate(..).
- Return type
Callable[…, str]
- Returns
A callable that takes the same arguments as
fbut returns a string summarising the modules used during the execution off.
See also
eval_summary(): Raw data used to generate this table.
eval_summary#
- haiku.experimental.eval_summary(f)[source]#
Records module method calls performed by
f.>>> f = lambda x: hk.nets.MLP([300, 100, 10])(x) >>> x = jnp.ones([8, 28 * 28]) >>> for i in hk.experimental.eval_summary(f)(x): ... print("mod := {:14} | in := {} out := {}".format( ... i.module_details.module.module_name, i.args_spec[0], i.output_spec)) mod := mlp | in := f32[8,784] out := f32[8,10] mod := mlp/~/linear_0 | in := f32[8,784] out := f32[8,300] mod := mlp/~/linear_1 | in := f32[8,300] out := f32[8,100] mod := mlp/~/linear_2 | in := f32[8,100] out := f32[8,10]
- Parameters
f (Union[Callable[..., Any], hk.Transformed, hk.TransformedWithState]) – A function or transformed function to trace.
- Return type
Callable[…, Sequence[MethodInvocation]]
- Returns
A callable taking the same arguments as the provided function, but returning a sequence of
MethodInvocationinstances revealing the methods called on each module when applyingf.
See also
tabulate(): Pretty prints a summary of the execution of a function.
ArraySpec#
- class haiku.experimental.ArraySpec(shape, dtype)[source]#
Shaped and sized specification of an array.
- shape#
Shape of the array.
- Type
Sequence[int]
- dtype#
DType of the array.
- Type
jnp.dtype
- __delattr__(name)#
Implement delattr(self, name).
- __eq__(other)#
Return self==value.
- __hash__()#
Return hash(self).
- __init__(shape, dtype)#
- __setattr__(name, value)#
Implement setattr(self, name, value).
MethodInvocation#
- class haiku.experimental.MethodInvocation(module_details, args_spec, kwargs_spec, output_spec, context, call_stack)[source]#
Record of a method being invoked on a given module.
- module_details#
Details about which module and method were invoked.
- Type
- args_spec#
Positional arguments to the method invocation with arrays replaced by
ArraySpec.- Type
tuple[Any, …]
- kwargs_spec#
Keyword arguments to the method invocation with arrays replaced by
ArraySpec.- Type
dict[str, Any]
- context#
Additional context information for the method call as provided by
intercept_methods().- Type
hk.MethodContext
- call_stack#
Stack of modules currently active while calling this module method. For example if
AcallsBwhich callsCthen the call stack forCwill be[B_DETAILS, A_DETAILS].- Type
Sequence[ModuleDetails]
- __delattr__(name)#
Implement delattr(self, name).
- __eq__(other)#
Return self==value.
- __hash__()#
Return hash(self).
- __init__(module_details, args_spec, kwargs_spec, output_spec, context, call_stack)#
- __setattr__(name, value)#
Implement setattr(self, name, value).
ModuleDetails#
- class haiku.experimental.ModuleDetails(module, method_name, params, state)[source]#
Module and method related information.
- method_name#
The method name that was invoked on the module.
- Type
str
- __delattr__(name)#
Implement delattr(self, name).
- __eq__(other)#
Return self==value.
- __hash__()#
Return hash(self).
- __init__(module, method_name, params, state)#
- __setattr__(name, value)#
Implement setattr(self, name, value).
Managing State#
|
Ensures JAX APIs (e.g. |
check_jax_usage#
- haiku.experimental.check_jax_usage(enabled=True)[source]#
Ensures JAX APIs (e.g.
jax.vmap()) are used correctly with Haiku.JAX transforms (like
jax.vmap()) and control flow (e.g.jax.lax.cond()) expect pure functions to be passed in. Some functions in Haiku (for exampleget_parameter()) have side effects and thus functions using them are only pure after usingtransform()(et al).Sometimes it is convenient to use JAX transforms or control flow before transforming your function (for example, to
vmap()the application of a module) but when doing so you need to be careful to use the Haiku overloaded version of the underlying JAX function, which carefully makes the function(s) you pass in pure functions before calling the underlying JAX function.check_jax_usage()enables checking raw JAX transforms are used appropriately inside Haiku transformed functions. Incorrect usage of JAX transforms will result in an error.Consider the function below, it is not a pure function (a function of its inputs with no side effects) because we call into a Haiku API (
get_parameter()) which during init will create a parameter and register it with Haiku.>>> def f(): ... return hk.get_parameter("some_param", [], init=jnp.zeros)
We should not use this with JAX APIs like
jax.vmap()(because it is not a pure function).check_jax_usage()allows you to tell Haiku to make incorrect usages of JAX APIs an error:>>> previous_value = hk.experimental.check_jax_usage(True) >>> jax.vmap(f, axis_size=2)() Traceback (most recent call last): ... haiku.JaxUsageError: ...
Using the Haiku wrapped version works correctly:
>>> print(hk.vmap(f, axis_size=2, split_rng=False)()) [0. 0.]
- Parameters
enabled (bool) – Boolean indicating whether usage should be checked or not.
- Return type
bool
- Returns
Boolean with the previous value for this setting.
Optimizations#
|
Optimizes a RNG key splitting in |
|
Disables automatically generating an implementation of Module.__repr__. |
|
Equivalent to |
|
Change amount of RNG keys reserved when calling |
optimize_rng_use#
- haiku.experimental.optimize_rng_use(fun)[source]#
Optimizes a RNG key splitting in
fun.Our strategy here is to use abstract interpretation to run your function twice, the first time we use
jax.eval_shape()to avoid spending any flops and simply observe how many times you callnext_rng_key(). We then run your function again, but this time we reserve enough RNG keys ahead of time such that we only need to calljax.random.split()once.In the following example, we need three random samples for our weight matrices in our 3-layer MLP. To draw these samples we use
next_rng_key()which will split a new key for each sample. By usingoptimize_rng_use()Haiku will pre-allocate exactly enough RNGs forfto be evaluated by splitting the input key once and only once. For large models (unlike this example) this can lead to a reduction in compilation time of bothinitandapply, withinitseeing a larger expected speedup as it performs more RNG key splitting in general.>>> def f(x): ... net = hk.nets.MLP([300, 100, 10]) ... return net(x) >>> f = hk.experimental.optimize_rng_use(f) >>> f = hk.transform(f) >>> params = f.init(jax.random.PRNGKey(42), jnp.ones([1, 1]))
- Parameters
fun – A function to wrap.
- Returns
A function that applies
funbut only requires one call tojax.random.split()by Haiku.
module_auto_repr#
- haiku.experimental.module_auto_repr(enabled)[source]#
Disables automatically generating an implementation of Module.__repr__.
By default, Haiku will automatically generate a useful string representation of modules for printing. For example:
>>> print(hk.Linear(1)) Linear(output_size=1)
In some cases, objects passed into module constructors may be slow to print, for example very nested data structures, or you may be rapidly creating and throwing away modules (e.g. in a test) and don’t want to pay the overhead of converting to string.
This config option enables users to disable the automatic repr feature globally in Haiku:
>>> previous_value = hk.experimental.module_auto_repr(False) >>> print(hk.Linear(1)) <...Linear object at ...>
>>> previous_value = hk.experimental.module_auto_repr(True) >>> print(hk.Linear(1)) Linear(output_size=1)
To disable the feature on a per-subclass basis assign
AUTO_REPR = Falseas a property on your class, for example:>>> class NoAutoRepr(hk.Module): ... AUTO_REPR = False >>> print(NoAutoRepr()) <...NoAutoRepr object at ...>
- Parameters
enabled (bool) – Boolean indicating whether a module should be enabled.
- Return type
bool
- Returns
The previous value of this config setting.
fast_eval_shape#
- haiku.experimental.fast_eval_shape(fun, *args, **kwargs)[source]#
Equivalent to
eval_shapein JAX.This utility is equivalent to
eval_shapein JAX except that it avoids running Haiku functions whose shapes are trivially known. This can avoid some Python overheads in JAX which can accumulate for very large models.Optimizations:
All parameter/state initialisers replaced with zeros.
hk.dropoutreplaced with identity.jax.random.fold_inreplaced with identity.
- Parameters
fun – The function to trace.
*args – Positional arguments to
fun.**kwargs – Keyword arguments to
fun.
- Returns
The shape produced by
funfor the given args/kwargs.
rng_reserve_size#
- haiku.experimental.rng_reserve_size(size)[source]#
Change amount of RNG keys reserved when calling
next_rng_key.- Parameters
size (int) – amount of keys to reserve when splitting off a key through
next_rng_key(), defaults to 1. Reserving larger blocks of keys can improve compilation and run-time of your model. Changing the reservation size will change RNG keys returned bynext_rng_key, and will change the generated random numbers.- Return type
int
- Returns
The previous value of the rng_reserve_size setting.
jaxpr_info#
|
Creates a function that computes flop, param and state information. |
|
Formats a Module as a tree of interactive HTML elements. |
|
Formats the output of make_model_info as an interactive HTML page. |
|
The CSS for HTML visualization of a Module. |
|
Recursively formats module information as a human readable string. |
|
The JavaScript for HTML visualization of a Module. |
|
Information about a single JAX expression. |
|
Information about a Haiku module. |
make_model_info#
- haiku.experimental.jaxpr_info.make_model_info(f, name=None, include_module_info=True, compute_flops=None, axis_env=None)[source]#
Creates a function that computes flop, param and state information.
- Parameters
f (Callable[..., Any]) – The function for which to compute information. Haiku modules and jax.named_call expressions will be represented as nested Modules in the result.
name (Optional[str]) – Optional, the name of the root expression.
include_module_info (bool) – Whether to include parameter and state count information for haiku modules. Can be slow for very large computations.
compute_flops (Optional[ComputeFlopsFn]) – Optional, a function that returns an estimate of the number of flops required to execute an equation.
axis_env (Optional[Sequence[tuple[Any, int]]]) – Sizes of pmapped axes. See docs of jax.make_jaxpr for details.
- Return type
Callable[…, Module]
- Returns
A wrapped version of f that when applied to example arguments returns a Module representation of f for those arguments.
Module and Expression contain high level information about JAX operations (jaxprs) and can be visualized in concise and interactive formats; see format_module, as_html_page or as_html.
as_html#
- haiku.experimental.jaxpr_info.as_html(module, min_flop=1000, outvars='', last=False)[source]#
Formats a Module as a tree of interactive HTML elements.
When embedding this in a page, the outputs of css and js must be embedded too for the visualization to work. To only visualize a single module directly, see as_html_page.
- Parameters
module (Module) – The module to visualize, as an interactive HTML tree.
min_flop (int) – Minimum number of flops for an operation to be shown.
outvars (str) – For internal use, the outputs of this module.
last (bool) – For internal use, whether this module is the last of its siblings.
- Return type
str
- Returns
HTML representation of module.
as_html_page#
css#
format_module#
js#
Expression#
Module#
Configuration#
|
Context manager for setting config options. |
|
Sets the given config option(s). |
context#
- haiku.config.context(*, check_jax_usage=None, module_auto_repr=None, restore_flatmap=None, rng_reserve_size=None)[source]#
Context manager for setting config options.
This context manager can be used to override config settings in a given context, values that are not explicitly passed as keyword arguments retain their current value:
>>> with hk.config.context(check_jax_usage=True): ... pass
- Parameters
check_jax_usage (Optional[bool]) – Checks that jax transforms and control flow are used appropriately in Haiku transformed functions.
module_auto_repr (Optional[bool]) – Can be used to disable the “to string” functionality that is part of Haiku’s base contructor.
restore_flatmap (Optional[bool]) – Whether legacy checkpoints should be restored in the old FlatMap datatype (as returned by
to_immtable_dict), default is to restore these as plain dicts.rng_reserve_size (Optional[int]) – amount of keys to reserve when splitting off a key through
next_rng_key(), defaults to 1. Reserving larger blocks of keys can improve compilation and run-time of your model. Changing the reservation size will change RNG keys returned bynext_rng_key, and will change the generated random numbers.
- Returns
Context manager that applies the given configs while active.
set#
- haiku.config.set(*, check_jax_usage=None, module_auto_repr=None, restore_flatmap=None, rng_reserve_size=None)[source]#
Sets the given config option(s).
>>> hk.config.set(module_auto_repr=False) >>> hk.Linear(1) <...Linear object at ...> >>> hk.config.set(module_auto_repr=True) >>> hk.Linear(1) Linear(output_size=1)
- Parameters
check_jax_usage (Optional[bool]) – Checks that jax transforms and control flow are used appropriately in Haiku transformed functions.
module_auto_repr (Optional[bool]) – Can be used to disable the “to string” functionality that is part of Haiku’s base contructor.
restore_flatmap (Optional[bool]) – Whether legacy checkpoints should be restored in the old FlatMap datatype (as returned by
to_immtable_dict), default is to restore these as plain dicts.rng_reserve_size (Optional[int]) – amount of keys to reserve when splitting off a key through
next_rng_key(), defaults to 1. Reserving larger blocks of keys can improve compilation and run-time of your model. Changing the reservation size will change RNG keys returned bynext_rng_key, and will change the generated random numbers.
Utilities#
Data Structures#
|
Filters an input structure according to a user specified predicate. |
|
Checks whether the leaves of subset appear in superset. |
|
Maps a function to an input structure accordingly. |
|
Merges multiple input structures. |
|
Partitions the input structure in two according to a given predicate. |
|
Partitions a structure into n structures. |
|
Returns a copy of the given two level structure. |
|
Returns an immutable copy of the given mapping. |
|
Turns an immutable FlatMapping into a mutable dict. |
|
Iterates over a structure yielding module names, names and values. |
|
Sums the size in bytes of all arrays in a pytree. |
|
Sums the sizes of all arrays in a pytree. |
filter#
- haiku.data_structures.filter(predicate, structure)[source]#
Filters an input structure according to a user specified predicate.
>>> params = {'linear': {'w': None, 'b': None}} >>> predicate = lambda module_name, name, value: name == 'w' >>> hk.data_structures.filter(predicate, params) {'linear': {'w': None}}
Note: returns a new structure not a view.
- Parameters
predicate (Callable[[str, str, T], bool]) – criterion to be used to partition the input data. The
predicateargument is expected to be a boolean function taking as inputs the name of the module, the name of a given entry in the module data bundle (e.g. parameter name) and the corresponding data.structure (Mapping[str, Mapping[str, T]]) – Haiku params or state data structure to be filtered.
- Return type
Mapping[str, Mapping[str, T]]
- Returns
All the input parameters or state as selected by the input predicate.
is_subset#
- haiku.data_structures.is_subset(*, subset, superset)[source]#
Checks whether the leaves of subset appear in superset.
Note that this is vacuously true in the case that both structures have no leaves:
>>> hk.data_structures.is_subset(subset={'a': {}}, superset={}) True
- Parameters
subset (Mapping[str, Mapping[str, Any]]) – The subset to check.
superset (Mapping[str, Mapping[str, Any]]) – The superset to check.
- Return type
bool
- Returns
A boolean indicating whether all elements in subset are contained in superset.
map#
- haiku.data_structures.map(fn, structure)[source]#
Maps a function to an input structure accordingly.
>>> params = {'linear': {'w': 1.0, 'b': 2.0}} >>> fn = lambda module_name, name, value: 2 * value if name == 'w' else value >>> hk.data_structures.map(fn, params) {'linear': {'b': 2.0, 'w': 2.0}}
Note: returns a new structure not a view.
- Parameters
fn (Callable[[str, str, InT], OutT]) – criterion to be used to map the input data. The
fnargument is expected to be a function taking as inputs the name of the module, the name of a given entry in the module data bundle (e.g. parameter name) and the corresponding data, and returning a new value.structure (Mapping[str, Mapping[str, InT]]) – Haiku params or state data structure to be mapped.
- Return type
Mapping[str, Mapping[str, OutT]]
- Returns
All the input parameters or state as mapped by the input fn.
merge#
- haiku.data_structures.merge(*structures, check_duplicates=False)[source]#
Merges multiple input structures.
>>> weights = {'linear': {'w': None}} >>> biases = {'linear': {'b': None}} >>> hk.data_structures.merge(weights, biases) {'linear': {'w': None, 'b': None}}
When structures are not disjoint the output will contain the value from the last structure for each path:
>>> weights1 = {'linear': {'w': 1}} >>> weights2 = {'linear': {'w': 2}} >>> hk.data_structures.merge(weights1, weights2) {'linear': {'w': 2}}
Note: returns a new structure not a view.
- Parameters
*structures – One or more structures to merge.
check_duplicates (bool) – If True, a ValueError will be thrown if an array is found in multiple structures but with a different shape and dtype.
- Return type
MutableMapping[str, MutableMapping[str, Any]]
- Returns
A single structure with an entry for each path in the input structures.
partition#
- haiku.data_structures.partition(predicate, structure)[source]#
Partitions the input structure in two according to a given predicate.
For a given set of parameters, you can use
partition()to split them:>>> params = {'linear': {'w': None, 'b': None}} >>> predicate = lambda module_name, name, value: name == 'w' >>> weights, biases = hk.data_structures.partition(predicate, params) >>> weights {'linear': {'w': None}} >>> biases {'linear': {'b': None}}
Note: returns new structures not a view.
- Parameters
predicate (Callable[[str, str, jax.Array], bool]) – criterion to be used to partition the input data. The
predicateargument is expected to be a boolean function taking as inputs the name of the module, the name of a given entry in the module data bundle (e.g. parameter name) and the corresponding data.structure (Mapping[str, Mapping[str, T]]) – Haiku params or state data structure to be partitioned.
- Return type
tuple[Mapping[str, Mapping[str, T]], Mapping[str, Mapping[str, T]]]
- Returns
- A tuple containing all the params or state as partitioned by the input
predicate. Entries matching the predicate will be in the first structure, and the rest will be in the second.
partition_n#
- haiku.data_structures.partition_n(fn, structure, n)[source]#
Partitions a structure into n structures.
For a given set of parameters, you can use
partition_n()to split them intongroups. For example, to split your parameters/gradients by module name:>>> def partition_by_module(structure): ... cnt = itertools.count() ... d = collections.defaultdict(lambda: next(cnt)) ... fn = lambda m, n, v: d[m] ... return hk.data_structures.partition_n(fn, structure, len(structure))
>>> structure = {f'layer_{i}': {'w': None, 'b': None} for i in range(3)} >>> for substructure in partition_by_module(structure): ... print(substructure) {'layer_0': {'b': None, 'w': None}} {'layer_1': {'b': None, 'w': None}} {'layer_2': {'b': None, 'w': None}}
- Parameters
fn (Callable[[str, str, T], int]) – Callable returning which bucket in
[0, n)the given element should be output.structure (Mapping[str, Mapping[str, T]]) – Haiku params or state data structure to be partitioned.
n (int) – The total number of buckets.
- Return type
tuple[Mapping[str, Mapping[str, T]], …]
- Returns
A tuple of size
n, where each element will contain the values for which the function returned the current index.
to_haiku_dict#
- haiku.data_structures.to_haiku_dict(structure)[source]#
Returns a copy of the given two level structure.
Uses the same mapping type as Haiku will return from
initorapplyfunctions.- Parameters
structure (Mapping[K, V]) – A two level mapping to copy.
- Return type
MutableMapping[K, V]
- Returns
A new two level mapping with the same contents as the input.
to_immutable_dict#
to_mutable_dict#
traverse#
- haiku.data_structures.traverse(structure)[source]#
Iterates over a structure yielding module names, names and values.
NOTE: Items are iterated in key sorted order.
- Parameters
structure (Mapping[str, Mapping[str, T]]) – The structure to traverse.
- Yields
Tuples of the module name, name and value from the given structure.
- Return type
Generator[tuple[str, str, T], None, None]
tree_bytes#
- haiku.data_structures.tree_bytes(tree)[source]#
Sums the size in bytes of all arrays in a pytree.
Note that this is the minimum size of the array (e.g. for a float32 we need at least 4 bytes) however on some accelerators buffers may occupy more memory due to padding/alignment constraints.
For example given a ResNet50 model:
>>> f = hk.transform_with_state(lambda x: hk.nets.ResNet50(1000)(x, True)) >>> rng = jax.random.PRNGKey(42) >>> x = jnp.ones([128, 224, 224, 3]) >>> params, state = f.init(rng, x)
We can count the number of parameters and their size at f32:
>>> num_params = hk.data_structures.tree_size(params) >>> byte_size = hk.data_structures.tree_bytes(params) >>> print(f'{num_params} params, size: {byte_size / 1e6:.2f}MB') 25557032 params, size: 102.23MB
And compare that with casting our parameters to bf16:
>>> params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) >>> num_params = hk.data_structures.tree_size(params) >>> byte_size = hk.data_structures.tree_bytes(params) >>> print(f'{num_params} params, size: {byte_size / 1e6:.2f}MB') 25557032 params, size: 51.11MB
- Parameters
tree – A tree of jax.Arrays.
- Return type
int
- Returns
The total size in bytes of the array(s) in the input.
tree_size#
- haiku.data_structures.tree_size(tree)[source]#
Sums the sizes of all arrays in a pytree.
For example given a ResNet50 model:
>>> f = hk.transform_with_state(lambda x: hk.nets.ResNet50(1000)(x, True)) >>> rng = jax.random.PRNGKey(42) >>> x = jnp.ones([128, 224, 224, 3]) >>> params, state = f.init(rng, x)
We can count the number of parameters and their size at f32:
>>> num_params = hk.data_structures.tree_size(params) >>> byte_size = hk.data_structures.tree_bytes(params) >>> print(f'{num_params} params, size: {byte_size / 1e6:.2f}MB') 25557032 params, size: 102.23MB
And compare that with casting our parameters to bf16:
>>> params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) >>> num_params = hk.data_structures.tree_size(params) >>> byte_size = hk.data_structures.tree_bytes(params) >>> print(f'{num_params} params, size: {byte_size / 1e6:.2f}MB') 25557032 params, size: 51.11MB
- Parameters
tree – A tree of jax.Arrays.
- Return type
int
- Returns
The total size (number of elements) of the array(s) in the input.
Testing#
|
Transforms the given function and runs init then (optionally) apply. |
transform_and_run#
- haiku.testing.transform_and_run(f=None, seed=42, run_apply=True, jax_transform=None, *, map_rng=None)[source]#
Transforms the given function and runs init then (optionally) apply.
Equivalent to:
>>> def f(x): ... return x >>> x = jnp.ones([]) >>> rng = jax.random.PRNGKey(42) >>> f = hk.transform_with_state(f) >>> params, state = f.init(rng, x) >>> out = f.apply(params, state, rng, x)
This function makes it very convenient to unit test Haiku:
>>> class MyTest(unittest.TestCase): ... @hk.testing.transform_and_run ... def test_linear_output(self): ... mod = hk.Linear(1) ... out = mod(jnp.ones([1, 1])) ... self.assertEqual(out.ndim, 2)
It can also be combined with
chexto test all pure/jit/pmap versions of a function:>>> class MyTest(unittest.TestCase): ... @chex.all_variants ... def test_linear_output(self): ... @hk.testing.transform_and_run(jax_transform=self.variant) ... def f(inputs): ... mod = hk.Linear(1) ... return mod(inputs) ... out = f(jnp.ones([1, 1])) ... self.assertEqual(out.ndim, 2)
And can also be useful in an interactive environment like ipython, Jupyter or Google Colaboratory:
>>> f = lambda x: hk.Bias()(x) >>> print(hk.testing.transform_and_run(f)(jnp.ones([1, 1]))) [[1.]]
See
transform()for more details.To use this with pmap (without
chex) you need to additionally pass in a function to map the init/apply rng keys. For example, if you want every instance of your pmap to have the same key:>>> def same_key_on_all_devices(key): ... return jnp.broadcast_to(key, (jax.local_device_count(), *key.shape))
>>> @hk.testing.transform_and_run(jax_transform=jax.pmap, ... map_rng=same_key_on_all_devices) ... def test_something(): ... ...
Or you can use a different key:
>>> def different_key_on_all_devices(key): ... return jax.random.split(key, jax.local_device_count())
>>> @hk.testing.transform_and_run(jax_transform=jax.pmap, ... map_rng=different_key_on_all_devices) ... def test_something_else(): ... ...
- Parameters
f (Optional[Fn]) – A function method to transform.
seed (Optional[int]) – A seed to pass to init and apply.
run_apply (bool) – Whether to run apply as well as init. Defaults to true.
jax_transform (Optional[Callable[[Fn], Fn]]) – An optional jax transform to apply on the init and apply functions.
map_rng (Optional[Callable[[Key], Key]]) – If set to a non-None value broadcast the init/apply rngs broadcast_rng-ways.
- Return type
T
- Returns
A function that
transform()sfand runsinitand optionallyapply.
Conditional Computation#
Return True if running the |
running_init#
- haiku.running_init()[source]#
Return True if running the
initfunction of a Haiku transform.In general you should not need to gate behaviour of your module based on whether you are running
initorapply, but sometimes (e.g. when making use of JAX control flow) this is required.For example, if you want to use
switch()to pick between experts, when we run your init function we need to ensure that params/state for all experts are created (unconditionally) but during apply we want to conditionally apply (and perhaps update the internal state) of only one of our experts:>>> experts = [hk.nets.ResNet50(10) for _ in range(5)] >>> x = jnp.ones([1, 224, 224, 3]) >>> if hk.running_init(): ... # During init unconditionally create params/state for all experts. ... for expert in experts: ... out = expert(x, is_training=True) ... else: ... # During apply conditionally apply (and update) only one expert. ... index = jax.random.randint(hk.next_rng_key(), [], 0, len(experts) - 1) ... out = hk.switch(index, experts, x)
- Return type
bool
- Returns
True if running
initotherwise False.
Functions#
|
Draws samples from a multinomial distribution. |
|
Returns a one-hot version of indices. |
multinomial#
- haiku.multinomial(rng, logits, num_samples)[source]#
Draws samples from a multinomial distribution.
DEPRECATED: Use
jax.random.categoricalinstead.- Parameters
rng – A JAX PRNGKey.
logits – Unnormalized log-probabilities, where last dimension is categories.
num_samples – Number of samples to draw.
- Returns
Chosen categories, of shape
logits.shape[:-1] + (num_samples,).
one_hot#
- haiku.one_hot(x, num_classes, dtype=<class 'jax.numpy.float32'>)[source]#
Returns a one-hot version of indices.
DEPRECATED: Use
jax.nn.one_hot(x, num_classes).astype(dtype)instead.- Parameters
x – A tensor of indices.
num_classes – Number of classes in the one-hot dimension.
dtype – The dtype.
- Returns
- The one-hot tensor. If indices’ shape is [A, B, …], shape is
[A, B, … num_classes].
References#
- 1
Wojciech Zaremba, Ilya Sutskever, and Oriol Vinyals. Recurrent neural network regularization. arXiv preprint arXiv:1409.2329, 2014. URL: https://arxiv.org/abs/1409.2329.
- 2(1,2,3,4)
Rafal Jozefowicz, Wojciech Zaremba, and Ilya Sutskever. An empirical exploration of recurrent network architectures. In International Conference on Machine Learning, 2342–2350. 2015.
- 3(1,2,3)
SHI Xingjian, Zhourong Chen, Hao Wang, Dit-Yan Yeung, Wai-Kin Wong, and Wang-chun Woo. Convolutional lstm network: a machine learning approach for precipitation nowcasting. In Advances in neural information processing systems, 802–810. 2015.