{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "_" }, "source": [ "# Visualization\n", "\n", "Haiku supports two ways to visualize your program. To use these you need to install two additional dependencies:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "_" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: dm-tree in /tmp/haiku-env/lib/python3.11/site-packages (0.1.8)\r\n", "Requirement already satisfied: graphviz in /tmp/haiku-env/lib/python3.11/site-packages (0.20.1)\r\n" ] } ], "source": [ "!pip install dm-tree graphviz" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "_" }, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "import haiku as hk" ] }, { "cell_type": "markdown", "metadata": { "id": "_" }, "source": [ "## Tabulate\n", "\n", "Like many neural network libraries, Haiku supports showing a summary of the execution of your program as a table of modules. Haiku's approach is to trace the execution of your program and to produce a table of (interesting) module method calls.\n", "\n", "For example, the interesting methods for a 3 layer MLP would be `MLP.__call__` which in turns calls `Linear.__call__` on three inner modules. For each module method we show columns relating to the input/output size of arrays, as well as details of the modules parameters and where it fits in the module hierarchy." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "_" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+\n", "| Module | Config | Module params | Input | Output | Param count | Param bytes |\n", "+=========================+==========================================+=================+============+============+===============+===============+\n", "| mlp (MLP) | MLP(output_sizes=[300, 100, 10]) | | f32[8,784] | f32[8,10] | 266,610 | 1.07 MB |\n", "+-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+\n", "| mlp/~/linear_0 (Linear) | Linear(output_size=300, name='linear_0') | w: f32[784,300] | f32[8,784] | f32[8,300] | 235,500 | 942.00 KB |\n", "| └ mlp (MLP) | | b: f32[300] | | | | |\n", "+-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+\n", "| mlp/~/linear_1 (Linear) | Linear(output_size=100, name='linear_1') | w: f32[300,100] | f32[8,300] | f32[8,100] | 30,100 | 120.40 KB |\n", "| └ mlp (MLP) | | b: f32[100] | | | | |\n", "+-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+\n", "| mlp/~/linear_2 (Linear) | Linear(output_size=10, name='linear_2') | w: f32[100,10] | f32[8,100] | f32[8,10] | 1,010 | 4.04 KB |\n", "| └ mlp (MLP) | | b: f32[10] | | | | |\n", "+-------------------------+------------------------------------------+-----------------+------------+------------+---------------+---------------+\n" ] } ], "source": [ "def f(x):\n", " return hk.nets.MLP([300, 100, 10])(x)\n", "\n", "f = hk.transform(f)\n", "x = jnp.ones([8, 28 * 28])\n", "\n", "print(hk.experimental.tabulate(f)(x))" ] }, { "cell_type": "markdown", "metadata": { "id": "_" }, "source": [ "We also offer access to the raw data used to build this table if you want to create your own summary:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "_" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MethodInvocation(module_details=ModuleDetails(module=MLP(output_sizes=[300, 100, 10]), method_name='__call__', params={'mlp/~/linear_0/b': f32[300], 'mlp/~/linear_0/w': f32[784,300], 'mlp/~/linear_1/b': f32[100], 'mlp/~/linear_1/w': f32[300,100], 'mlp/~/linear_2/b': f32[10], 'mlp/~/linear_2/w': f32[100,10]}, state={}), args_spec=(f32[8,784],), kwargs_spec={}, output_spec=f32[8,10], context=MethodContext(module=MLP(output_sizes=[300, 100, 10]), method_name='__call__', orig_method=functools.partial(, MLP(output_sizes=[300, 100, 10])), orig_class=), call_stack=(ModuleDetails(module=MLP(output_sizes=[300, 100, 10]), method_name='__call__', params={'mlp/~/linear_0/b': f32[300], 'mlp/~/linear_0/w': f32[784,300], 'mlp/~/linear_1/b': f32[100], 'mlp/~/linear_1/w': f32[300,100], 'mlp/~/linear_2/b': f32[10], 'mlp/~/linear_2/w': f32[100,10]}, state={}),))\n", "MethodInvocation(module_details=ModuleDetails(module=Linear(output_size=300, name='linear_0'), method_name='__call__', params={'mlp/~/linear_0/b': f32[300], 'mlp/~/linear_0/w': f32[784,300]}, state={}), args_spec=(f32[8,784],), kwargs_spec={}, output_spec=f32[8,300], context=MethodContext(module=Linear(output_size=300, name='linear_0'), method_name='__call__', orig_method=functools.partial(, Linear(output_size=300, name='linear_0')), orig_class=), call_stack=(ModuleDetails(module=Linear(output_size=300, name='linear_0'), method_name='__call__', params={'mlp/~/linear_0/b': f32[300], 'mlp/~/linear_0/w': f32[784,300]}, state={}), ModuleDetails(module=MLP(output_sizes=[300, 100, 10]), method_name='__call__', params={'mlp/~/linear_0/b': f32[300], 'mlp/~/linear_0/w': f32[784,300], 'mlp/~/linear_1/b': f32[100], 'mlp/~/linear_1/w': f32[300,100], 'mlp/~/linear_2/b': f32[10], 'mlp/~/linear_2/w': f32[100,10]}, state={})))\n", "MethodInvocation(module_details=ModuleDetails(module=Linear(output_size=100, name='linear_1'), method_name='__call__', params={'mlp/~/linear_1/b': f32[100], 'mlp/~/linear_1/w': f32[300,100]}, state={}), args_spec=(f32[8,300],), kwargs_spec={}, output_spec=f32[8,100], context=MethodContext(module=Linear(output_size=100, name='linear_1'), method_name='__call__', orig_method=functools.partial(, Linear(output_size=100, name='linear_1')), orig_class=), call_stack=(ModuleDetails(module=Linear(output_size=100, name='linear_1'), method_name='__call__', params={'mlp/~/linear_1/b': f32[100], 'mlp/~/linear_1/w': f32[300,100]}, state={}), ModuleDetails(module=MLP(output_sizes=[300, 100, 10]), method_name='__call__', params={'mlp/~/linear_0/b': f32[300], 'mlp/~/linear_0/w': f32[784,300], 'mlp/~/linear_1/b': f32[100], 'mlp/~/linear_1/w': f32[300,100], 'mlp/~/linear_2/b': f32[10], 'mlp/~/linear_2/w': f32[100,10]}, state={})))\n", "MethodInvocation(module_details=ModuleDetails(module=Linear(output_size=10, name='linear_2'), method_name='__call__', params={'mlp/~/linear_2/b': f32[10], 'mlp/~/linear_2/w': f32[100,10]}, state={}), args_spec=(f32[8,100],), kwargs_spec={}, output_spec=f32[8,10], context=MethodContext(module=Linear(output_size=10, name='linear_2'), method_name='__call__', orig_method=functools.partial(, Linear(output_size=10, name='linear_2')), orig_class=), call_stack=(ModuleDetails(module=Linear(output_size=10, name='linear_2'), method_name='__call__', params={'mlp/~/linear_2/b': f32[10], 'mlp/~/linear_2/w': f32[100,10]}, state={}), ModuleDetails(module=MLP(output_sizes=[300, 100, 10]), method_name='__call__', params={'mlp/~/linear_0/b': f32[300], 'mlp/~/linear_0/w': f32[784,300], 'mlp/~/linear_1/b': f32[100], 'mlp/~/linear_1/w': f32[300,100], 'mlp/~/linear_2/b': f32[10], 'mlp/~/linear_2/w': f32[100,10]}, state={})))\n" ] } ], "source": [ "for method_invocation in hk.experimental.eval_summary(f)(x):\n", " print(method_invocation)" ] }, { "cell_type": "markdown", "metadata": { "id": "_" }, "source": [ "## Graphviz (aka. `to_dot`)\n", "\n", "Haiku supports rendering your program as a graphviz graph. We show all of the JAX primitives involved in a given computation clustered by Haiku module.\n", "\n", "Lets start by visualizing a simple program not using Haiku modules:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "_" }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "G\n", "\n", "f\n", "\n", "\n", "49830512\n", "\n", "\n", "sin\n", "f32[1]\n", "\n", "\n", "\n", "\n", "\n", "51865824\n", "\n", "\n", "cos\n", "f32[1]\n", "\n", "\n", "\n", "\n", "\n", "49830512->51865824\n", "\n", "\n", "\n", "\n", "\n", "52344992\n", "\n", "\n", "add\n", "f32[1]\n", "\n", "\n", "\n", "\n", "\n", "49830512->52344992\n", "\n", "\n", "\n", "\n", "\n", "51865824->52344992\n", "\n", "\n", "\n", "\n", "\n", "49103760\n", "\n", "\n", "output\n", "add\n", "f32[1]\n", "\n", "\n", "\n", "\n", "\n", "52344992->49103760\n", "\n", "\n", "\n", "\n", "\n", "486529600\n", "\n", "args[0]\n", "f32[1]\n", "\n", "\n", "\n", "486529600->49830512\n", "\n", "\n", "\n", "\n", "\n", "486529601\n", "\n", "(reuse)\n", "args[0]\n", "f32[1]\n", "\n", "\n", "\n", "486529601->49103760\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def f(a):\n", " b = jnp.sin(a)\n", " c = jnp.cos(b)\n", " d = b + c\n", " e = a + d\n", " return e\n", "\n", "x = jnp.ones([1])\n", "dot = hk.to_dot(f)(x)\n", "\n", "import graphviz\n", "graphviz.Source(dot)" ] }, { "cell_type": "markdown", "metadata": { "id": "_" }, "source": [ "The visualization above shows our program as a simple dataflow graph of our single input highligted in orange (`args[0]`) being passed through some operations and producing a result (highlighted in blue). Primitive operations (e.g. `sin`, `cos` and `add`) are highlighted in yellow.\n", "\n", "Actual Haiku programs are often far more complex, involving many modules and many more primitive operations. For these programs it is often useful to visualize the program on a module by module basis.\n", "\n", "`to_dot` offers this by clustering operations by their module. Again it is probably simplest to see an example:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "_" }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "G\n", "\n", "apply_fn\n", "\n", "cluster_139737768113904\n", "\n", "mlp\n", "\n", "\n", "cluster_139737768115264\n", "\n", "mlp/~/linear_0\n", "\n", "\n", "cluster_139737790661312\n", "\n", "pjit (relu)\n", "\n", "\n", "cluster_139737790121520\n", "\n", "mlp/~/linear_1\n", "\n", "\n", "cluster_139737767817952\n", "\n", "pjit (relu)\n", "\n", "\n", "cluster_139737768114384\n", "\n", "mlp/~/linear_2\n", "\n", "\n", "\n", "59282016\n", "\n", "\n", "dot_general\n", "f32[8,300]\n", "\n", "\n", "\n", "\n", "\n", "71412688\n", "\n", "\n", "add\n", "f32[8,300]\n", "\n", "\n", "\n", "\n", "\n", "59282016->71412688\n", "\n", "\n", "\n", "\n", "\n", "69211888\n", "\n", "\n", "broadcast_in_dim\n", "f32[8,300]\n", "\n", "\n", "\n", "\n", "\n", "69211888->71412688\n", "\n", "\n", "\n", "\n", "\n", "67211248\n", "\n", "\n", "max\n", "f32[8,300]\n", "\n", "\n", "\n", "\n", "\n", "71412688->67211248\n", "\n", "\n", "\n", "\n", "\n", "523449920\n", "\n", "args[2]\n", "f32[8,784]\n", "\n", "\n", "\n", "523449920->59282016\n", "\n", "\n", "\n", "\n", "\n", "527083520\n", "\n", "args[0]: mlp/~/linear_0/w\n", "f32[784,300]\n", "\n", "\n", "\n", "527083520->59282016\n", "\n", "\n", "\n", "\n", "\n", "548696480\n", "\n", "args[0]: mlp/~/linear_0/b\n", "f32[300]\n", "\n", "\n", "\n", "548696480->69211888\n", "\n", "\n", "\n", "\n", "\n", "70422144\n", "\n", "\n", "dot_general\n", "f32[8,100]\n", "\n", "\n", "\n", "\n", "\n", "67211248->70422144\n", "\n", "\n", "\n", "\n", "\n", "139737768045680\n", "\n", "0.0\n", "\n", "\n", "\n", "139737768045680->67211248\n", "\n", "\n", "\n", "\n", "\n", "50960960\n", "\n", "\n", "add\n", "f32[8,100]\n", "\n", "\n", "\n", "\n", "\n", "70422144->50960960\n", "\n", "\n", "\n", "\n", "\n", "69926688\n", "\n", "\n", "broadcast_in_dim\n", "f32[8,100]\n", "\n", "\n", "\n", "\n", "\n", "69926688->50960960\n", "\n", "\n", "\n", "\n", "\n", "68842336\n", "\n", "\n", "max\n", "f32[8,100]\n", "\n", "\n", "\n", "\n", "\n", "50960960->68842336\n", "\n", "\n", "\n", "\n", "\n", "653177280\n", "\n", "args[0]: mlp/~/linear_1/w\n", "f32[300,100]\n", "\n", "\n", "\n", "653177280->70422144\n", "\n", "\n", "\n", "\n", "\n", "642729120\n", "\n", "args[0]: mlp/~/linear_1/b\n", "f32[100]\n", "\n", "\n", "\n", "642729120->69926688\n", "\n", "\n", "\n", "\n", "\n", "71785552\n", "\n", "\n", "dot_general\n", "f32[8,10]\n", "\n", "\n", "\n", "\n", "\n", "68842336->71785552\n", "\n", "\n", "\n", "\n", "\n", "139737768255984\n", "\n", "0.0\n", "\n", "\n", "\n", "139737768255984->68842336\n", "\n", "\n", "\n", "\n", "\n", "72015984\n", "\n", "\n", "output\n", "add\n", "f32[8,10]\n", "\n", "\n", "\n", "\n", "\n", "71785552->72015984\n", "\n", "\n", "\n", "\n", "\n", "70466816\n", "\n", "\n", "broadcast_in_dim\n", "f32[8,10]\n", "\n", "\n", "\n", "\n", "\n", "70466816->72015984\n", "\n", "\n", "\n", "\n", "\n", "689771200\n", "\n", "args[0]: mlp/~/linear_2/w\n", "f32[100,10]\n", "\n", "\n", "\n", "689771200->71785552\n", "\n", "\n", "\n", "\n", "\n", "690080640\n", "\n", "args[0]: mlp/~/linear_2/b\n", "f32[10]\n", "\n", "\n", "\n", "690080640->70466816\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def f(x):\n", " return hk.nets.MLP([300, 100, 10])(x)\n", "\n", "f = hk.transform(f)\n", "\n", "rng = jax.random.PRNGKey(42)\n", "x = jnp.ones([8, 28 * 28])\n", "params = f.init(rng, x)\n", "\n", "dot = hk.to_dot(f.apply)(params, None, x)\n", "graphviz.Source(dot)" ] } ], "metadata": { "colab": {} }, "nbformat": 4, "nbformat_minor": 0 }