{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "_" }, "outputs": [], "source": [ "import haiku as hk\n", "import jax\n", "import jax.numpy as jnp" ] }, { "cell_type": "markdown", "metadata": { "id": "7ag4NOGwC3lR" }, "source": [ "**TL;DR:** A JAX transform inside of a `hk.transform` is likely to transform a side effecting function, which will result in an `UnexpectedTracerError`. This page describes two ways to get around this." ] }, { "cell_type": "markdown", "metadata": { "id": "kXvm0ueZC3lU" }, "source": [ "# Limitations of Nesting JAX Functions and Haiku Modules\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "XKDdO7n_C3lU" }, "source": [ "\n", "\n", "Once a Haiku network has been transformed to a pair of pure functions using `hk.transform`, it's possible to freely combine these with any JAX transformations like `jax.jit`, `jax.grad`, `jax.lax.scan` and so on.\n", "\n", "If you want to use JAX transformations **inside** of a `hk.transform` however, you need to be more careful. It's possible, but most functions inside of the `hk.transform` boundary are still side effecting, and cannot safely be transformed by JAX.\n", "This is a common cause of `UnexpectedTracerError`s in code using Haiku. These errors are a result of using a JAX transform on a side effecting function (for more information on this JAX error, see https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError).\n", "\n", "An example with `jax.eval_shape`:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CMnZLQgPC3lV", "outputId": "2305be6d-2976-4f0b-a4f7-eb999a46c16d" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "UnexpectedTracerError: applied JAX transform to side effecting function\n" ] } ], "source": [ "def net(x): # inside of a hk.transform, this is still side-effecting\n", " w = hk.get_parameter(\"w\", (2, 2), init=jnp.ones)\n", " return w @ x\n", "\n", "def eval_shape_net(x):\n", " output_shape = jax.eval_shape(net, x) # eval_shape on side-effecting function\n", " return net(x) # UnexpectedTracerError!\n", "\n", "init, _ = hk.transform(eval_shape_net)\n", "try:\n", " init(jax.random.PRNGKey(666), jnp.ones((2, 2)))\n", "except jax.errors.UnexpectedTracerError:\n", " print(\"UnexpectedTracerError: applied JAX transform to side effecting function\")" ] }, { "cell_type": "markdown", "metadata": { "id": "pZrXKtC0C3lX" }, "source": [ "These examples use `jax.eval_shape`, but could have used any higher-order JAX function (eg. `jax.vmap`, `jax.lax.scan`, \n", "`jax.while_loop`, ...).\n", "\n", "The error points to `hk.get_parameter`. This is the operation which makes `net` a side effecting function. The side effect in this case is the creation of a parameter, which gets stored into the Haiku state. Similarly you would get an error using `hk.next_rng_key`, because it advances the Haiku RNG state and stores a new PRNGKey into the Haiku state. In general, transforming a non-transformed Haiku module will result in an `UnexpectedTracerError`.\n", "\n", "You could re-write the code above to create the parameter outside of the `eval_shape` transformation, making `net` a pure function by threading through the parameter explictly as an argument:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BG68aS9lC3lY", "outputId": "3fe34887-cb81-499c-8dda-e8f8f818e4a3" }, "outputs": [ { "data": { "text/plain": [ "(ShapeDtypeStruct(shape=(3, 3), dtype=float32),\n", " DeviceArray([[2., 2., 2.],\n", " [2., 2., 2.],\n", " [2., 2., 2.]], dtype=float32))" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def net(w, x): # no side effects!\n", " return w @ x\n", "\n", "def eval_shape_net(x):\n", " w = hk.get_parameter(\"w\", (3, 2), init=jnp.ones)\n", " output_shape = jax.eval_shape(net, w, x) # net is now side-effect free\n", " return output_shape, net(w, x)\n", "\n", "key = jax.random.PRNGKey(777)\n", "x = jnp.ones((2, 3))\n", "init, apply = hk.transform(eval_shape_net)\n", "params = init(key, x)\n", "apply(params, key, x)" ] }, { "cell_type": "markdown", "metadata": { "id": "FfVWrRg9C3lZ" }, "source": [ "However, that's not always possible. Consider the following code which calls a Haiku module (`hk.nets.MLP`) which we don't own. This module will internally call `get_parameter`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iBJ-JpDYC3la", "outputId": "6bbb43bf-bfc3-4b7e-fdd9-e471c627be32" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "UnexpectedTracerError: applied JAX transform to side effecting function\n" ] } ], "source": [ "def eval_shape_net(x):\n", " net = hk.nets.MLP([300, 100])\n", " output_shape = jax.eval_shape(net, x)\n", " return output_shape, net(x)\n", "\n", "init, _ = hk.transform(eval_shape_net)\n", "try:\n", " init(jax.random.PRNGKey(666), jnp.ones((2, 2)))\n", "except jax.errors.UnexpectedTracerError:\n", " print(\"UnexpectedTracerError: applied JAX transform to side effecting function\")" ] }, { "cell_type": "markdown", "metadata": { "id": "e04UC5ZbC3lb" }, "source": [ "## Using hk.lift\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "O7E0ldV7C3lc" }, "source": [ "We want a way to get access to our implicit Haiku state, and\n", "get a functionally pure version of `hk.nets.MLP`. The way to usually achieve this is by using a `hk.transform`, so all we need is a way to nest an inner `hk.tranform` inside an outer `hk.transform`! We'll create another pair of `init` and `apply` functions through `hk.transform`, and these can then be safely combined with any higher-order JAX function.\n", "\n", "However, we need a way to register this nested `hk.tranform` state into the outer scope. We can use `hk.lift` for this. Wrapping our inner `init` function with `hk.lift` will register our inner `params` into the outer parameter scope." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9w7nuPF1C3lc", "outputId": "042d3e7c-cac7-4fc1-cfd0-495b69a193c2" }, "outputs": [ { "data": { "text/plain": [ "FlatMap({\n", " 'inner/mlp/~/linear_0': FlatMap({'b': (300,), 'w': (100, 300)}),\n", " 'inner/mlp/~/linear_1': FlatMap({'b': (100,), 'w': (300, 100)}),\n", " 'mlp/~/linear_0': FlatMap({'b': (300,), 'w': (100, 300)}),\n", " 'mlp/~/linear_1': FlatMap({'b': (100,), 'w': (300, 100)}),\n", "})" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def eval_shape_net(x):\n", " net = hk.nets.MLP([300, 100]) # still side-effecting\n", " init, apply = hk.transform(net) # nested transform\n", " params = hk.lift(init, name=\"inner\")(hk.next_rng_key(), x) # register parameters in outer module scope with name \"inner\"\n", " output_shape = jax.eval_shape(apply, params, hk.next_rng_key(), x) # apply is a functionaly pure function and can be transformed!\n", " out = net(x) \n", " return out, output_shape \n", " \n", " \n", "init, apply = hk.transform(eval_shape_net) \n", "params = init(jax.random.PRNGKey(777), jnp.ones((100, 100)))\n", "apply(params, jax.random.PRNGKey(777), jnp.ones((100, 100)))\n", "jax.tree_util.tree_map(lambda x: x.shape, params)" ] }, { "cell_type": "markdown", "metadata": { "id": "90dfm3H6C3ld" }, "source": [ "## Using Haiku versions of JAX transforms\n" ] }, { "cell_type": "markdown", "metadata": { "id": "damGG_NkC3ld" }, "source": [ "Haiku also provides wrapped versions of some of the JAX functions for convenience. For example: `hk.grad`, `hk.vmap`, .... See https://dm-haiku.readthedocs.io/en/latest/api.html#jax-fundamentals for a full list of available functions.\n", "\n", "These wrappers apply the JAX function to a functionally pure version of the Haiku function, by doing the explicit state threading for you. They don't introduce an extra namescoping level like `lift` does." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "S-b4sFboC3ld" }, "outputs": [], "source": [ "def eval_shape_net(x):\n", " net = hk.nets.MLP([300, 100]) # still side-effecting\n", " output_shape = hk.eval_shape(net, x) # hk.eval_shape threads through the Haiku state for you\n", " out = net(x) \n", " return out, output_shape \n", " \n", " \n", "init, apply = hk.transform(eval_shape_net) \n", "params = init(jax.random.PRNGKey(777), jnp.ones((100, 100)))\n", "out = apply(params, jax.random.PRNGKey(777), jnp.ones((100, 100)))" ] }, { "cell_type": "markdown", "metadata": { "id": "syzC6xUHCR13" }, "source": [ "## Summary\n", "\n", "To summarize, some good and bad examples of combining JAX transforms and Haiku modules:\n", "\n", "| What? | Works? | Example |\n", "|-----------------------------------------------|-------------------------------|-------------------------------------------------------|\n", "| vmapping outside a hk.transform | ✔ yes! | jax.vmap(hk.transform(hk.nets.ResNet50)) |\n", "| vmapping inside a hk.transform | ✖ unexpected tracer error | hk.transform(jax.vmap(hk.nets.ResNet50)) |\n", "| vmapping a nested hk.transform (without lift) | ✖ inner state is not registered | hk.transform(jax.vmap(hk.transform(hk.nets.ResNet50))) |\n", "| vmapping a nested hk.transform (with lift) | ✔ yes! | hk.transform(jax.vmap(hk.lift(hk.transform(hk.nets.ResNet50)))) |\n", "| using hk.vmap | ✔ yes! | hk.transform(hk.vmap(hk.nets.ResNet50)) \n" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "transforms.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }