diff --git a/README.md b/README.md
index 74124c41a7cca1df66f14fdd8ab33b55d6817f86..f963d2891c8dd976361d15366515b3acc892d3fa 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,30 @@
-# MultiExitViT
+# Introduction
+
+This code was developed and executed using Jupyter notebooks.
+
+The following instructions assume Ubuntu 20.04 operating system with superuser access, Nvidia GPUs, GPU drivers already installed and CUDA version 10.1, 11.0 or 11.2.
+
+# Setting Up the Environment
+
+1. [Install Docker](https://docs.docker.com/engine/install/ubuntu/)
+2. [Install Nvidia Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#setting-up-nvidia-container-toolkit)
+3. `sudo docker run --gpus all -it -p 8888:8888 tensorflow/tensorflow:2.3.2-gpu-jupyter` (also tested with the `tensorflow/tensorflow:2.4.1-gpu-jupyter` image)
+4. Copy the URL provided in docker logs (including the token).
+5. <kbd>CTRL</kbd>+<kbd>P</kbd> then <kbd>CTRL</kbd>+<kbd>Q</kbd> to detach from the container without terminating the execution.
+6. Install SciPy inside the container: `sudo docker exec -it [container_name] bash` (you can find the container name from the output of `sudo docker ps`) then `pip install scipy==1.5` (use <kbd>CTRL</kbd>+<kbd>D</kbd> to terminate and detach)
+7. Paste the copied URL in your browser to open Jupyter (if you are running the docker container on a remote server, you need to replace the IP address with that of the server).
+8. Upload all of the `.ipynb` files in this repository.
+
+# Running the Experiments
+
+Note: in each of the notebooks, you can modify `SELECTED_GPUS` to specify which GPUs to use. If you only have a single GPU available, set `SELECTED_GPUS = [0]`. The distributed training may not be supported in some notebooks.
+
+1. Run the `train_cifar10_backbone`, `train_cifar100_backbone`, `train_fashion_mnist_backbone` and`train_disco_backbone` notebooks to train the backbones.
+
+2. Run the `precompute_cifar_features`, `precompute_disco_features` and `precompute_fashion_mnist_features` notebooks to precompute the intermediate representations of the backbones.
+
+3. Run the `ee` and `cw` notebooks to run the end-to-end and classifier-wise experiments, respectively. You can change the `dataset`, `head_type`, `version` and other parameters given to the `train` function.
+
+4. Run the `calculate_flops` notebook to calculate the FLOPS, the `calculate_maes` notebook to calculate MAEs for the DISCO dataset cases, and the `plots` notebook to draw the plots.
+
 
diff --git a/calculate_flops.ipynb b/calculate_flops.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..b1d5f68d1d5fb58a0f577d369cd92bf9e22e378e
--- /dev/null
+++ b/calculate_flops.ipynb
@@ -0,0 +1,377 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "SELECTED_GPUS = [7]\n",
+    "\n",
+    "import os\n",
+    "\n",
+    "os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu_number) for gpu_number in SELECTED_GPUS])\n",
+    "\n",
+    "import tensorflow as tf \n",
+    "\n",
+    "tf.get_logger().setLevel('INFO')\n",
+    "\n",
+    "assert len(tf.config.list_physical_devices('GPU')) > 0\n",
+    "\n",
+    "GPUS = tf.config.experimental.list_physical_devices('GPU')\n",
+    "for gpu in GPUS:\n",
+    "    tf.config.experimental.set_memory_growth(gpu, True)\n",
+    "\n",
+    "DISTRIBUTED_STRATEGY = tf.distribute.MirroredStrategy(\n",
+    "    cross_device_ops=tf.distribute.NcclAllReduce(),\n",
+    "    devices=['/gpu:%d' % index for index in range(len(SELECTED_GPUS))]\n",
+    ")\n",
+    "\n",
+    "NUM_GPUS = DISTRIBUTED_STRATEGY.num_replicas_in_sync\n",
+    "\n",
+    "print('Number of devices: {}'.format(NUM_GPUS))\n",
+    "\n",
+    "import math\n",
+    "import numpy as np\n",
+    "import pickle\n",
+    "import sys\n",
+    "from skimage import transform\n",
+    "from tensorflow.python.framework.convert_to_constants import  convert_variables_to_constants_v2_as_graph\n",
+    "from vit_keras import vit\n",
+    "from vit_keras.layers import ClassToken, AddPositionEmbs, MultiHeadSelfAttention, TransformerBlock\n",
+    "\n",
+    "IMAGE_SIZE = 384\n",
+    "PATCH_SIZE = 16\n",
+    "HIDDEN_DIM = 768\n",
+    "MLP_DIM = 3072\n",
+    "CHANNELS_MLP_DIM = 3072\n",
+    "TOKENS_MLP_DIM = 384"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_flops(model):\n",
+    "    \"\"\"\n",
+    "    from https://github.com/tensorflow/tensorflow/issues/32809#issuecomment-768977280\n",
+    "    \"\"\"\n",
+    "    concrete = tf.function(lambda inputs: model(inputs))\n",
+    "    concrete_func = concrete.get_concrete_function(\n",
+    "        [tf.TensorSpec([1, *inputs.shape[1:]]) for inputs in model.inputs])\n",
+    "    frozen_func, graph_def = convert_variables_to_constants_v2_as_graph(concrete_func)\n",
+    "    with tf.Graph().as_default() as graph:\n",
+    "        tf.graph_util.import_graph_def(graph_def, name='')\n",
+    "        run_meta = tf.compat.v1.RunMetadata()\n",
+    "        opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()\n",
+    "        flops = tf.compat.v1.profiler.profile(graph=graph, run_meta=run_meta, cmd=\"op\", options=opts)\n",
+    "        return flops.total_float_ops"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# from https://github.com/leondgarse/Keras_mlp/blob/main/res_mlp.py\n",
+    "\n",
+    "def channel_affine(inputs, use_bias=True, weight_init_value=1, name=''):\n",
+    "    ww_init = tfkeras.initializers.Constant(weight_init_value) if weight_init_value != 1 else 'ones'\n",
+    "    nn = tf.keras.backend.expand_dims(inputs, 1)\n",
+    "    nn = tf.keras.layers.DepthwiseConv2D(1, depthwise_initializer=ww_init, use_bias=use_bias, name=name + 'affine')(nn)\n",
+    "    return tf.keras.backend.squeeze(nn, 1)\n",
+    "\n",
+    "def mlp_block(inputs, mlp_dim, activation='gelu', name=''):\n",
+    "    affine_inputs = channel_affine(inputs, use_bias=True, name=name + '1_')\n",
+    "    nn = tf.keras.layers.Permute((2, 1), name=name + 'permute_1')(affine_inputs)\n",
+    "    nn = tf.keras.layers.Dense(nn.shape[-1], name=name + 'dense_1')(nn)\n",
+    "    nn = tf.keras.layers.Permute((2, 1), name=name + 'permute_2')(nn)\n",
+    "    nn = channel_affine(nn, use_bias=False, name=name + '1_gamma_')\n",
+    "    skip_conn = tf.keras.layers.Add(name=name + 'add_1')([nn, affine_inputs])\n",
+    "\n",
+    "    affine_skip = channel_affine(skip_conn, use_bias=True, name=name + '2_')\n",
+    "    nn = tf.keras.layers.Dense(mlp_dim, name=name + 'dense_2_1')(affine_skip)\n",
+    "    nn = tf.keras.layers.Activation(activation, name=name + 'gelu')(nn)\n",
+    "    nn = tf.keras.layers.Dense(inputs.shape[-1], name=name + 'dense_2_2')(nn)\n",
+    "    nn = channel_affine(nn, use_bias=False, name=name + '2_gamma_')\n",
+    "    nn = tf.keras.layers.Add(name=name + 'add_2')([nn, affine_skip])\n",
+    "    return nn"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# from https://github.com/Benjamin-Etheredge/mlp-mixer-keras/blob/main/mlp_mixer_keras/mlp_mixer.py\n",
+    "\n",
+    "class MlpBlock(tf.keras.layers.Layer):\n",
+    "    def __init__(self, dim, hidden_dim, activation=None, **kwargs):\n",
+    "        super(MlpBlock, self).__init__(**kwargs)\n",
+    "\n",
+    "        if activation is None:\n",
+    "            activation = tf.keras.activations.gelu\n",
+    "\n",
+    "        self.dim = dim\n",
+    "        self.dense1 = tf.keras.layers.Dense(hidden_dim)\n",
+    "        self.activation = tf.keras.layers.Activation(activation)\n",
+    "        self.dense2 = tf.keras.layers.Dense(dim)\n",
+    "\n",
+    "    def call(self, inputs):\n",
+    "        x = inputs\n",
+    "        x = self.dense1(x)\n",
+    "        x = self.activation(x)\n",
+    "        x = self.dense2(x)\n",
+    "        return x\n",
+    "\n",
+    "    def compute_output_shape(self, input_signature):\n",
+    "        return (input_signature[0], self.dim)\n",
+    "\n",
+    "class MixerBlock(tf.keras.layers.Layer):\n",
+    "    def __init__(\n",
+    "        self,\n",
+    "        num_patches,\n",
+    "        channel_dim,\n",
+    "        token_mixer_hidden_dim,\n",
+    "        channel_mixer_hidden_dim=None,\n",
+    "        activation=None,\n",
+    "        **kwargs\n",
+    "    ):\n",
+    "        super(MixerBlock, self).__init__(**kwargs)\n",
+    "\n",
+    "        if activation is None:\n",
+    "            activation = tf.keras.activations.gelu\n",
+    "\n",
+    "        if channel_mixer_hidden_dim is None:\n",
+    "            channel_mixer_hidden_dim = token_mixer_hidden_dim\n",
+    "\n",
+    "        self.norm1 = tf.keras.layers.LayerNormalization(axis=1)\n",
+    "        self.permute1 = tf.keras.layers.Permute((2, 1))\n",
+    "        self.token_mixer = MlpBlock(num_patches, token_mixer_hidden_dim, name='token_mixer')\n",
+    "\n",
+    "        self.permute2 = tf.keras.layers.Permute((2, 1))\n",
+    "        self.norm2 = tf.keras.layers.LayerNormalization(axis=1)\n",
+    "        self.channel_mixer = MlpBlock(channel_dim, channel_mixer_hidden_dim, name='channel_mixer')\n",
+    "\n",
+    "        self.skip_connection1 = tf.keras.layers.Add()\n",
+    "        self.skip_connection2 = tf.keras.layers.Add()\n",
+    "\n",
+    "    def call(self, inputs):\n",
+    "        x = inputs\n",
+    "        skip_x = x\n",
+    "        x = self.norm1(x)\n",
+    "        x = self.permute1(x)\n",
+    "        x = self.token_mixer(x)\n",
+    "\n",
+    "        x = self.permute2(x)\n",
+    "\n",
+    "        x = self.skip_connection1([x, skip_x])\n",
+    "        skip_x = x\n",
+    "\n",
+    "        x = self.norm2(x)\n",
+    "        x = self.channel_mixer(x)\n",
+    "\n",
+    "        x = self.skip_connection2([x, skip_x])\n",
+    "\n",
+    "        return x\n",
+    "\n",
+    "    def compute_output_shape(self, input_shape):\n",
+    "        return input_shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_branch_id(branch_number):\n",
+    "    if branch_number == 1:\n",
+    "        return 'transformer_block'\n",
+    "    else:\n",
+    "        return 'transformer_block_%d' % (branch_number - 1)\n",
+    "\n",
+    "def get_model(dataset_name, branch_type, branch_number):\n",
+    "    if dataset_name == 'disco':\n",
+    "        model_file_name = 'vit_cc_backbone_v2.h5'\n",
+    "        output_units = 1\n",
+    "        output_activation = None\n",
+    "    elif dataset_name == 'fashion_mnist':\n",
+    "        model_file_name = 'vit_fashion_mnist_v1.h5'\n",
+    "        output_units = 10\n",
+    "        output_activation = 'softmax'\n",
+    "    elif dataset_name == 'cifar10':\n",
+    "        model_file_name = 'vit_cifar10_v1.h5'\n",
+    "        output_units = 10\n",
+    "        output_activation = 'softmax'\n",
+    "    else:\n",
+    "        model_file_name = 'vit_cifar100_v1.h5'\n",
+    "        output_units = 100\n",
+    "        output_activation = 'softmax'\n",
+    "\n",
+    "    backbone_model = tf.keras.models.load_model(model_file_name, custom_objects={\n",
+    "        'ClassToken': ClassToken,\n",
+    "        'AddPositionEmbs': AddPositionEmbs,\n",
+    "        'MultiHeadSelfAttention': MultiHeadSelfAttention,\n",
+    "        'TransformerBlock': TransformerBlock,\n",
+    "    })\n",
+    "    \n",
+    "    # freeze\n",
+    "    for layer in backbone_model.layers:\n",
+    "        layer.trainable = False\n",
+    "    \n",
+    "    if branch_type == 'mlp':\n",
+    "        y, _ = backbone_model.get_layer(get_branch_id(branch_number)).output\n",
+    "        y = tf.keras.layers.LayerNormalization(\n",
+    "            epsilon=1e-6, name=\"Transformer/encoder_norm\"\n",
+    "        )(y)\n",
+    "        y = tf.keras.layers.Lambda(lambda v: v[:, 0], name=\"ExtractToken\")(y)\n",
+    "\n",
+    "    elif branch_type == 'vit':\n",
+    "        y, _ = backbone_model.get_layer(get_branch_id(branch_number)).output\n",
+    "        y, _ = TransformerBlock(\n",
+    "            num_heads=12,\n",
+    "            mlp_dim=3072,\n",
+    "            dropout=0.1,\n",
+    "            name=f\"Transformer/encoderblock_x\",\n",
+    "        )(y)\n",
+    "        y = tf.keras.layers.LayerNormalization(\n",
+    "            epsilon=1e-6, name=\"Transformer/encoder_norm\"\n",
+    "        )(y)\n",
+    "        y = tf.keras.layers.Lambda(lambda v: v[:, 0], name=\"ExtractToken\")(y)\n",
+    "\n",
+    "    elif branch_type.startswith('cnn_'):\n",
+    "        y0, _ = backbone_model.get_layer(get_branch_id(branch_number)).output\n",
+    "        channels = HIDDEN_DIM\n",
+    "        width = height = IMAGE_SIZE // PATCH_SIZE\n",
+    "        y1 = tf.keras.layers.Lambda(lambda v: v[:, 1:], name='RemoveToken')(y0)\n",
+    "        y1 = tf.keras.layers.Reshape((width, height, channels), name='cnn_reshape')(y1)\n",
+    "        y2 = tf.keras.layers.Lambda(lambda v: v[:, 0], name='ExtractToken')(y0)\n",
+    "        y2 = tf.keras.layers.RepeatVector(width * height)(y2)\n",
+    "        y2 = tf.keras.layers.Reshape((width, height, channels), name='cls_reshape')(y2)\n",
+    "        if branch_type == 'cnn_ignore':\n",
+    "            y = y1\n",
+    "        elif branch_type == 'cnn_add':\n",
+    "            y = tf.keras.layers.Add()([y1, y2])\n",
+    "        elif branch_type == 'cnn_project':\n",
+    "            y = tf.keras.layers.Concatenate()([y1, y2])\n",
+    "        y = tf.keras.layers.Conv2D(\n",
+    "            filters=16,\n",
+    "            kernel_size=(3, 3),\n",
+    "            activation='elu',\n",
+    "            padding='same'\n",
+    "        )(y)\n",
+    "        y = tf.keras.layers.MaxPool2D(pool_size=(2, 2))(y)\n",
+    "        y = tf.keras.layers.Flatten()(y)\n",
+    "\n",
+    "    elif branch_type == 'resmlp':\n",
+    "        y, _ = backbone_model.get_layer(get_branch_id(branch_number)).output\n",
+    "        y = mlp_block(y, mlp_dim=MLP_DIM, name='mlp_mixer')\n",
+    "        y = tf.keras.layers.GlobalAveragePooling1D()(y)\n",
+    "\n",
+    "    elif branch_type == 'mlp_mixer':\n",
+    "        y, _ = backbone_model.get_layer(get_branch_id(branch_number)).output\n",
+    "        num_patches = (IMAGE_SIZE // PATCH_SIZE) ** 2 + 1\n",
+    "        y = MixerBlock(\n",
+    "            num_patches=num_patches,\n",
+    "            channel_dim=HIDDEN_DIM,\n",
+    "            token_mixer_hidden_dim=TOKENS_MLP_DIM,\n",
+    "            channel_mixer_hidden_dim=CHANNELS_MLP_DIM\n",
+    "        )(y)\n",
+    "        y = tf.keras.layers.GlobalAveragePooling1D()(y)\n",
+    "\n",
+    "    else:\n",
+    "        raise Exception('Unknown branch type: %s' % branch_type)\n",
+    "    \n",
+    "    # MLP head\n",
+    "    initializer = tf.keras.initializers.he_normal()\n",
+    "    regularizer = tf.keras.regularizers.l2()\n",
+    "    y = tf.keras.layers.Dense(\n",
+    "        units=256,\n",
+    "        activation='elu',\n",
+    "        kernel_initializer=initializer,\n",
+    "        kernel_regularizer=regularizer\n",
+    "    )(y)\n",
+    "    y = tf.keras.layers.Dropout(0.5)(y)\n",
+    "    y = tf.keras.layers.Dense(\n",
+    "        units=256,\n",
+    "        activation='elu',\n",
+    "        kernel_initializer=initializer,\n",
+    "        kernel_regularizer=regularizer\n",
+    "    )(y)\n",
+    "    y = tf.keras.layers.Dropout(0.5)(y)\n",
+    "    y = tf.keras.layers.Dense(\n",
+    "        units=output_units,\n",
+    "        activation=output_activation,\n",
+    "        kernel_initializer=initializer,\n",
+    "        kernel_regularizer=regularizer\n",
+    "    )(y)\n",
+    "\n",
+    "    model = tf.keras.models.Model(\n",
+    "        inputs=backbone_model.get_layer(index=0).input,\n",
+    "        outputs=y\n",
+    "    )\n",
+    "\n",
+    "    return model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "branch_types = [\n",
+    "    'mlp',\n",
+    "    'vit',\n",
+    "    'cnn_ignore',\n",
+    "    'cnn_add',\n",
+    "    'cnn_project',\n",
+    "    'resmlp',\n",
+    "    'mlp_mixer',\n",
+    "]\n",
+    "\n",
+    "dataset_names = [\n",
+    "    'cifar10',\n",
+    "    'cifar100',\n",
+    "    'disco',\n",
+    "    'fashion_mnist',\n",
+    "]\n",
+    "\n",
+    "for dataset_name in dataset_names:\n",
+    "    for branch_type in branch_types:\n",
+    "        flops = []\n",
+    "        for branch_number in range(1, 12):\n",
+    "            tf.keras.backend.clear_session()\n",
+    "            flops.append(get_flops(get_model(dataset_name, branch_type, branch_number)) / 10 ** 9)\n",
+    "        print('###', dataset_name, branch_type)\n",
+    "        print(flops)"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.9"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/calculate_maes.ipynb b/calculate_maes.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..e02b37bd2b0775aefd8871ad72178a6f5f1eb8c7
--- /dev/null
+++ b/calculate_maes.ipynb
@@ -0,0 +1,346 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "SELECTED_GPUS = [7]\n",
+    "\n",
+    "import os\n",
+    "\n",
+    "os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu_number) for gpu_number in SELECTED_GPUS])\n",
+    "\n",
+    "import tensorflow as tf \n",
+    "\n",
+    "tf.get_logger().setLevel('INFO')\n",
+    "\n",
+    "assert len(tf.config.list_physical_devices('GPU')) > 0\n",
+    "\n",
+    "GPUS = tf.config.experimental.list_physical_devices('GPU')\n",
+    "for gpu in GPUS:\n",
+    "    tf.config.experimental.set_memory_growth(gpu, True)\n",
+    "\n",
+    "DISTRIBUTED_STRATEGY = tf.distribute.MirroredStrategy(\n",
+    "    cross_device_ops=tf.distribute.NcclAllReduce(),\n",
+    "    devices=['/gpu:%d' % index for index in range(len(SELECTED_GPUS))]\n",
+    ")\n",
+    "\n",
+    "NUM_GPUS = DISTRIBUTED_STRATEGY.num_replicas_in_sync\n",
+    "\n",
+    "print('Number of devices: {}'.format(NUM_GPUS))\n",
+    "\n",
+    "import math\n",
+    "import numpy as np\n",
+    "import pickle\n",
+    "import random\n",
+    "import string\n",
+    "import sys\n",
+    "from skimage import transform\n",
+    "from tensorflow.python.framework.convert_to_constants import  convert_variables_to_constants_v2_as_graph\n",
+    "from vit_keras import vit\n",
+    "from vit_keras.layers import ClassToken, AddPositionEmbs, MultiHeadSelfAttention, TransformerBlock\n",
+    "\n",
+    "VIDEO_PATCHES = (2, 3)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# from https://github.com/Benjamin-Etheredge/mlp-mixer-keras/blob/main/mlp_mixer_keras/mlp_mixer.py\n",
+    "\n",
+    "class MlpBlock(tf.keras.layers.Layer):\n",
+    "    def __init__(self, dim, hidden_dim, activation=None, **kwargs):\n",
+    "        super(MlpBlock, self).__init__(**kwargs)\n",
+    "\n",
+    "        if activation is None:\n",
+    "            activation = tf.keras.activations.gelu\n",
+    "\n",
+    "        self.dim = dim\n",
+    "        self.hidden_dim = hidden_dim\n",
+    "        self.activation = activation\n",
+    "        self.dense1 = tf.keras.layers.Dense(hidden_dim)\n",
+    "        self.activation = tf.keras.layers.Activation(activation)\n",
+    "        self.dense2 = tf.keras.layers.Dense(dim)\n",
+    "\n",
+    "    def call(self, inputs):\n",
+    "        x = inputs\n",
+    "        x = self.dense1(x)\n",
+    "        x = self.activation(x)\n",
+    "        x = self.dense2(x)\n",
+    "        return x\n",
+    "\n",
+    "    def compute_output_shape(self, input_signature):\n",
+    "        return (input_signature[0], self.dim)\n",
+    "\n",
+    "    def get_config(self):\n",
+    "        config = super(MlpBlock, self).get_config().copy()\n",
+    "        config.update({\n",
+    "            'dim': self.dim,\n",
+    "            'hidden_dim': self.hidden_dim,\n",
+    "            'activation': self.activation,\n",
+    "        })\n",
+    "        return config\n",
+    "\n",
+    "class MixerBlock(tf.keras.layers.Layer):\n",
+    "    def __init__(\n",
+    "        self,\n",
+    "        num_patches,\n",
+    "        channel_dim,\n",
+    "        token_mixer_hidden_dim,\n",
+    "        channel_mixer_hidden_dim=None,\n",
+    "        activation=None,\n",
+    "        **kwargs\n",
+    "    ):\n",
+    "        super(MixerBlock, self).__init__(**kwargs)\n",
+    "\n",
+    "        if activation is None:\n",
+    "            activation = tf.keras.activations.gelu\n",
+    "\n",
+    "        if channel_mixer_hidden_dim is None:\n",
+    "            channel_mixer_hidden_dim = token_mixer_hidden_dim\n",
+    "\n",
+    "        self.num_patches = num_patches\n",
+    "        self.channel_dim = channel_dim\n",
+    "        self.token_mixer_hidden_dim = token_mixer_hidden_dim\n",
+    "        self.channel_mixer_hidden_dim = channel_mixer_hidden_dim\n",
+    "        self.activation = activation\n",
+    "        \n",
+    "        self.norm1 = tf.keras.layers.LayerNormalization(axis=1)\n",
+    "        self.permute1 = tf.keras.layers.Permute((2, 1))\n",
+    "        self.token_mixer = MlpBlock(num_patches, token_mixer_hidden_dim, name='token_mixer')\n",
+    "\n",
+    "        self.permute2 = tf.keras.layers.Permute((2, 1))\n",
+    "        self.norm2 = tf.keras.layers.LayerNormalization(axis=1)\n",
+    "        self.channel_mixer = MlpBlock(channel_dim, channel_mixer_hidden_dim, name='channel_mixer')\n",
+    "\n",
+    "        self.skip_connection1 = tf.keras.layers.Add()\n",
+    "        self.skip_connection2 = tf.keras.layers.Add()\n",
+    "\n",
+    "    def get_config(self):\n",
+    "        config = super(MixerBlock, self).get_config().copy()\n",
+    "        config.update({\n",
+    "            'num_patches': self.num_patches,\n",
+    "            'channel_dim': self.channel_dim,\n",
+    "            'token_mixer_hidden_dim': self.token_mixer_hidden_dim,\n",
+    "            'channel_mixer_hidden_dim': self.channel_mixer_hidden_dim,\n",
+    "            'activation': self.activation,\n",
+    "        })\n",
+    "        return config\n",
+    "\n",
+    "    def call(self, inputs):\n",
+    "        x = inputs\n",
+    "        skip_x = x\n",
+    "        x = self.norm1(x)\n",
+    "        x = self.permute1(x)\n",
+    "        x = self.token_mixer(x)\n",
+    "\n",
+    "        x = self.permute2(x)\n",
+    "\n",
+    "        x = self.skip_connection1([x, skip_x])\n",
+    "        skip_x = x\n",
+    "\n",
+    "        x = self.norm2(x)\n",
+    "        x = self.channel_mixer(x)\n",
+    "\n",
+    "        x = self.skip_connection2([x, skip_x])\n",
+    "\n",
+    "        return x\n",
+    "\n",
+    "    def compute_output_shape(self, input_shape):\n",
+    "        return input_shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_branch_id(branch_number):\n",
+    "    if branch_number == 1:\n",
+    "        return 'transformer_block'\n",
+    "    else:\n",
+    "        return 'transformer_block_%d' % (branch_number - 1)\n",
+    "\n",
+    "def get_model(branch_type, branch_number, version):\n",
+    "    backbone_model = tf.keras.models.load_model('vit_cc_backbone_v2.h5', custom_objects={\n",
+    "        'ClassToken': ClassToken,\n",
+    "        'AddPositionEmbs': AddPositionEmbs,\n",
+    "        'MultiHeadSelfAttention': MultiHeadSelfAttention,\n",
+    "        'TransformerBlock': TransformerBlock,\n",
+    "    })\n",
+    "    y, _ = backbone_model.get_layer(get_branch_id(branch_number)).output\n",
+    "    backend_model = tf.keras.models.Model(\n",
+    "        inputs=backbone_model.get_layer(index=0).input,\n",
+    "        outputs=y\n",
+    "    )\n",
+    "    backend_model._name='backend_model'\n",
+    "    frontend_model = tf.keras.models.load_model(\n",
+    "        'vit_disco_cw_%d_%s_head_precomputed_%s.h5' % (branch_number, branch_type, version),\n",
+    "        custom_objects={\n",
+    "            'ClassToken': ClassToken,\n",
+    "            'AddPositionEmbs': AddPositionEmbs,\n",
+    "            'MultiHeadSelfAttention': MultiHeadSelfAttention,\n",
+    "            'TransformerBlock': TransformerBlock,\n",
+    "            'MlpBlock': MlpBlock,\n",
+    "            'MixerBlock': MixerBlock,\n",
+    "        }\n",
+    "    )\n",
+    "    frontend_model._name = 'frontend_model'\n",
+    "    model = tf.keras.Sequential([\n",
+    "        backend_model,\n",
+    "        frontend_model\n",
+    "    ])\n",
+    "    return model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "DISCO_PATH = 'disco'\n",
+    "CACHE_DIR = os.path.join(DISCO_PATH, 'vit_cache')\n",
+    "\n",
+    "def horizontal_flip(image):\n",
+    "    return np.flip(image, axis=1)\n",
+    "\n",
+    "class CCSequence(tf.keras.utils.Sequence):\n",
+    "    def __init__(self, split, batch_size):\n",
+    "        self.split = split\n",
+    "        self.split_len = sum([\n",
+    "            1 if file_name.startswith(self.split) else 0 for file_name in os.listdir(CACHE_DIR)\n",
+    "        ])\n",
+    "        self.batch_size = batch_size\n",
+    "        self.random_permutation = np.random.permutation(self.split_len)\n",
+    "\n",
+    "    def __len__(self):\n",
+    "        return math.ceil(self.split_len / self.batch_size)\n",
+    "\n",
+    "    def on_epoch_end(self):\n",
+    "        self.random_permutation = np.random.permutation(self.split_len)\n",
+    "\n",
+    "    def __getitem__(self, index):\n",
+    "        spectrograms = []\n",
+    "        images = []\n",
+    "        density_maps = []\n",
+    "        if self.split == 'test':\n",
+    "            index_generator = range(\n",
+    "                index * self.batch_size,\n",
+    "                min((index + 1) * self.batch_size, self.split_len)\n",
+    "            )\n",
+    "        else:\n",
+    "            index_generator = self.random_permutation[index * self.batch_size:(index + 1) * self.batch_size]\n",
+    "        for random_index in index_generator:\n",
+    "            all_path = os.path.join(\n",
+    "                CACHE_DIR,\n",
+    "                '%s_%d.pkl' % (self.split, random_index)\n",
+    "            )\n",
+    "            with open(all_path, 'rb') as all_file:\n",
+    "                data = pickle.load(all_file)\n",
+    "                if self.split == 'train' and random.random() < 0.5:  # flip augmentation\n",
+    "                    images.append(horizontal_flip(data['image']))\n",
+    "                else:\n",
+    "                    images.append(data['image'])\n",
+    "                density_maps.append(np.sum(data['density_map']))\n",
+    "\n",
+    "        return np.array(images), np.array(density_maps)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_mae(branch_type, branch_number, version):\n",
+    "    tf.keras.backend.clear_session()\n",
+    "    test_sequence = CCSequence('test', 32)\n",
+    "    model = get_model(branch_type, branch_number, version)\n",
+    "    gt = None\n",
+    "    out = None\n",
+    "    for i, (images, density_maps) in enumerate(test_sequence):\n",
+    "        sys.stdout.write('\\r%d' % (i + 1))\n",
+    "        sys.stdout.flush()\n",
+    "        if gt is not None:\n",
+    "            gt = np.concatenate((gt, density_maps))\n",
+    "        else:\n",
+    "            gt = density_maps\n",
+    "        if out is not None:\n",
+    "            out = np.concatenate((out, model(images).numpy().flatten()))\n",
+    "        else:\n",
+    "            out = model(images).numpy().flatten()\n",
+    "    print()  # newline\n",
+    "    mae = []\n",
+    "    img_patches = VIDEO_PATCHES[0] * VIDEO_PATCHES[1]\n",
+    "    for i in range(0, gt.shape[0], img_patches):\n",
+    "        gt_subset = gt[i:i + img_patches]\n",
+    "        out_subset = out[i:i + img_patches]\n",
+    "        mae.append(np.abs(np.sum(gt_subset) - np.sum(out_subset)))\n",
+    "    return np.mean(np.array(mae))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_maes():\n",
+    "    for branch_type in [\n",
+    "        'vit',\n",
+    "        'mlp',\n",
+    "        'cnn_ignore',\n",
+    "        'cnn_add',\n",
+    "        'cnn_project',\n",
+    "        'resmlp',\n",
+    "        'mlp_mixer',\n",
+    "    ]:\n",
+    "        maes = []\n",
+    "        for branch_number in range(1, 12):\n",
+    "            mae = get_mae(branch_type, branch_number, 'v1')\n",
+    "            print(mae)\n",
+    "            maes.append(mae)\n",
+    "        print('###', branch_type, maes)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "get_maes()"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.9"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/cw.ipynb b/cw.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..bfd3b2164607f36b0232bc3f0f3babb8d2ff99c5
--- /dev/null
+++ b/cw.ipynb
@@ -0,0 +1,579 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "SELECTED_GPUS = [4, 5]\n",
+    "\n",
+    "import os\n",
+    "\n",
+    "os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu_number) for gpu_number in SELECTED_GPUS])\n",
+    "\n",
+    "import tensorflow as tf \n",
+    "\n",
+    "tf.get_logger().setLevel('INFO')\n",
+    "\n",
+    "assert len(tf.config.list_physical_devices('GPU')) > 0\n",
+    "\n",
+    "GPUS = tf.config.experimental.list_physical_devices('GPU')\n",
+    "for gpu in GPUS:\n",
+    "    tf.config.experimental.set_memory_growth(gpu, True)\n",
+    "\n",
+    "DISTRIBUTED_STRATEGY = tf.distribute.MirroredStrategy(\n",
+    "    cross_device_ops=tf.distribute.NcclAllReduce(),\n",
+    "    devices=['/gpu:%d' % index for index in range(len(SELECTED_GPUS))]\n",
+    ")\n",
+    "\n",
+    "NUM_GPUS = DISTRIBUTED_STRATEGY.num_replicas_in_sync\n",
+    "\n",
+    "print('Number of devices: {}'.format(NUM_GPUS))\n",
+    "\n",
+    "import json\n",
+    "import math\n",
+    "import numpy as np\n",
+    "import pickle\n",
+    "import sys\n",
+    "from skimage import transform\n",
+    "from tensorflow.python.framework.convert_to_constants import  convert_variables_to_constants_v2_as_graph\n",
+    "from vit_keras import vit\n",
+    "from vit_keras.layers import ClassToken, AddPositionEmbs, MultiHeadSelfAttention, TransformerBlock\n",
+    "\n",
+    "IMAGE_SIZE = 384\n",
+    "PATCH_SIZE = 16\n",
+    "NUM_PATCHES = (384 // PATCH_SIZE) ** 2 + 1\n",
+    "HIDDEN_DIM = 768\n",
+    "VIDEO_PATCHES = (2, 3)\n",
+    "VIDEO_SIZE = (VIDEO_PATCHES[0] * IMAGE_SIZE, VIDEO_PATCHES[1] * IMAGE_SIZE)\n",
+    "MLP_DIM = 3072  # ResMLP\n",
+    "CHANNELS_MLP_DIM = 3072  # MLP-Mixer\n",
+    "TOKENS_MLP_DIM = 384  # MLP-Mixer\n",
+    "PRECOMPUTE_DIR = 'precompute'\n",
+    "PRECOMPUTE_FASHION_MNIST_DIR = os.path.join(PRECOMPUTE_DIR, 'fashion_mnist')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_params(model):\n",
+    "    string_list = []\n",
+    "    model.summary(print_fn=lambda x: string_list.append(x))\n",
+    "    for string in string_list:\n",
+    "        if string.startswith('Trainable params:'):\n",
+    "            return int(string.split()[-1].replace(',', ''))\n",
+    "    return None\n",
+    "\n",
+    "def get_flops(model):\n",
+    "    \"\"\"\n",
+    "    from https://github.com/tensorflow/tensorflow/issues/32809#issuecomment-768977280\n",
+    "    \"\"\"\n",
+    "    concrete = tf.function(lambda inputs: model(inputs))\n",
+    "    concrete_func = concrete.get_concrete_function(\n",
+    "        [tf.TensorSpec([1, *inputs.shape[1:]]) for inputs in model.inputs])\n",
+    "    frozen_func, graph_def = convert_variables_to_constants_v2_as_graph(concrete_func)\n",
+    "    with tf.Graph().as_default() as graph:\n",
+    "        tf.graph_util.import_graph_def(graph_def, name='')\n",
+    "        run_meta = tf.compat.v1.RunMetadata()\n",
+    "        opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()\n",
+    "        flops = tf.compat.v1.profiler.profile(graph=graph, run_meta=run_meta, cmd=\"op\", options=opts)\n",
+    "        return flops.total_float_ops"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# from https://github.com/leondgarse/Keras_mlp/blob/main/res_mlp.py\n",
+    "\n",
+    "def channel_affine(inputs, use_bias=True, weight_init_value=1, name=''):\n",
+    "    ww_init = tfkeras.initializers.Constant(weight_init_value) if weight_init_value != 1 else 'ones'\n",
+    "    nn = tf.keras.backend.expand_dims(inputs, 1)\n",
+    "    nn = tf.keras.layers.DepthwiseConv2D(1, depthwise_initializer=ww_init, use_bias=use_bias, name=name + 'affine')(nn)\n",
+    "    return tf.keras.backend.squeeze(nn, 1)\n",
+    "\n",
+    "def mlp_block(inputs, mlp_dim, activation='gelu', name=''):\n",
+    "    affine_inputs = channel_affine(inputs, use_bias=True, name=name + '1_')\n",
+    "    nn = tf.keras.layers.Permute((2, 1), name=name + 'permute_1')(affine_inputs)\n",
+    "    nn = tf.keras.layers.Dense(nn.shape[-1], name=name + 'dense_1')(nn)\n",
+    "    nn = tf.keras.layers.Permute((2, 1), name=name + 'permute_2')(nn)\n",
+    "    nn = channel_affine(nn, use_bias=False, name=name + '1_gamma_')\n",
+    "    skip_conn = tf.keras.layers.Add(name=name + 'add_1')([nn, affine_inputs])\n",
+    "\n",
+    "    affine_skip = channel_affine(skip_conn, use_bias=True, name=name + '2_')\n",
+    "    nn = tf.keras.layers.Dense(mlp_dim, name=name + 'dense_2_1')(affine_skip)\n",
+    "    nn = tf.keras.layers.Activation(activation, name=name + 'gelu')(nn)\n",
+    "    nn = tf.keras.layers.Dense(inputs.shape[-1], name=name + 'dense_2_2')(nn)\n",
+    "    nn = channel_affine(nn, use_bias=False, name=name + '2_gamma_')\n",
+    "    nn = tf.keras.layers.Add(name=name + 'add_2')([nn, affine_skip])\n",
+    "    return nn"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# from https://github.com/Benjamin-Etheredge/mlp-mixer-keras/blob/main/mlp_mixer_keras/mlp_mixer.py\n",
+    "\n",
+    "class MlpBlock(tf.keras.layers.Layer):\n",
+    "    def __init__(self, dim, hidden_dim, activation=None, **kwargs):\n",
+    "        super(MlpBlock, self).__init__(**kwargs)\n",
+    "\n",
+    "        if activation is None:\n",
+    "            activation = tf.keras.activations.gelu\n",
+    "\n",
+    "        self.dim = dim\n",
+    "        self.hidden_dim = hidden_dim\n",
+    "        self.activation = activation\n",
+    "        self.dense1 = tf.keras.layers.Dense(hidden_dim)\n",
+    "        self.activation = tf.keras.layers.Activation(activation)\n",
+    "        self.dense2 = tf.keras.layers.Dense(dim)\n",
+    "\n",
+    "    def call(self, inputs):\n",
+    "        x = inputs\n",
+    "        x = self.dense1(x)\n",
+    "        x = self.activation(x)\n",
+    "        x = self.dense2(x)\n",
+    "        return x\n",
+    "\n",
+    "    def compute_output_shape(self, input_signature):\n",
+    "        return (input_signature[0], self.dim)\n",
+    "\n",
+    "    def get_config(self):\n",
+    "        config = super(MlpBlock, self).get_config().copy()\n",
+    "        config.update({\n",
+    "            'dim': self.dim,\n",
+    "            'hidden_dim': self.hidden_dim,\n",
+    "            'activation': self.activation,\n",
+    "        })\n",
+    "        return config\n",
+    "\n",
+    "class MixerBlock(tf.keras.layers.Layer):\n",
+    "    def __init__(\n",
+    "        self,\n",
+    "        num_patches,\n",
+    "        channel_dim,\n",
+    "        token_mixer_hidden_dim,\n",
+    "        channel_mixer_hidden_dim=None,\n",
+    "        activation=None,\n",
+    "        **kwargs\n",
+    "    ):\n",
+    "        super(MixerBlock, self).__init__(**kwargs)\n",
+    "\n",
+    "        if activation is None:\n",
+    "            activation = tf.keras.activations.gelu\n",
+    "\n",
+    "        if channel_mixer_hidden_dim is None:\n",
+    "            channel_mixer_hidden_dim = token_mixer_hidden_dim\n",
+    "\n",
+    "        self.num_patches = num_patches\n",
+    "        self.channel_dim = channel_dim\n",
+    "        self.token_mixer_hidden_dim = token_mixer_hidden_dim\n",
+    "        self.channel_mixer_hidden_dim = channel_mixer_hidden_dim\n",
+    "        self.activation = activation\n",
+    "        \n",
+    "        self.norm1 = tf.keras.layers.LayerNormalization(axis=1)\n",
+    "        self.permute1 = tf.keras.layers.Permute((2, 1))\n",
+    "        self.token_mixer = MlpBlock(num_patches, token_mixer_hidden_dim, name='token_mixer')\n",
+    "\n",
+    "        self.permute2 = tf.keras.layers.Permute((2, 1))\n",
+    "        self.norm2 = tf.keras.layers.LayerNormalization(axis=1)\n",
+    "        self.channel_mixer = MlpBlock(channel_dim, channel_mixer_hidden_dim, name='channel_mixer')\n",
+    "\n",
+    "        self.skip_connection1 = tf.keras.layers.Add()\n",
+    "        self.skip_connection2 = tf.keras.layers.Add()\n",
+    "\n",
+    "    def get_config(self):\n",
+    "        config = super(MixerBlock, self).get_config().copy()\n",
+    "        config.update({\n",
+    "            'num_patches': self.num_patches,\n",
+    "            'channel_dim': self.channel_dim,\n",
+    "            'token_mixer_hidden_dim': self.token_mixer_hidden_dim,\n",
+    "            'channel_mixer_hidden_dim': self.channel_mixer_hidden_dim,\n",
+    "            'activation': self.activation,\n",
+    "        })\n",
+    "        return config\n",
+    "\n",
+    "    def call(self, inputs):\n",
+    "        x = inputs\n",
+    "        skip_x = x\n",
+    "        x = self.norm1(x)\n",
+    "        x = self.permute1(x)\n",
+    "        x = self.token_mixer(x)\n",
+    "\n",
+    "        x = self.permute2(x)\n",
+    "\n",
+    "        x = self.skip_connection1([x, skip_x])\n",
+    "        skip_x = x\n",
+    "\n",
+    "        x = self.norm2(x)\n",
+    "        x = self.channel_mixer(x)\n",
+    "\n",
+    "        x = self.skip_connection2([x, skip_x])\n",
+    "\n",
+    "        return x\n",
+    "\n",
+    "    def compute_output_shape(self, input_shape):\n",
+    "        return input_shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_model(branch_number, head_type):\n",
+    "    model_input = tf.keras.Input(shape=(NUM_PATCHES, HIDDEN_DIM))\n",
+    "    y = model_input\n",
+    "    if head_type == 'resmlp':\n",
+    "        y = mlp_block(y, mlp_dim=MLP_DIM, name='mlp_mixer')\n",
+    "        y = tf.keras.layers.GlobalAveragePooling1D()(y)\n",
+    "    elif head_type == 'mlp':\n",
+    "        y = tf.keras.layers.LayerNormalization(\n",
+    "            epsilon=1e-6,\n",
+    "            name='Transformer/encoder_norm_x'\n",
+    "        )(y)\n",
+    "        y = tf.keras.layers.Lambda(lambda v: v[:, 0], name='ExtractToken_x')(y)\n",
+    "    elif head_type == 'vit':\n",
+    "        y, _ = TransformerBlock(\n",
+    "            num_heads=12,\n",
+    "            mlp_dim=3072,\n",
+    "            dropout=0.1,\n",
+    "            name='Transformer/encoderblock_x'\n",
+    "        )(y)\n",
+    "        y = tf.keras.layers.LayerNormalization(\n",
+    "            epsilon=1e-6,\n",
+    "            name='Transformer/encoder_norm_x'\n",
+    "        )(y)\n",
+    "        y = tf.keras.layers.Lambda(lambda v: v[:, 0], name='ExtractToken_x')(y)\n",
+    "    elif head_type == 'cnn_ignore':\n",
+    "        channels = HIDDEN_DIM\n",
+    "        width = height = IMAGE_SIZE // PATCH_SIZE\n",
+    "        y = tf.keras.layers.Lambda(lambda v: v[:, 1:], name='RemoveToken')(y)\n",
+    "        y = tf.keras.layers.Reshape((width, height, channels), name='cnn_reshape')(y)\n",
+    "        y = tf.keras.layers.Conv2D(\n",
+    "            filters=16,\n",
+    "            kernel_size=(3, 3),\n",
+    "            activation='elu',\n",
+    "            padding='same'\n",
+    "        )(y)\n",
+    "        y = tf.keras.layers.MaxPool2D(pool_size=(2, 2))(y)\n",
+    "        y = tf.keras.layers.Flatten()(y)\n",
+    "    elif head_type == 'cnn_add':    \n",
+    "        channels = HIDDEN_DIM\n",
+    "        width = height = IMAGE_SIZE // PATCH_SIZE\n",
+    "\n",
+    "        y1 = tf.keras.layers.Lambda(lambda v: v[:, 1:], name='RemoveToken_x')(y)\n",
+    "        y1 = tf.keras.layers.Reshape((width, height, channels), name='cnn_reshape')(y1)\n",
+    "\n",
+    "        y2 = tf.keras.layers.Lambda(lambda v: v[:, 0], name='ExtractToken_x')(y)\n",
+    "        y2 = tf.keras.layers.RepeatVector(width * height)(y2)\n",
+    "        y2 = tf.keras.layers.Reshape((width, height, channels), name='cls_reshape')(y2)\n",
+    "\n",
+    "        y = tf.keras.layers.Add()([y1, y2])\n",
+    "\n",
+    "        y = tf.keras.layers.Conv2D(\n",
+    "            filters=16,\n",
+    "            kernel_size=(3, 3),\n",
+    "            activation='elu',\n",
+    "            padding='same'\n",
+    "        )(y)\n",
+    "        y = tf.keras.layers.MaxPool2D(pool_size=(2, 2))(y)\n",
+    "        y = tf.keras.layers.Flatten()(y)\n",
+    "    elif head_type == 'cnn_project':\n",
+    "        channels = HIDDEN_DIM\n",
+    "        width = height = IMAGE_SIZE // PATCH_SIZE\n",
+    "\n",
+    "        y1 = tf.keras.layers.Lambda(lambda v: v[:, 1:], name='RemoveToken_x')(y)\n",
+    "        y1 = tf.keras.layers.Reshape((width, height, channels), name='cnn_reshape')(y1)\n",
+    "\n",
+    "        y2 = tf.keras.layers.Lambda(lambda v: v[:, 0], name='ExtractToken_x')(y)\n",
+    "        y2 = tf.keras.layers.RepeatVector(width * height)(y2)\n",
+    "        y2 = tf.keras.layers.Reshape((width, height, channels), name='cls_reshape')(y2)\n",
+    "\n",
+    "        y = tf.keras.layers.Concatenate()([y1, y2])\n",
+    "\n",
+    "        y = tf.keras.layers.Conv2D(\n",
+    "            filters=16,\n",
+    "            kernel_size=(3, 3),\n",
+    "            activation='elu',\n",
+    "            padding='same'\n",
+    "        )(y)\n",
+    "        y = tf.keras.layers.MaxPool2D(pool_size=(2, 2))(y)\n",
+    "        y = tf.keras.layers.Flatten()(y)\n",
+    "    elif head_type == 'mlp_mixer':\n",
+    "        num_patches = (IMAGE_SIZE // PATCH_SIZE) ** 2 + 1\n",
+    "        y = MixerBlock(\n",
+    "            num_patches=num_patches,\n",
+    "            channel_dim=HIDDEN_DIM,\n",
+    "            token_mixer_hidden_dim=TOKENS_MLP_DIM,\n",
+    "            channel_mixer_hidden_dim=CHANNELS_MLP_DIM\n",
+    "        )(y)\n",
+    "        y = tf.keras.layers.GlobalAveragePooling1D()(y)\n",
+    "\n",
+    "    # MLP head\n",
+    "    initializer = tf.keras.initializers.he_normal()\n",
+    "    regularizer = tf.keras.regularizers.l2()\n",
+    "    y = tf.keras.layers.Dense(\n",
+    "        units=256,\n",
+    "        activation='elu',\n",
+    "        kernel_initializer=initializer,\n",
+    "        kernel_regularizer=regularizer\n",
+    "    )(y)\n",
+    "    y = tf.keras.layers.Dropout(0.5)(y)\n",
+    "    y = tf.keras.layers.Dense(\n",
+    "        units=256,\n",
+    "        activation='elu',\n",
+    "        kernel_initializer=initializer,\n",
+    "        kernel_regularizer=regularizer\n",
+    "    )(y)\n",
+    "    y = tf.keras.layers.Dropout(0.5)(y)\n",
+    "    y = tf.keras.layers.Dense(\n",
+    "        units=10,\n",
+    "        activation='softmax',\n",
+    "        kernel_initializer=initializer,\n",
+    "        kernel_regularizer=regularizer\n",
+    "    )(y)\n",
+    "\n",
+    "    model = tf.keras.models.Model(\n",
+    "        inputs=model_input,\n",
+    "        outputs=y\n",
+    "    )\n",
+    "\n",
+    "    model.compile(\n",
+    "        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),\n",
+    "        loss='categorical_crossentropy',\n",
+    "        metrics=['accuracy']\n",
+    "    )\n",
+    "\n",
+    "    return model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class FashionMNISTSequence(tf.keras.utils.Sequence):\n",
+    "    def __init__(self, split, branch_number, batch_size):\n",
+    "        self.split = split\n",
+    "        self.branch_number = branch_number\n",
+    "        self.batch_size = batch_size * NUM_GPUS\n",
+    "        self.dir = PRECOMPUTE_FASHION_MNIST_DIR\n",
+    "        self.count = sum([\n",
+    "            1 if file_name.startswith('%s_branch%d_sample' % (\n",
+    "                self.split,\n",
+    "                self.branch_number\n",
+    "            )) else 0 for file_name in os.listdir(self.dir)\n",
+    "        ])\n",
+    "        self.random_permutation = np.random.permutation(self.count)\n",
+    "\n",
+    "    def __len__(self):\n",
+    "        return math.ceil(self.count / self.batch_size)\n",
+    "\n",
+    "    def on_epoch_end(self):\n",
+    "        self.random_permutation = np.random.permutation(self.count)\n",
+    "\n",
+    "    def __getitem__(self, index):\n",
+    "        features = []\n",
+    "        labels = []\n",
+    "        for i in self.random_permutation[index * self.batch_size:(index + 1) * self.batch_size]:\n",
+    "            cache_file_path = os.path.join(\n",
+    "                self.dir,\n",
+    "                '%s_branch%d_sample%d.pkl' % (self.split, self.branch_number, i)\n",
+    "            )\n",
+    "            with open(cache_file_path, 'rb') as cache_file:\n",
+    "                contents = pickle.load(cache_file)\n",
+    "                features.append(contents['features'])\n",
+    "                labels.append(contents['label'])\n",
+    "        return np.array(features), np.array(labels)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def train(max_epochs, branch_number, head_type, batch_size=64):\n",
+    "    tf.keras.backend.clear_session()\n",
+    "\n",
+    "    with DISTRIBUTED_STRATEGY.scope():\n",
+    "        model = get_model(branch_number, head_type)\n",
+    "        branch_params = get_params(model) / 10 ** 6\n",
+    "        total_flops = get_flops(model) / 10 ** 9\n",
+    "\n",
+    "    lr_reduce = tf.keras.callbacks.ReduceLROnPlateau(\n",
+    "        monitor='val_accuracy',\n",
+    "        factor=0.6,\n",
+    "        patience=2,\n",
+    "        verbose=1,\n",
+    "        mode='max',\n",
+    "        min_lr=1e-7\n",
+    "    )\n",
+    "\n",
+    "    early_stop = tf.keras.callbacks.EarlyStopping(\n",
+    "        monitor='val_accuracy',\n",
+    "        patience=5,\n",
+    "        verbose=1,\n",
+    "        mode='max'\n",
+    "    )\n",
+    "\n",
+    "    save_model_checkpoint_file = 'vit_shtb_cw_%d_%s_head_precomputed_v1.h5' % (branch_number, head_type)\n",
+    "\n",
+    "    checkpoint = tf.keras.callbacks.ModelCheckpoint(\n",
+    "        save_model_checkpoint_file,\n",
+    "        monitor='val_accuracy',\n",
+    "        verbose=1,\n",
+    "        save_weights_only=False,\n",
+    "        save_best_only=True,\n",
+    "        mode='max',\n",
+    "        save_freq='epoch'\n",
+    "    )\n",
+    "\n",
+    "    history = model.fit(\n",
+    "        FashionMNISTSequence('train', branch_number, batch_size),\n",
+    "        validation_data=FashionMNISTSequence('val', branch_number, batch_size),\n",
+    "        epochs=max_epochs,\n",
+    "        shuffle=True,\n",
+    "        callbacks=[\n",
+    "            lr_reduce,\n",
+    "            early_stop,\n",
+    "            checkpoint\n",
+    "        ],\n",
+    "        verbose=1\n",
+    "    )\n",
+    "\n",
+    "    test_accuracy = model.evaluate(FashionMNISTSequence('test', branch_number, batch_size))\n",
+    "\n",
+    "    return model, test_accuracy, branch_params, total_flops"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def save_results(results, results_path):\n",
+    "    with open(results_path, 'w') as results_file:\n",
+    "        results_file.write(json.dumps(results))\n",
+    "\n",
+    "def print_results(results_path):\n",
+    "    with open(results_path, 'r') as results_file:\n",
+    "        print(json.loads(results_file.read()))\n",
+    "\n",
+    "def get_results_path(head_type):\n",
+    "    return 'shtb_%s.json' % head_type\n",
+    "\n",
+    "def run_experiment(head_type):\n",
+    "    results = []\n",
+    "    for i in reversed(range(1, 12)):\n",
+    "        model, test_accuracy, branch_params, total_flops = train(100, i, head_type)        \n",
+    "        results.append({\n",
+    "            'exit': i,\n",
+    "            'test_accuracy': test_accuracy,\n",
+    "            'branch_params': branch_params,\n",
+    "            'total_flops': total_flops,\n",
+    "        })\n",
+    "        results_path = get_results_path(head_type)\n",
+    "        save_results(results, results_path)\n",
+    "        print_results(results_path)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "run_experiment('vit')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "run_experiment('resmlp')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "run_experiment('mlp_mixer')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "run_experiment('mlp')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "run_experiment('cnn_ignore')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "run_experiment('cnn_add')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "run_experiment('cnn_project')"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.9"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/ee.ipynb b/ee.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..ea8125408d211d17d2dfbc1bd33f3889699f2af9
--- /dev/null
+++ b/ee.ipynb
@@ -0,0 +1,677 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "SELECTED_GPUS = [4, 5, 6, 7]\n",
+    "\n",
+    "import os\n",
+    "\n",
+    "os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu_number) for gpu_number in SELECTED_GPUS])\n",
+    "\n",
+    "import tensorflow as tf \n",
+    "\n",
+    "tf.get_logger().setLevel('INFO')\n",
+    "\n",
+    "assert len(tf.config.list_physical_devices('GPU')) > 0\n",
+    "\n",
+    "GPUS = tf.config.experimental.list_physical_devices('GPU')\n",
+    "for gpu in GPUS:\n",
+    "    tf.config.experimental.set_memory_growth(gpu, True)\n",
+    "\n",
+    "DISTRIBUTED_STRATEGY = tf.distribute.MirroredStrategy(\n",
+    "    cross_device_ops=tf.distribute.NcclAllReduce(),\n",
+    "    devices=['/gpu:%d' % index for index in range(len(SELECTED_GPUS))]\n",
+    ")\n",
+    "\n",
+    "NUM_GPUS = DISTRIBUTED_STRATEGY.num_replicas_in_sync\n",
+    "\n",
+    "print('Number of devices: {}'.format(NUM_GPUS))\n",
+    "\n",
+    "import math\n",
+    "import numpy as np\n",
+    "import pickle\n",
+    "import random\n",
+    "import sys\n",
+    "from skimage import transform\n",
+    "from tensorflow.python.framework.convert_to_constants import  convert_variables_to_constants_v2_as_graph\n",
+    "from vit_keras import vit\n",
+    "from vit_keras.layers import ClassToken, AddPositionEmbs, MultiHeadSelfAttention, TransformerBlock\n",
+    "\n",
+    "IMAGE_SIZE = 384\n",
+    "HIDDEN_DIM = 768\n",
+    "PATCH_SIZE = 16\n",
+    "MLP_DIM = 3072  # ResMLP\n",
+    "CHANNELS_MLP_DIM = 3072  # MLP-Mixer\n",
+    "TOKENS_MLP_DIM = 384  # MLP-Mixer\n",
+    "VIDEO_PATCHES = (2, 3)  # how many sub-images there are in each image for crowd counting\n",
+    "VIDEO_SIZE = (VIDEO_PATCHES[0] * IMAGE_SIZE, VIDEO_PATCHES[1] * IMAGE_SIZE)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_params(model):\n",
+    "    string_list = []\n",
+    "    model.summary(print_fn=lambda x: string_list.append(x))\n",
+    "    for string in string_list:\n",
+    "        if string.startswith('Trainable params:'):\n",
+    "            return int(string.split()[-1].replace(',', ''))\n",
+    "    return None\n",
+    "\n",
+    "def get_flops(model):\n",
+    "    \"\"\"\n",
+    "    from https://github.com/tensorflow/tensorflow/issues/32809#issuecomment-768977280\n",
+    "    \"\"\"\n",
+    "    concrete = tf.function(lambda inputs: model(inputs))\n",
+    "    concrete_func = concrete.get_concrete_function(\n",
+    "        [tf.TensorSpec([1, *inputs.shape[1:]]) for inputs in model.inputs])\n",
+    "    frozen_func, graph_def = convert_variables_to_constants_v2_as_graph(concrete_func)\n",
+    "    with tf.Graph().as_default() as graph:\n",
+    "        tf.graph_util.import_graph_def(graph_def, name='')\n",
+    "        run_meta = tf.compat.v1.RunMetadata()\n",
+    "        opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()\n",
+    "        flops = tf.compat.v1.profiler.profile(graph=graph, run_meta=run_meta, cmd=\"op\", options=opts)\n",
+    "        return flops.total_float_ops"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# from https://github.com/leondgarse/Keras_mlp/blob/main/res_mlp.py\n",
+    "\n",
+    "def channel_affine(inputs, use_bias=True, weight_init_value=1, name=''):\n",
+    "    ww_init = tfkeras.initializers.Constant(weight_init_value) if weight_init_value != 1 else 'ones'\n",
+    "    nn = tf.keras.backend.expand_dims(inputs, 1)\n",
+    "    nn = tf.keras.layers.DepthwiseConv2D(1, depthwise_initializer=ww_init, use_bias=use_bias, name=name + 'affine')(nn)\n",
+    "    return tf.keras.backend.squeeze(nn, 1)\n",
+    "\n",
+    "def mlp_block(inputs, mlp_dim, activation='gelu', name=''):\n",
+    "    affine_inputs = channel_affine(inputs, use_bias=True, name=name + '1_')\n",
+    "    nn = tf.keras.layers.Permute((2, 1), name=name + 'permute_1')(affine_inputs)\n",
+    "    nn = tf.keras.layers.Dense(nn.shape[-1], name=name + 'dense_1')(nn)\n",
+    "    nn = tf.keras.layers.Permute((2, 1), name=name + 'permute_2')(nn)\n",
+    "    nn = channel_affine(nn, use_bias=False, name=name + '1_gamma_')\n",
+    "    skip_conn = tf.keras.layers.Add(name=name + 'add_1')([nn, affine_inputs])\n",
+    "\n",
+    "    affine_skip = channel_affine(skip_conn, use_bias=True, name=name + '2_')\n",
+    "    nn = tf.keras.layers.Dense(mlp_dim, name=name + 'dense_2_1')(affine_skip)\n",
+    "    nn = tf.keras.layers.Activation(activation, name=name + 'gelu')(nn)\n",
+    "    nn = tf.keras.layers.Dense(inputs.shape[-1], name=name + 'dense_2_2')(nn)\n",
+    "    nn = channel_affine(nn, use_bias=False, name=name + '2_gamma_')\n",
+    "    nn = tf.keras.layers.Add(name=name + 'add_2')([nn, affine_skip])\n",
+    "    return nn"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# from https://github.com/Benjamin-Etheredge/mlp-mixer-keras/blob/main/mlp_mixer_keras/mlp_mixer.py\n",
+    "\n",
+    "class MlpBlock(tf.keras.layers.Layer):\n",
+    "    def __init__(self, dim, hidden_dim, activation=None, **kwargs):\n",
+    "        super(MlpBlock, self).__init__(**kwargs)\n",
+    "\n",
+    "        if activation is None:\n",
+    "            activation = tf.keras.activations.gelu\n",
+    "\n",
+    "        self.dim = dim\n",
+    "        self.hidden_dim = hidden_dim\n",
+    "        self.activation = activation\n",
+    "        self.dense1 = tf.keras.layers.Dense(hidden_dim)\n",
+    "        self.activation = tf.keras.layers.Activation(activation)\n",
+    "        self.dense2 = tf.keras.layers.Dense(dim)\n",
+    "\n",
+    "    def call(self, inputs):\n",
+    "        x = inputs\n",
+    "        x = self.dense1(x)\n",
+    "        x = self.activation(x)\n",
+    "        x = self.dense2(x)\n",
+    "        return x\n",
+    "\n",
+    "    def compute_output_shape(self, input_signature):\n",
+    "        return (input_signature[0], self.dim)\n",
+    "\n",
+    "    def get_config(self):\n",
+    "        config = super(MlpBlock, self).get_config().copy()\n",
+    "        config.update({\n",
+    "            'dim': self.dim,\n",
+    "            'hidden_dim': self.hidden_dim,\n",
+    "            'activation': self.activation,\n",
+    "        })\n",
+    "        return config\n",
+    "\n",
+    "class MixerBlock(tf.keras.layers.Layer):\n",
+    "    def __init__(\n",
+    "        self,\n",
+    "        num_patches,\n",
+    "        channel_dim,\n",
+    "        token_mixer_hidden_dim,\n",
+    "        channel_mixer_hidden_dim=None,\n",
+    "        activation=None,\n",
+    "        **kwargs\n",
+    "    ):\n",
+    "        super(MixerBlock, self).__init__(**kwargs)\n",
+    "\n",
+    "        if activation is None:\n",
+    "            activation = tf.keras.activations.gelu\n",
+    "\n",
+    "        if channel_mixer_hidden_dim is None:\n",
+    "            channel_mixer_hidden_dim = token_mixer_hidden_dim\n",
+    "\n",
+    "        self.num_patches = num_patches\n",
+    "        self.channel_dim = channel_dim\n",
+    "        self.token_mixer_hidden_dim = token_mixer_hidden_dim\n",
+    "        self.channel_mixer_hidden_dim = channel_mixer_hidden_dim\n",
+    "        self.activation = activation\n",
+    "        \n",
+    "        self.norm1 = tf.keras.layers.LayerNormalization(axis=1)\n",
+    "        self.permute1 = tf.keras.layers.Permute((2, 1))\n",
+    "        self.token_mixer = MlpBlock(num_patches, token_mixer_hidden_dim, name='token_mixer')\n",
+    "\n",
+    "        self.permute2 = tf.keras.layers.Permute((2, 1))\n",
+    "        self.norm2 = tf.keras.layers.LayerNormalization(axis=1)\n",
+    "        self.channel_mixer = MlpBlock(channel_dim, channel_mixer_hidden_dim, name='channel_mixer')\n",
+    "\n",
+    "        self.skip_connection1 = tf.keras.layers.Add()\n",
+    "        self.skip_connection2 = tf.keras.layers.Add()\n",
+    "\n",
+    "    def get_config(self):\n",
+    "        config = super(MixerBlock, self).get_config().copy()\n",
+    "        config.update({\n",
+    "            'num_patches': self.num_patches,\n",
+    "            'channel_dim': self.channel_dim,\n",
+    "            'token_mixer_hidden_dim': self.token_mixer_hidden_dim,\n",
+    "            'channel_mixer_hidden_dim': self.channel_mixer_hidden_dim,\n",
+    "            'activation': self.activation,\n",
+    "        })\n",
+    "        return config\n",
+    "\n",
+    "    def call(self, inputs):\n",
+    "        x = inputs\n",
+    "        skip_x = x\n",
+    "        x = self.norm1(x)\n",
+    "        x = self.permute1(x)\n",
+    "        x = self.token_mixer(x)\n",
+    "\n",
+    "        x = self.permute2(x)\n",
+    "\n",
+    "        x = self.skip_connection1([x, skip_x])\n",
+    "        skip_x = x\n",
+    "\n",
+    "        x = self.norm2(x)\n",
+    "        x = self.channel_mixer(x)\n",
+    "\n",
+    "        x = self.skip_connection2([x, skip_x])\n",
+    "\n",
+    "        return x\n",
+    "\n",
+    "    def compute_output_shape(self, input_shape):\n",
+    "        return input_shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_branch_id(branch_number):\n",
+    "    if branch_number == 1:\n",
+    "        return 'transformer_block'\n",
+    "    else:\n",
+    "        return 'transformer_block_%d' % (branch_number - 1)\n",
+    "\n",
+    "def get_model(branch_numbers, head_type, dataset):\n",
+    "    if dataset == 'cifar10':\n",
+    "        model_file_name = 'vit_cifar10_v1.h5'\n",
+    "    elif dataset == 'cifar100':\n",
+    "        model_file_name = 'vit_cifar100_v1.h5'\n",
+    "    elif dataset == 'disco':\n",
+    "        model_file_name = 'vit_cc_backbone_v2.h5'\n",
+    "    else:\n",
+    "        model_file_name = None\n",
+    "    \n",
+    "    backbone_model = tf.keras.models.load_model(model_file_name, custom_objects={\n",
+    "        'ClassToken': ClassToken,\n",
+    "        'AddPositionEmbs': AddPositionEmbs,\n",
+    "        'MultiHeadSelfAttention': MultiHeadSelfAttention,\n",
+    "        'TransformerBlock': TransformerBlock,\n",
+    "    })\n",
+    "\n",
+    "    outputs = []\n",
+    "    for i, branch_number in enumerate(branch_numbers):\n",
+    "        y, _ = backbone_model.get_layer(get_branch_id(branch_number)).output\n",
+    "        if head_type == 'resmlp':\n",
+    "            y = mlp_block(y, mlp_dim=MLP_DIM, name='mlp_mixer_%d' % i)\n",
+    "            y = tf.keras.layers.GlobalAveragePooling1D()(y)\n",
+    "        elif head_type == 'mlp':\n",
+    "            y = tf.keras.layers.LayerNormalization(\n",
+    "                epsilon=1e-6,\n",
+    "                name='Transformer/encoder_norm_x_%d' % i\n",
+    "            )(y)\n",
+    "            y = tf.keras.layers.Lambda(lambda v: v[:, 0], name='ExtractToken_x_%d' % i)(y)\n",
+    "        elif head_type == 'vit':\n",
+    "            y, _ = TransformerBlock(\n",
+    "                num_heads=12,\n",
+    "                mlp_dim=3072,\n",
+    "                dropout=0.1,\n",
+    "                name='Transformer/encoderblock_x_%d' % i\n",
+    "            )(y)\n",
+    "            y = tf.keras.layers.LayerNormalization(\n",
+    "                epsilon=1e-6,\n",
+    "                name='Transformer/encoder_norm_x_%d' % i\n",
+    "            )(y)\n",
+    "            y = tf.keras.layers.Lambda(lambda v: v[:, 0], name='ExtractToken_x_%d' % i)(y)\n",
+    "        elif head_type == 'cnn_ignore':\n",
+    "            channels = HIDDEN_DIM\n",
+    "            width = height = IMAGE_SIZE // PATCH_SIZE\n",
+    "            y = tf.keras.layers.Lambda(lambda v: v[:, 1:], name='RemoveToken_%d' % i)(y)\n",
+    "            y = tf.keras.layers.Reshape((width, height, channels), name='cnn_reshape_%d' % i)(y)\n",
+    "            y = tf.keras.layers.Conv2D(\n",
+    "                filters=16,\n",
+    "                kernel_size=(3, 3),\n",
+    "                activation='elu',\n",
+    "                padding='same'\n",
+    "            )(y)\n",
+    "            y = tf.keras.layers.MaxPool2D(pool_size=(2, 2))(y)\n",
+    "            y = tf.keras.layers.Flatten()(y)\n",
+    "        elif head_type == 'cnn_add':    \n",
+    "            channels = HIDDEN_DIM\n",
+    "            width = height = IMAGE_SIZE // PATCH_SIZE\n",
+    "\n",
+    "            y1 = tf.keras.layers.Lambda(lambda v: v[:, 1:], name='RemoveToken_x_%d' % i)(y)\n",
+    "            y1 = tf.keras.layers.Reshape((width, height, channels), name='cnn_reshape_%d' % i)(y1)\n",
+    "\n",
+    "            y2 = tf.keras.layers.Lambda(lambda v: v[:, 0], name='ExtractToken_x_%d' % i)(y)\n",
+    "            y2 = tf.keras.layers.RepeatVector(width * height)(y2)\n",
+    "            y2 = tf.keras.layers.Reshape((width, height, channels), name='cls_reshape_%d' % i)(y2)\n",
+    "\n",
+    "            y = tf.keras.layers.Add()([y1, y2])\n",
+    "\n",
+    "            y = tf.keras.layers.Conv2D(\n",
+    "                filters=16,\n",
+    "                kernel_size=(3, 3),\n",
+    "                activation='elu',\n",
+    "                padding='same'\n",
+    "            )(y)\n",
+    "            y = tf.keras.layers.MaxPool2D(pool_size=(2, 2))(y)\n",
+    "            y = tf.keras.layers.Flatten()(y)\n",
+    "        elif head_type == 'cnn_project':\n",
+    "            channels = HIDDEN_DIM\n",
+    "            width = height = IMAGE_SIZE // PATCH_SIZE\n",
+    "\n",
+    "            y1 = tf.keras.layers.Lambda(lambda v: v[:, 1:], name='RemoveToken_x_%d' % i)(y)\n",
+    "            y1 = tf.keras.layers.Reshape((width, height, channels), name='cnn_reshape_%d' % i)(y1)\n",
+    "\n",
+    "            y2 = tf.keras.layers.Lambda(lambda v: v[:, 0], name='ExtractToken_x_%d' % i)(y)\n",
+    "            y2 = tf.keras.layers.RepeatVector(width * height)(y2)\n",
+    "            y2 = tf.keras.layers.Reshape((width, height, channels), name='cls_reshape_%d' % i)(y2)\n",
+    "\n",
+    "            y = tf.keras.layers.Concatenate()([y1, y2])\n",
+    "\n",
+    "            y = tf.keras.layers.Conv2D(\n",
+    "                filters=16,\n",
+    "                kernel_size=(3, 3),\n",
+    "                activation='elu',\n",
+    "                padding='same'\n",
+    "            )(y)\n",
+    "            y = tf.keras.layers.MaxPool2D(pool_size=(2, 2))(y)\n",
+    "            y = tf.keras.layers.Flatten()(y)\n",
+    "        elif head_type == 'mlp_mixer':\n",
+    "            num_patches = (IMAGE_SIZE // PATCH_SIZE) ** 2 + 1\n",
+    "            y = MixerBlock(\n",
+    "                num_patches=num_patches,\n",
+    "                channel_dim=HIDDEN_DIM,\n",
+    "                token_mixer_hidden_dim=TOKENS_MLP_DIM,\n",
+    "                channel_mixer_hidden_dim=CHANNELS_MLP_DIM\n",
+    "            )(y)\n",
+    "            y = tf.keras.layers.GlobalAveragePooling1D()(y)\n",
+    "\n",
+    "        if dataset == 'cifar10':\n",
+    "            output_units = 10\n",
+    "            output_activation = 'softmax'\n",
+    "        elif dataset == 'cifar100':\n",
+    "            output_units = 100\n",
+    "            output_activation = 'softmax'\n",
+    "        elif dataset == 'disco':\n",
+    "            output_units = 1\n",
+    "            output_activation = None\n",
+    "        else:\n",
+    "            output_units = None\n",
+    "            output_activation = None\n",
+    "\n",
+    "        # MLP head\n",
+    "        initializer = tf.keras.initializers.he_normal()\n",
+    "        regularizer = tf.keras.regularizers.l2()\n",
+    "        y = tf.keras.layers.Dense(\n",
+    "            units=256,\n",
+    "            activation='elu',\n",
+    "            kernel_initializer=initializer,\n",
+    "            kernel_regularizer=regularizer\n",
+    "        )(y)\n",
+    "        y = tf.keras.layers.Dropout(0.5)(y)\n",
+    "        y = tf.keras.layers.Dense(\n",
+    "            units=256,\n",
+    "            activation='elu',\n",
+    "            kernel_initializer=initializer,\n",
+    "            kernel_regularizer=regularizer\n",
+    "        )(y)\n",
+    "        y = tf.keras.layers.Dropout(0.5)(y)\n",
+    "        y = tf.keras.layers.Dense(\n",
+    "            units=output_units,\n",
+    "            activation=output_activation,\n",
+    "            kernel_initializer=initializer,\n",
+    "            kernel_regularizer=regularizer\n",
+    "        )(y)\n",
+    "        outputs.append(y)\n",
+    "\n",
+    "    outputs.append(backbone_model.get_layer(index=-1).output)\n",
+    "    model = tf.keras.models.Model(\n",
+    "        inputs=backbone_model.get_layer(index=0).input,\n",
+    "        outputs=outputs\n",
+    "    )\n",
+    "\n",
+    "    if dataset == 'cifar10' or dataset == 'cifar100':\n",
+    "        loss_type = 'categorical_crossentropy'\n",
+    "        metric_type = 'accuracy'\n",
+    "    elif dataset == 'disco':\n",
+    "        loss_type = 'mean_absolute_error'\n",
+    "        metric_type = 'mean_absolute_error'\n",
+    "    \n",
+    "    model.compile(\n",
+    "        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),\n",
+    "        loss=[loss_type] * (len(branch_numbers) + 1),\n",
+    "        loss_weights=[1] * len(branch_numbers) + [2],\n",
+    "        metrics=[metric_type]\n",
+    "    )\n",
+    "\n",
+    "    return model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def cache_split(cache_dir, images, labels, split):\n",
+    "    for i in range(images.shape[0]):\n",
+    "        if (i + 1) % 100 == 0:\n",
+    "            sys.stdout.write('\\r%d' % (i + 1))\n",
+    "            sys.stdout.flush()\n",
+    "        with open(os.path.join(cache_dir, '%s_%d.pkl' % (split, i)), 'wb') as cache_file:\n",
+    "            pickle.dump({\n",
+    "                'image': transform.resize(images[i], (IMAGE_SIZE, IMAGE_SIZE)),\n",
+    "                'label': labels[i],\n",
+    "            }, cache_file)\n",
+    "    print()  # newline\n",
+    "\n",
+    "def cache_all(dataset):\n",
+    "    if dataset == 'cifar10':\n",
+    "        (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()\n",
+    "    elif dataset == 'cifar100':\n",
+    "        (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar100.load_data()\n",
+    "    else:\n",
+    "        raise Exception('Unknown dataset: %s' % dataset)\n",
+    "\n",
+    "    train_labels = tf.keras.utils.to_categorical(train_labels)\n",
+    "    test_labels = tf.keras.utils.to_categorical(test_labels)\n",
+    "\n",
+    "    val_index = int(len(train_images) * 0.8)\n",
+    "    val_images = train_images[val_index:]\n",
+    "    val_labels = train_labels[val_index:]\n",
+    "    train_images = train_images[:val_index]\n",
+    "    train_labels = train_labels[:val_index]\n",
+    "\n",
+    "    cache_split(dataset, train_images, train_labels, 'train')\n",
+    "    cache_split(dataset, val_images, val_labels, 'val')\n",
+    "    cache_split(dataset, test_images, test_labels, 'test')\n",
+    "\n",
+    "class CIFARSequence(tf.keras.utils.Sequence):\n",
+    "    def __init__(self, split, batch_size, dataset):\n",
+    "        self.split = split\n",
+    "        self.batch_size = batch_size * NUM_GPUS\n",
+    "        self.cache_dir = dataset\n",
+    "        self.count = sum([1 if file_name.startswith(split) else 0 for file_name in os.listdir(self.cache_dir)])\n",
+    "        self.random_permutation = np.random.permutation(self.count)\n",
+    "\n",
+    "    def __len__(self):\n",
+    "        return math.ceil(self.count / self.batch_size)\n",
+    "\n",
+    "    def on_epoch_end(self):\n",
+    "        self.random_permutation = np.random.permutation(self.count)\n",
+    "\n",
+    "    def __getitem__(self, index):\n",
+    "        images = []\n",
+    "        labels = []\n",
+    "        for i in self.random_permutation[index * self.batch_size:(index + 1) * self.batch_size]:\n",
+    "            with open(os.path.join(self.cache_dir, '%s_%d.pkl' % (self.split, i)), 'rb') as cache_file:\n",
+    "                contents = pickle.load(cache_file)\n",
+    "                images.append(contents['image'])\n",
+    "                labels.append(contents['label'])\n",
+    "        return np.array(images), np.array(labels)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def horizontal_flip(image):\n",
+    "    return np.flip(image, axis=1)\n",
+    "\n",
+    "class DISCOSequence(tf.keras.utils.Sequence):\n",
+    "    def __init__(self, split, batch_size):\n",
+    "        self.split = split\n",
+    "        self.cache_dir = os.path.join('disco', 'vit_cache')\n",
+    "        self.split_len = sum([\n",
+    "            1 if file_name.startswith(self.split) else 0 for file_name in os.listdir(self.cache_dir)\n",
+    "        ])\n",
+    "        self.batch_size = batch_size * NUM_GPUS\n",
+    "        self.random_permutation = np.random.permutation(self.split_len)\n",
+    "\n",
+    "    def __len__(self):\n",
+    "        return math.ceil(self.split_len / self.batch_size)\n",
+    "\n",
+    "    def on_epoch_end(self):\n",
+    "        self.random_permutation = np.random.permutation(self.split_len)\n",
+    "\n",
+    "    def __getitem__(self, index):\n",
+    "        spectrograms = []\n",
+    "        images = []\n",
+    "        density_maps = []\n",
+    "        if self.split == 'test':\n",
+    "            index_generator = range(\n",
+    "                index * self.batch_size,\n",
+    "                min((index + 1) * self.batch_size, self.split_len - 1)\n",
+    "            )\n",
+    "        else:\n",
+    "            index_generator = self.random_permutation[index * self.batch_size:(index + 1) * self.batch_size]\n",
+    "        for random_index in index_generator:\n",
+    "            all_path = os.path.join(\n",
+    "                self.cache_dir,\n",
+    "                '%s_%d.pkl' % (self.split, random_index)\n",
+    "            )\n",
+    "            with open(all_path, 'rb') as all_file:\n",
+    "                data = pickle.load(all_file)\n",
+    "                if self.split == 'train' and random.random() < 0.5:  # flip augmentation\n",
+    "                    images.append(horizontal_flip(data['image']))\n",
+    "                else:\n",
+    "                    images.append(data['image'])\n",
+    "                density_maps.append(np.sum(data['density_map']))\n",
+    "\n",
+    "        return np.array(images), np.array(density_maps)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def test_cc(model, test_sequence, total_branches):\n",
+    "    gt = None\n",
+    "    outs = []\n",
+    "    for i, (images, density_maps) in enumerate(test_sequence):\n",
+    "        sys.stdout.write('\\r%d' % (i + 1))\n",
+    "        sys.stdout.flush()\n",
+    "        if gt is not None:\n",
+    "            gt = np.concatenate((gt, density_maps))\n",
+    "        else:\n",
+    "            gt = density_maps\n",
+    "        output = model(images)\n",
+    "        for j in range(total_branches):\n",
+    "            if i == 0:\n",
+    "                outs.append(output[j].numpy().flatten())\n",
+    "            else:\n",
+    "                outs[j] = np.concatenate((outs[j], output[j].numpy().flatten()))\n",
+    "    print()  # newline\n",
+    "    maes = []\n",
+    "    img_patches = VIDEO_PATCHES[0] * VIDEO_PATCHES[1]\n",
+    "    for i in range(0, gt.shape[0], img_patches):\n",
+    "        gt_subset = gt[i:i + img_patches]\n",
+    "        for j in range(total_branches):\n",
+    "            if i == 0:\n",
+    "                maes.append([np.abs(np.sum(gt_subset) - np.sum(outs[j][i:i + img_patches]))])\n",
+    "            else:\n",
+    "                maes[j].append(np.abs(np.sum(gt_subset) - np.sum(outs[j][i:i + img_patches])))\n",
+    "    return [np.mean(np.array(item)) for item in maes]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def train(max_epochs, branch_numbers, head_type, dataset, version, temporary):\n",
+    "    tf.keras.backend.clear_session()\n",
+    "\n",
+    "    with DISTRIBUTED_STRATEGY.scope():\n",
+    "        model = get_model(branch_numbers, head_type, dataset)\n",
+    "\n",
+    "    lr_reduce = tf.keras.callbacks.ReduceLROnPlateau(\n",
+    "        monitor='val_loss',\n",
+    "        factor=0.6,\n",
+    "        patience=2,\n",
+    "        verbose=1,\n",
+    "        mode='min',\n",
+    "        min_lr=1e-7\n",
+    "    )\n",
+    "\n",
+    "    early_stop = tf.keras.callbacks.EarlyStopping(\n",
+    "        monitor='val_loss',\n",
+    "        patience=5,\n",
+    "        verbose=1,\n",
+    "        mode='min'\n",
+    "    )\n",
+    "\n",
+    "    save_model_checkpoint_file = 'bmvc_rebuttal_ee_v%d_%s_%s_%s.h5' % (\n",
+    "        version,\n",
+    "        head_type,\n",
+    "        dataset,\n",
+    "        '-'.join([str(branch_number) for branch_number in branch_numbers])\n",
+    "    )\n",
+    "\n",
+    "    checkpoint = tf.keras.callbacks.ModelCheckpoint(\n",
+    "        save_model_checkpoint_file,\n",
+    "        monitor='val_loss',\n",
+    "        verbose=1,\n",
+    "        save_weights_only=False,\n",
+    "        save_best_only=True,\n",
+    "        mode='min',\n",
+    "        save_freq='epoch'\n",
+    "    )\n",
+    "\n",
+    "    callbacks = [lr_reduce, early_stop]\n",
+    "    if not temporary:\n",
+    "        callbacks.append(checkpoint)\n",
+    "\n",
+    "    batch_size = 4\n",
+    "    if dataset == 'cifar10' or dataset == 'cifar100':\n",
+    "        train_sequence = CIFARSequence('train', batch_size, dataset)\n",
+    "        val_sequence = CIFARSequence('val', batch_size, dataset)\n",
+    "        test_sequence = CIFARSequence('test', batch_size, dataset)\n",
+    "    elif dataset == 'disco':\n",
+    "        train_sequence = DISCOSequence('train', batch_size)\n",
+    "        val_sequence = DISCOSequence('val', batch_size)\n",
+    "        test_sequence = DISCOSequence('test', 2 * batch_size)\n",
+    "    else:\n",
+    "        raise Exception('Unknown dataset: %s' % dataset)\n",
+    "\n",
+    "    history = model.fit(\n",
+    "        train_sequence,\n",
+    "        validation_data=val_sequence,\n",
+    "        epochs=max_epochs,\n",
+    "        shuffle=True,\n",
+    "        callbacks=callbacks,\n",
+    "        verbose=1\n",
+    "    )\n",
+    "\n",
+    "    if dataset == 'cifar10' or dataset == 'cifar100':\n",
+    "        test_accuracy = model.evaluate(test_sequence)[1]\n",
+    "    elif dataset == 'disco':\n",
+    "        test_accuracy = test_cc(model, test_sequence, len(branch_numbers) + 1)\n",
+    "\n",
+    "    model_params = get_params(model) / 10 ** 6\n",
+    "    model_flops = get_flops(model) / 10 ** 9\n",
+    "\n",
+    "    return model, test_accuracy, model_params, model_flops"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "cache_all('cifar10')\n",
+    "cache_all('cifar100')\n",
+    "model, test_accuracy, model_params, model_flops = train(\n",
+    "    max_epochs=100,\n",
+    "    branch_numbers=[3, 6, 9],\n",
+    "    head_type='resmlp',\n",
+    "    dataset='disco',\n",
+    "    version=5,\n",
+    "    temporary=False\n",
+    ")\n",
+    "print(test_accuracy)"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.9"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/plots.ipynb b/plots.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..78100b2678bf6c4ec1268235aeeec96cae28b98b
--- /dev/null
+++ b/plots.ipynb
@@ -0,0 +1,776 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import matplotlib\n",
+    "import matplotlib.pyplot as plt\n",
+    "from matplotlib.patches import Circle\n",
+    "from operator import sub\n",
+    "\n",
+    "matplotlib.rc('font',**{'size': 20})"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "EPSILON = 0.1\n",
+    "\n",
+    "def practical_points_generator(flops, accuracies, mode):\n",
+    "    all_data_points = []\n",
+    "    for i in range(len(flops)):\n",
+    "        for j in range(len(flops[i])):\n",
+    "            all_data_points.append((flops[i][j], accuracies[i][j]))\n",
+    "    if mode == 'max':\n",
+    "        sorted_data_points = sorted(all_data_points, key=lambda x: 10 ** 6 * x[0] - x[1])\n",
+    "        max_accuracy = 0\n",
+    "        for point in sorted_data_points:\n",
+    "            if point[1] - max_accuracy > EPSILON:\n",
+    "                max_accuracy = point[1]\n",
+    "                yield point\n",
+    "    else:\n",
+    "        sorted_data_points = sorted(all_data_points, key=lambda x: 10 ** 6 * x[0] + x[1])\n",
+    "        min_mae = 10 ** 6\n",
+    "        for point in sorted_data_points:\n",
+    "            if min_mae - point[1] > EPSILON:\n",
+    "                min_mae = point[1]\n",
+    "                yield point"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_marker_linestyle_color(label):\n",
+    "    if label == 'CNN-Ignore-EE':\n",
+    "        return 'o', '-', 'b'\n",
+    "    elif label == 'MLP-Mixer-EE':\n",
+    "        return 'v', '--', 'g'\n",
+    "    elif label == 'MLP-EE':\n",
+    "        return 'P', '-.', 'r'\n",
+    "    elif label == 'ViT-EE':\n",
+    "        return 'X', ':', 'c'\n",
+    "    elif label == 'CNN-Add-EE':\n",
+    "        return 'D', '-', 'm'\n",
+    "    elif label == 'ResMLP-EE':\n",
+    "        return '^', '--', 'y'\n",
+    "    elif label == 'CNN-Project-EE':\n",
+    "        return 's', '-.', 'k'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def correct_flops(flops):\n",
+    "    new_flops = []\n",
+    "    for exit_type_flops in flops:\n",
+    "        new_exit_type_flops = []\n",
+    "        for flops in exit_type_flops:\n",
+    "            new_exit_type_flops.append(flops / 2)\n",
+    "        new_flops.append(new_exit_type_flops)\n",
+    "    return new_flops\n",
+    "\n",
+    "def draw_plots(\n",
+    "    flops,\n",
+    "    accuracies,\n",
+    "    labels,\n",
+    "    ranges,\n",
+    "    name,\n",
+    "    backbone_flops,\n",
+    "    backbone_accuracy,\n",
+    "    mode='max',\n",
+    "    include_mlp=True,\n",
+    "    y_axis_limit=None\n",
+    "):\n",
+    "    flops = correct_flops(flops)\n",
+    "    backbone_flops /= 2\n",
+    "    fig, axes = plt.subplots(1, len(ranges), figsize=(30, 10))\n",
+    "    for range_index in range(len(ranges)):\n",
+    "        start = ranges[range_index][0]\n",
+    "        end = ranges[range_index][1]\n",
+    "        if len(ranges) == 1:\n",
+    "            ax = axes\n",
+    "        else:\n",
+    "            ax = axes[range_index]\n",
+    "        ax.set_title('%s, Exits %d to %d' % (name, start + 1, end))\n",
+    "        ax.set_xlabel('FLOPS (B)')\n",
+    "        ax.set_ylabel('Accuracy (%)' if mode == 'max' else 'MAE')\n",
+    "        used_flops = []\n",
+    "        used_accuracies = []\n",
+    "        for i in range(len(flops)):\n",
+    "            if include_mlp or labels[i] != 'MLP-EE':\n",
+    "                used_flops += flops[i][start:end]\n",
+    "                used_accuracies += accuracies[i][start:end]\n",
+    "                marker, linestyle, color = get_marker_linestyle_color(labels[i])\n",
+    "                ax.plot(\n",
+    "                    flops[i][start:end],\n",
+    "                    accuracies[i][start:end],\n",
+    "                    marker=marker,\n",
+    "                    linestyle=linestyle,\n",
+    "                    color=color,\n",
+    "                    label=labels[i],\n",
+    "                    markersize=8\n",
+    "                )\n",
+    "        if y_axis_limit is not None:\n",
+    "            ax.set_ylim(y_axis_limit)\n",
+    "        ax.legend()\n",
+    "        for point in practical_points_generator(flops, accuracies, mode):\n",
+    "            if point in zip(used_flops, used_accuracies):\n",
+    "                ax.plot(point[0], point[1], color='grey', marker='o', fillstyle='none', markersize=20)\n",
+    "    if len(ranges) == 1:\n",
+    "        last_ax = axes\n",
+    "    else:\n",
+    "        last_ax = axes[-1]\n",
+    "    last_ax.scatter([backbone_flops], [backbone_accuracy])\n",
+    "    last_ax.annotate(\n",
+    "        'Final',\n",
+    "        xy=(backbone_flops, backbone_accuracy),\n",
+    "        xytext=(-24, 8),\n",
+    "        textcoords='offset pixels'\n",
+    "    )"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "cnn_ignore_accuracies = [\n",
+    "    74.86,\n",
+    "    80.47,\n",
+    "    85.10,\n",
+    "    89.95,\n",
+    "    92.51,\n",
+    "    94.02,\n",
+    "    95.51,\n",
+    "    96.63,\n",
+    "    97.62,\n",
+    "    97.93,\n",
+    "    98.00,\n",
+    "]\n",
+    "\n",
+    "cnn_ignore_flops = [10.039666312, 19.27067618, 28.501686048, 37.732695916, 46.963705784, 56.194715652, 65.42572552, 74.656735388, 83.887745256, 93.118755124, 102.349764992]\n",
+    "\n",
+    "mlp_mixer_accuracies = [\n",
+    "    74.53,\n",
+    "    81.40,\n",
+    "    85.49,\n",
+    "    88.33,\n",
+    "    91.84,\n",
+    "    93.79,\n",
+    "    94.97,\n",
+    "    95.99,\n",
+    "    97.08,\n",
+    "    97.75,\n",
+    "    97.86,\n",
+    "]\n",
+    "\n",
+    "mlp_mixer_flops = [16.054064008, 25.285073876, 34.516083744, 43.747093612, 52.97810348, 62.209113348, 71.440123216, 80.671133084, 89.902142952, 99.13315282, 108.364162688]\n",
+    "\n",
+    "mlp_accuracies = [\n",
+    "    26.19,\n",
+    "    42.07,\n",
+    "    58.42,\n",
+    "    72.60,\n",
+    "    81.46,\n",
+    "    87.19,\n",
+    "    90.81,\n",
+    "    93.11,\n",
+    "    96.07,\n",
+    "    97.06,\n",
+    "    97.92,\n",
+    "]\n",
+    "\n",
+    "mlp_flops = [9.915005706, 19.146015574, 28.377025442, 37.60803531, 46.839045178, 56.070055046, 65.301064914, 74.532074782, 83.76308465, 92.994094518, 102.225104386]\n",
+    "\n",
+    "vit_accuracies = [\n",
+    "    79.14,\n",
+    "    84.91,\n",
+    "    89.83,\n",
+    "    92.60,\n",
+    "    94.61,\n",
+    "    96.04,\n",
+    "    96.65,\n",
+    "    97.39,\n",
+    "    97.67,\n",
+    "    98.02,\n",
+    "    98.09,\n",
+    "]\n",
+    "\n",
+    "vit_flops = [19.146015574, 28.377025442, 37.60803531, 46.839045178, 56.070055046, 65.301064914, 74.532074782, 83.76308465, 92.994094518, 102.225104386, 111.456114254]\n",
+    "\n",
+    "cnn_add_accuracies = [\n",
+    "    77.19,\n",
+    "    81.22,\n",
+    "    86.25,\n",
+    "    89.50,\n",
+    "    92.03,\n",
+    "    93.94,\n",
+    "    95.19,\n",
+    "    95.98,\n",
+    "    97.43,\n",
+    "    97.72,\n",
+    "    98.13,\n",
+    "]\n",
+    "\n",
+    "cnn_add_flops = [10.039666312, 19.27067618, 28.501686048, 37.732695916, 46.963705784, 56.194715652, 65.42572552, 74.656735388, 83.887745256, 93.118755124, 102.349764992]\n",
+    "\n",
+    "resmlp_accuracies = [\n",
+    "    74.99,\n",
+    "    85.49,\n",
+    "    90.45,\n",
+    "    92.44,\n",
+    "    94.14,\n",
+    "    94.76,\n",
+    "    95.93,\n",
+    "    96.60,\n",
+    "    97.63,\n",
+    "    97.94,\n",
+    "    98.12,\n",
+    "]\n",
+    "\n",
+    "resmlp_flops = [15.88094452, 25.111954388, 34.342964256, 43.573974124, 52.804983992, 62.03599386, 71.267003728, 80.498013596, 89.729023464, 98.960033332, 108.1910432]\n",
+    "\n",
+    "cnn_project_accuracies = [\n",
+    "    76.60,\n",
+    "    81.07,\n",
+    "    86.76,\n",
+    "    90.28,\n",
+    "    92.17,\n",
+    "    94.24,\n",
+    "    95.67,\n",
+    "    96.73,\n",
+    "    97.37,\n",
+    "    97.82,\n",
+    "    97.90,\n",
+    "]\n",
+    "\n",
+    "cnn_project_flops = [10.167068296, 19.398078164, 28.629088032, 37.8600979, 47.091107768, 56.322117636, 65.553127504, 74.784137372, 84.01514724, 93.246157108, 102.477166976]\n",
+    "\n",
+    "backbone_accuracy = 98.31\n",
+    "backbone_flops = 111.46"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "flops = [\n",
+    "    cnn_ignore_flops,\n",
+    "    mlp_mixer_flops,\n",
+    "    mlp_flops,\n",
+    "    vit_flops,\n",
+    "    cnn_add_flops,\n",
+    "    resmlp_flops,\n",
+    "    cnn_project_flops,\n",
+    "]\n",
+    "accuracies = [\n",
+    "    cnn_ignore_accuracies,\n",
+    "    mlp_mixer_accuracies,\n",
+    "    mlp_accuracies,\n",
+    "    vit_accuracies,\n",
+    "    cnn_add_accuracies,\n",
+    "    resmlp_accuracies,\n",
+    "    cnn_project_accuracies,\n",
+    "]\n",
+    "labels = [\n",
+    "    'CNN-Ignore-EE',\n",
+    "    'MLP-Mixer-EE',\n",
+    "    'MLP-EE',\n",
+    "    'ViT-EE',\n",
+    "    'CNN-Add-EE',\n",
+    "    'ResMLP-EE',\n",
+    "    'CNN-Project-EE',\n",
+    "]\n",
+    "ranges = [(0, 6), (5, 11)]\n",
+    "full_range = [(0, 11)]\n",
+    "name = 'CIFAR-10'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "\n",
+      "text/plain": [
+       "<Figure size 2160x720 with 1 Axes>"
+      ]
+     },
+     "metadata": {
+      "needs_background": "light"
+     },
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "draw_plots(flops, accuracies, labels, full_range, name, backbone_flops, backbone_accuracy)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "\n",
+      "text/plain": [
+       "<Figure size 2160x720 with 2 Axes>"
+      ]
+     },
+     "metadata": {
+      "needs_background": "light"
+     },
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "draw_plots(flops, accuracies, labels, ranges, name, backbone_flops, backbone_accuracy, include_mlp=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "cnn_ignore_accuracies = [\n",
+    "    45.23,\n",
+    "    50.60,\n",
+    "    56.41,\n",
+    "    61.53,\n",
+    "    67.16,\n",
+    "    70.39,\n",
+    "    75.41,\n",
+    "    79.47,\n",
+    "    83.56,\n",
+    "    87.76,\n",
+    "    89.33,\n",
+    "]\n",
+    "\n",
+    "cnn_ignore_flops = [10.039666312, 19.27067618, 28.501686048, 37.732695916, 46.963705784, 56.194715652, 65.42572552, 74.656735388, 83.887745256, 93.118755124, 102.349764992]\n",
+    "\n",
+    "mlp_accuracies = [\n",
+    "    7.07,\n",
+    "    16.26,\n",
+    "    21.63,\n",
+    "    40.28,\n",
+    "    49.04,\n",
+    "    55.39,\n",
+    "    59.79,\n",
+    "    65.58,\n",
+    "    71.40,\n",
+    "    81.64,\n",
+    "    88.69,\n",
+    "]\n",
+    "\n",
+    "mlp_flops = [9.915005706, 19.146015574, 28.377025442, 37.60803531, 46.839045178, 56.070055046, 65.301064914, 74.532074782, 83.76308465, 92.994094518, 102.225104386]\n",
+    "\n",
+    "vit_accuracies = [\n",
+    "    31.08,\n",
+    "    45.04,\n",
+    "    58.59,\n",
+    "    66.54,\n",
+    "    73.44,\n",
+    "    79.18,\n",
+    "    83.10,\n",
+    "    86.26,\n",
+    "    88.38,\n",
+    "    90.12,\n",
+    "    90.92,\n",
+    "]\n",
+    "\n",
+    "vit_flops = [19.146015574, 28.377025442, 37.60803531, 46.839045178, 56.070055046, 65.301064914, 74.532074782, 83.76308465, 92.994094518, 102.225104386, 111.456114254]\n",
+    "\n",
+    "mlp_mixer_accuracies = [\n",
+    "    34.31,\n",
+    "    47.03,\n",
+    "    59.18,\n",
+    "    66.32,\n",
+    "    73.13,\n",
+    "    78.11,\n",
+    "    81.64,\n",
+    "    84.31,\n",
+    "    87.33,\n",
+    "    88.50,\n",
+    "    89.98,\n",
+    "]\n",
+    "\n",
+    "mlp_mixer_flops = [16.054064008, 25.285073876, 34.516083744, 43.747093612, 52.97810348, 62.209113348, 71.440123216, 80.671133084, 89.902142952, 99.13315282, 108.364162688]\n",
+    "\n",
+    "resmlp_accuracies = [\n",
+    "    34.65,\n",
+    "    58.73,\n",
+    "    66.71,\n",
+    "    72.44,\n",
+    "    76.88,\n",
+    "    80.94,\n",
+    "    84.51,\n",
+    "    86.83,\n",
+    "    88.51,\n",
+    "    90.20,\n",
+    "    91.13,\n",
+    "]\n",
+    "\n",
+    "resmlp_flops = [15.88094452, 25.111954388, 34.342964256, 43.573974124, 52.804983992, 62.03599386, 71.267003728, 80.498013596, 89.729023464, 98.960033332, 108.1910432]\n",
+    "\n",
+    "cnn_project_accuracies = [\n",
+    "    43.46,\n",
+    "    47.42,\n",
+    "    50.43,\n",
+    "    59.26,\n",
+    "    60.92,\n",
+    "    64.07,\n",
+    "    68.61,\n",
+    "    70.97,\n",
+    "    65.69,\n",
+    "    70.57,\n",
+    "    80.95,\n",
+    "]\n",
+    "\n",
+    "cnn_project_flops = [10.167068296, 19.398078164, 28.629088032, 37.8600979, 47.091107768, 56.322117636, 65.553127504, 74.784137372, 84.01514724, 93.246157108, 102.477166976]\n",
+    "\n",
+    "cnn_add_accuracies = [\n",
+    "    44.36,\n",
+    "    47.66,\n",
+    "    52.82,\n",
+    "    58.94,\n",
+    "    62.96,\n",
+    "    65.25,\n",
+    "    70.61,\n",
+    "    69.48,\n",
+    "    73.92,\n",
+    "    78.11,\n",
+    "    80.88,\n",
+    "]\n",
+    "\n",
+    "cnn_add_flops = [10.039666312, 19.27067618, 28.501686048, 37.732695916, 46.963705784, 56.194715652, 65.42572552, 74.656735388, 83.887745256, 93.118755124, 102.349764992]\n",
+    "\n",
+    "backbone_accuracy = 91.24\n",
+    "backbone_flops = 111.46"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "flops = [\n",
+    "    cnn_ignore_flops,\n",
+    "    mlp_mixer_flops,\n",
+    "    mlp_flops,\n",
+    "    vit_flops,\n",
+    "    cnn_add_flops,\n",
+    "    resmlp_flops,\n",
+    "    cnn_project_flops,\n",
+    "]\n",
+    "accuracies = [\n",
+    "    cnn_ignore_accuracies,\n",
+    "    mlp_mixer_accuracies,\n",
+    "    mlp_accuracies,\n",
+    "    vit_accuracies,\n",
+    "    cnn_add_accuracies,\n",
+    "    resmlp_accuracies,\n",
+    "    cnn_project_accuracies,\n",
+    "]\n",
+    "markers = ['o', 'v', 'P', 'X', 'D', '^', 's'] # https://matplotlib.org/2.0.2/api/lines_api.html\n",
+    "linestyles = ['-', '--', '-.', ':']\n",
+    "labels = [\n",
+    "    'CNN-Ignore-EE',\n",
+    "    'MLP-Mixer-EE',\n",
+    "    'MLP-EE',\n",
+    "    'ViT-EE',\n",
+    "    'CNN-Add-EE',\n",
+    "    'ResMLP-EE',\n",
+    "    'CNN-Project-EE',\n",
+    "]\n",
+    "ranges = [(0, 6), (5, 11)]\n",
+    "name = 'CIFAR-100'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "\n",
+      "text/plain": [
+       "<Figure size 2160x720 with 1 Axes>"
+      ]
+     },
+     "metadata": {
+      "needs_background": "light"
+     },
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "draw_plots(flops, accuracies, labels, full_range, name, backbone_flops, backbone_accuracy)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "draw_plots(flops, accuracies, labels, ranges, name, backbone_flops, backbone_accuracy, include_mlp=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "mlp_maes_5 = [66.95597652493295, 55.6043279293333, 40.01493319068866, 27.534241703863543, 20.099828833925915, 16.6587248431277, 14.207798832063078, 13.643488861268288, 12.692167538868029, 11.745789087804178, 11.116295196966094]\n",
+    "\n",
+    "mlp_maes_4 = [63.81760612188089, 48.84459626358478, 31.38019274014999, 21.35408311672693, 17.32409016078545, 15.096452528379604, 13.05824572280687, 12.952922174428489, 12.154847678503849, 11.278382839842044, 10.903999793224612]\n",
+    "\n",
+    "mlp_maes = [min(mlp_maes_4[i], mlp_maes_5[i]) for i in range(len(mlp_maes_4))]\n",
+    "\n",
+    "mlp_flops = [9.915005706, 19.146015574, 28.377025442, 37.60803531, 46.839045178, 56.070055046, 65.301064914, 74.532074782, 83.76308465, 92.994094518, 102.225104386]\n",
+    "\n",
+    "resmlp_maes_4 = [17.60861961755429, 15.780182688240421, 14.967578141838581, 12.967585373219885, 13.565632146464361, 17.838170967843872, 25.29613745927154, 18.395887481893116, 22.84772446918155, 21.12996455613271, 19.378696329954778]\n",
+    "\n",
+    "resmlp_maes_5 = [37.567863819194585, 16.379686318456177, 15.91074247730594, 14.830663657496425, 14.105883239636947, 13.518657890148038, 14.987332331531313, 13.092971086047813, 12.99657056636505, 11.822814100180334, 13.347536167698324]\n",
+    "\n",
+    "resmlp_maes = [min(resmlp_maes_4[i], resmlp_maes_5[i]) for i in range(len(resmlp_maes_4))]\n",
+    "\n",
+    "resmlp_flops = [15.88094452, 25.111954388, 34.342964256, 43.573974124, 52.804983992, 62.03599386, 71.267003728, 80.498013596, 89.729023464, 98.960033332, 108.1910432]\n",
+    "\n",
+    "vit_maes_4 = [17.25549441917847, 15.310706459980956, 13.595461759060662, 12.577511369468663, 12.067579316858268, 11.083785289440117, 11.188187754225888, 10.87809422029592, 11.642507894276399, 10.745084875372624, 10.597995807472484]\n",
+    "\n",
+    "vit_maes_5 = [20.92489066249554, 17.892692841223028, 15.841838350091864, 14.224327003689748, 13.197131156595411, 12.574951000025784, 12.357885581834648, 11.632370412163207, 11.68577344596502, 11.010207051810061, 10.896251262811639]\n",
+    "\n",
+    "vit_maes = [min(vit_maes_4[i], vit_maes_5[i]) for i in range(len(vit_maes_4))]\n",
+    "\n",
+    "vit_alt_maes = [18.049246351608634, 15.180274767263906, 13.628905414822626, 12.978167053696502, 11.69005954003767, 11.319822538837004, 10.915161482429214, 11.06383380322437, 10.836275275190616, 10.95410315675852, 10.759412091444661]\n",
+    "\n",
+    "vit_flops = [19.146015574, 28.377025442, 37.60803531, 46.839045178, 56.070055046, 65.301064914, 74.532074782, 83.76308465, 92.994094518, 102.225104386, 111.456114254]\n",
+    "\n",
+    "cnn_ignore_maes_4 = [16.324433793363568, 15.642264217512718, 15.203111727147673, 14.539004881089028, 14.859935972397198, 16.9033360170686, 14.710937910877693, 15.024733024942853, 13.002627540657842, 15.919211207802743, 13.1617159714426]\n",
+    "\n",
+    "cnn_ignore_maes_5 = [19.472250306927545, 16.382709773809054, 15.30624921188889, 16.640002978804905, 15.339903997527072, 14.558104298405594, 14.186327249106556, 18.8666830396562, 15.023331387561072, 16.011291563485088, 16.721815794863026]\n",
+    "\n",
+    "cnn_ignore_maes = [min(cnn_ignore_maes_4[i], cnn_ignore_maes_5[i]) for i in range(len(cnn_ignore_maes_4))]\n",
+    "\n",
+    "cnn_ignore_flops = [10.039666312, 19.27067618, 28.501686048, 37.732695916, 46.963705784, 56.194715652, 65.42572552, 74.656735388, 83.887745256, 93.118755124, 102.349764992]\n",
+    "\n",
+    "cnn_add_maes_4 = [18.90440685073256, 16.726752551982553, 13.927111037765874, 13.950359128337304, 15.485235473291226, 13.63742827597631, 15.302497663269468, 14.101453771361111, 13.903442669190742, 13.146807598164951, 11.929169691797588]\n",
+    "\n",
+    "cnn_add_maes_5 = [23.029530046001007, 17.36398403153728, 17.532784224678203, 15.0629114230157, 14.053487970473283, 14.495352183394191, 12.894946989563001, 15.040746399772036, 14.855697106636278, 13.728063926384943, 12.44930448742072]\n",
+    "\n",
+    "cnn_add_maes = [min(cnn_add_maes_4[i], cnn_add_maes_5[i]) for i in range(len(cnn_add_maes_4))]\n",
+    "\n",
+    "cnn_add_flops = [10.039666312, 19.27067618, 28.501686048, 37.732695916, 46.963705784, 56.194715652, 65.42572552, 74.656735388, 83.887745256, 93.118755124, 102.349764992]\n",
+    "\n",
+    "cnn_project_maes_4 = [17.97360331761429, 15.467450911446946, 14.959393072910776, 14.132034754839491, 16.076203802519633, 14.18988822844617, 13.815687876052461, 12.58639125568036, 11.517910521763955, 12.874515941075424, 13.256840143438424]\n",
+    "\n",
+    "cnn_project_maes_5 = [27.43422174678897, 16.961936098398724, 15.575427034930241, 14.76191664899309, 13.80299366597453, 13.76081104695114, 13.545493860247136, 12.36295470406632, 12.355132346136038, 11.9225329570658, 13.00598683830469]\n",
+    "\n",
+    "cnn_project_maes = [min(cnn_project_maes_4[i], cnn_project_maes_5[i]) for i in range(len(cnn_project_maes_4))]\n",
+    "\n",
+    "cnn_project_flops = [10.167068296, 19.398078164, 28.629088032, 37.8600979, 47.091107768, 56.322117636, 65.553127504, 74.784137372, 84.01514724, 93.246157108, 102.477166976]\n",
+    "\n",
+    "mlp_mixer_maes_4 = [17.449155579073565, 15.024769666981092, 13.918311021642689, 13.102047395174287, 13.740980489502375, 13.465846355746645, 13.171297758797463, 15.58475445045778, 14.971250198540853, 18.186729623156797, 17.877043017754772]\n",
+    "\n",
+    "mlp_mixer_maes_5 = [17.895658858100944, 15.83869162005821, 14.261989270241047, 13.886250207755124, 13.046678551942659, 12.559346749946299, 14.470972782496279, 12.07975543923034, 12.112579300768346, 13.57506639775688, 12.628674099279955]\n",
+    "\n",
+    "mlp_mixer_maes = [min(mlp_mixer_maes_4[i], mlp_mixer_maes_5[i]) for i in range(len(mlp_mixer_maes_4))]\n",
+    "\n",
+    "mlp_mixer_flops = [16.054064008, 25.285073876, 34.516083744, 43.747093612, 52.97810348, 62.209113348, 71.440123216, 80.671133084, 89.902142952, 99.13315282, 108.364162688]\n",
+    "\n",
+    "backbone_mae = 11.07\n",
+    "backbone_flops = 111.46"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "flops = [\n",
+    "    mlp_flops,\n",
+    "    vit_flops,\n",
+    "    resmlp_flops,\n",
+    "    cnn_ignore_flops,\n",
+    "    cnn_add_flops,\n",
+    "    cnn_project_flops,\n",
+    "    mlp_mixer_flops,\n",
+    "]\n",
+    "maes = [\n",
+    "    mlp_maes,\n",
+    "    vit_maes,\n",
+    "    resmlp_maes,\n",
+    "    cnn_ignore_maes,\n",
+    "    cnn_add_maes,\n",
+    "    cnn_project_maes,\n",
+    "    mlp_mixer_maes,\n",
+    "]\n",
+    "markers = ['o', 'v', 'P', 'X', 'D', '^', 's'] # https://matplotlib.org/2.0.2/api/lines_api.html\n",
+    "linestyles = ['-', '--', '-.', ':']\n",
+    "labels = [\n",
+    "    'MLP-EE',\n",
+    "    'ViT-EE',\n",
+    "    'ResMLP-EE',\n",
+    "    'CNN-Ignore-EE',\n",
+    "    'CNN-Add-EE',\n",
+    "    'CNN-Project-EE',\n",
+    "    'MLP-Mixer-EE'\n",
+    "]\n",
+    "ranges = [(0, 6), (5, 11)]\n",
+    "name = 'DISCO'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "\n",
+      "text/plain": [
+       "<Figure size 2160x720 with 1 Axes>"
+      ]
+     },
+     "metadata": {
+      "needs_background": "light"
+     },
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "draw_plots(flops, maes, labels, full_range, name, backbone_flops, backbone_mae, mode='min', y_axis_limit=[10, 20])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "mlp_flops = [9.915005706, 19.146015574, 28.377025442, 37.60803531, 46.839045178, 56.070055046, 65.301064914, 74.532074782, 83.76308465, 92.994094518, 102.225104386]\n",
+    "mlp_accuracies = [60.66, 73.76, 82.65, 87.55, 89.60, 92.23, 92.97, 94.04, 94.67, 94.83, 94.81]\n",
+    "cnn_ignore_flops = [10.039666312, 19.27067618, 28.501686048, 37.732695916, 46.963705784, 56.194715652, 65.42572552, 74.656735388, 83.887745256, 93.118755124, 102.349764992]\n",
+    "cnn_ignore_accuracies = [92.86, 93.42, 93.49, 93.74, 93.85, 94.12, 94.16, 94.49, 94.58, 94.73, 94.79]\n",
+    "cnn_add_flops = [10.039666312, 19.27067618, 28.501686048, 37.732695916, 46.963705784, 56.194715652, 65.42572552, 74.656735388, 83.887745256, 93.118755124, 102.349764992]\n",
+    "cnn_add_accuracies = [92.45, 93.08, 93.83, 93.75, 93.67, 93.54, 93.91, 94.52, 94.71, 94.68, 94.68]\n",
+    "cnn_project_flops = [10.167068296, 19.398078164, 28.629088032, 37.8600979, 47.091107768, 56.322117636, 65.553127504, 74.784137372, 84.01514724, 93.246157108, 102.477166976]\n",
+    "cnn_project_accuracies = [92.72, 92.95, 94.13, 93.68, 93.71, 93.84, 94.15, 94.52, 94.65, 94.49, 94.88]\n",
+    "vit_flops = [19.146015574, 28.377025442, 37.60803531, 46.839045178, 56.070055046, 65.301064914, 74.532074782, 83.76308465, 92.994094518, 102.225104386, 111.456114254]\n",
+    "vit_accuracies = [91.38, 92.28, 92.61, 93.48, 93.86, 94.19, 94.54, 94.65, 94.85, 94.94, 94.89]\n",
+    "resmlp_flops = [15.88094452, 25.111954388, 34.342964256, 43.573974124, 52.804983992, 62.03599386, 71.267003728, 80.498013596, 89.729023464, 98.960033332, 108.1910432]\n",
+    "resmlp_accuracies = [90.25, 92.54, 93.64, 93.76, 93.63, 93.68, 94.08, 94.40, 94.74, 94.82, 94.85]\n",
+    "mlp_mixer_flops = [16.054064008, 25.285073876, 34.516083744, 43.747093612, 52.97810348, 62.209113348, 71.440123216, 80.671133084, 89.902142952, 99.13315282, 108.364162688]\n",
+    "mlp_mixer_accuracies = [90.81, 92.52, 93.26, 93.78, 93.42, 93.94, 94.08, 94.52, 94.78, 94.77, 94.86]\n",
+    "backbone_accuracy = 95.00\n",
+    "backbone_flops = 111.46"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "flops = [\n",
+    "    cnn_ignore_flops,\n",
+    "    mlp_mixer_flops,\n",
+    "    mlp_flops,\n",
+    "    vit_flops,\n",
+    "    cnn_add_flops,\n",
+    "    resmlp_flops,\n",
+    "    cnn_project_flops,\n",
+    "]\n",
+    "accuracies = [\n",
+    "    cnn_ignore_accuracies,\n",
+    "    mlp_mixer_accuracies,\n",
+    "    mlp_accuracies,\n",
+    "    vit_accuracies,\n",
+    "    cnn_add_accuracies,\n",
+    "    resmlp_accuracies,\n",
+    "    cnn_project_accuracies,\n",
+    "]\n",
+    "markers = ['o', 'v', 'P', 'X', 'D', '^', 's'] # https://matplotlib.org/2.0.2/api/lines_api.html\n",
+    "linestyles = ['-', '--', '-.', ':']\n",
+    "labels = [\n",
+    "    'CNN-Ignore-EE',\n",
+    "    'MLP-Mixer-EE',\n",
+    "    'MLP-EE',\n",
+    "    'ViT-EE',\n",
+    "    'CNN-Add-EE',\n",
+    "    'ResMLP-EE',\n",
+    "    'CNN-Project-EE',\n",
+    "]\n",
+    "ranges = [(0, 6), (5, 11)]\n",
+    "name = 'Fashion MNIST'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 18,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "\n",
+      "text/plain": [
+       "<Figure size 2160x720 with 2 Axes>"
+      ]
+     },
+     "metadata": {
+      "needs_background": "light"
+     },
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "draw_plots(flops, accuracies, labels, ranges, name, backbone_flops, backbone_accuracy, include_mlp=False)"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.9"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/precompute_cifar_features.ipynb b/precompute_cifar_features.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..b5829ca5e0e27b85c37d93b2df990df70120d88d
--- /dev/null
+++ b/precompute_cifar_features.ipynb
@@ -0,0 +1,159 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "SELECTED_GPUS = [4]\n",
+    "\n",
+    "import os\n",
+    "\n",
+    "os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu_number) for gpu_number in SELECTED_GPUS])\n",
+    "\n",
+    "import tensorflow as tf \n",
+    "\n",
+    "tf.get_logger().setLevel('INFO')\n",
+    "\n",
+    "assert len(tf.config.list_physical_devices('GPU')) > 0\n",
+    "\n",
+    "GPUS = tf.config.experimental.list_physical_devices('GPU')\n",
+    "for gpu in GPUS:\n",
+    "    tf.config.experimental.set_memory_growth(gpu, True)\n",
+    "\n",
+    "DISTRIBUTED_STRATEGY = tf.distribute.MirroredStrategy(\n",
+    "    cross_device_ops=tf.distribute.NcclAllReduce(),\n",
+    "    devices=['/gpu:%d' % index for index in range(len(SELECTED_GPUS))]\n",
+    ")\n",
+    "\n",
+    "NUM_GPUS = DISTRIBUTED_STRATEGY.num_replicas_in_sync\n",
+    "\n",
+    "print('Number of devices: {}'.format(NUM_GPUS))\n",
+    "\n",
+    "import math\n",
+    "import numpy as np\n",
+    "import pickle\n",
+    "import random\n",
+    "import sys\n",
+    "import time\n",
+    "from skimage import transform\n",
+    "from vit_keras import vit\n",
+    "from vit_keras.layers import ClassToken, AddPositionEmbs, MultiHeadSelfAttention, TransformerBlock\n",
+    "\n",
+    "PRECOMPUTE_DIR = 'precompute'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_branch_id(branch_number):\n",
+    "    if branch_number == 1:\n",
+    "        return 'transformer_block'\n",
+    "    else:\n",
+    "        return 'transformer_block_%d' % (branch_number - 1)\n",
+    "\n",
+    "def get_model(dataset):\n",
+    "    backbone_model = tf.keras.models.load_model('vit_%s_v1.h5' % dataset, custom_objects={\n",
+    "        'ClassToken': ClassToken,\n",
+    "        'AddPositionEmbs': AddPositionEmbs,\n",
+    "        'MultiHeadSelfAttention': MultiHeadSelfAttention,\n",
+    "        'TransformerBlock': TransformerBlock,\n",
+    "    })\n",
+    "\n",
+    "    # freeze\n",
+    "    for layer in backbone_model.layers:\n",
+    "        layer.trainable = False\n",
+    "    \n",
+    "    outputs = []\n",
+    "    for branch_number in range(1, 12):\n",
+    "        y, _ = backbone_model.get_layer(get_branch_id(branch_number)).output\n",
+    "        outputs.append(y)\n",
+    "    \n",
+    "    model = tf.keras.models.Model(\n",
+    "        inputs=backbone_model.get_layer(index=0).input,\n",
+    "        outputs=outputs\n",
+    "    )\n",
+    "\n",
+    "    return model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def precompute(dataset, batch_size=32 * NUM_GPUS):\n",
+    "    with DISTRIBUTED_STRATEGY.scope():\n",
+    "        model = get_model(dataset)\n",
+    "    for split in ['train', 'val', 'test']:\n",
+    "        print(split)\n",
+    "        total_count = sum([1 if file_name.startswith(split) else 0 for file_name in os.listdir(dataset)])\n",
+    "        batch_count = math.ceil(total_count / batch_size)\n",
+    "        for batch_index in range(batch_count):\n",
+    "            sys.stdout.write('\\r[%d/%d]' % (batch_index + 1, batch_count))\n",
+    "            sys.stdout.flush()\n",
+    "            images = []\n",
+    "            labels = []\n",
+    "            for sample_index in range(batch_index * batch_size, (batch_index + 1) * batch_size):\n",
+    "                image_path = os.path.join(dataset, '%s_%d.pkl' % (split, sample_index))\n",
+    "                if os.path.exists(image_path):  # last batch may contain less\n",
+    "                    with open(image_path, 'rb') as cache_file:\n",
+    "                        contents = pickle.load(cache_file)\n",
+    "                        images.append(contents['image'])\n",
+    "                        labels.append(contents['label'])\n",
+    "            outputs = model(np.array(images))\n",
+    "            for branch_number in range(1, 12):\n",
+    "                branch_outputs = outputs[branch_number - 1]\n",
+    "                for i, branch_output in enumerate(branch_outputs):\n",
+    "                    sample_index = batch_index * batch_size + i\n",
+    "                    sample_path = os.path.join(\n",
+    "                        PRECOMPUTE_DIR,\n",
+    "                        dataset,\n",
+    "                        '%s_branch%d_sample%d.pkl' % (split, branch_number, sample_index)\n",
+    "                    )\n",
+    "                    with open(sample_path, 'wb') as sample_file:\n",
+    "                        pickle.dump({\n",
+    "                            'features': branch_output,\n",
+    "                            'label': labels[i],\n",
+    "                        }, sample_file)\n",
+    "        print()  # newline"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "precompute('cifar10')\n",
+    "precompute('cifar100')"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.9"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/precompute_disco_features.ipynb b/precompute_disco_features.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..d2b1f52fe5152d0478e6afcd33b19e015470e22a
--- /dev/null
+++ b/precompute_disco_features.ipynb
@@ -0,0 +1,160 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "SELECTED_GPUS = [4]\n",
+    "\n",
+    "import os\n",
+    "\n",
+    "os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu_number) for gpu_number in SELECTED_GPUS])\n",
+    "\n",
+    "import tensorflow as tf \n",
+    "\n",
+    "tf.get_logger().setLevel('INFO')\n",
+    "\n",
+    "assert len(tf.config.list_physical_devices('GPU')) > 0\n",
+    "\n",
+    "GPUS = tf.config.experimental.list_physical_devices('GPU')\n",
+    "for gpu in GPUS:\n",
+    "    tf.config.experimental.set_memory_growth(gpu, True)\n",
+    "\n",
+    "DISTRIBUTED_STRATEGY = tf.distribute.MirroredStrategy(\n",
+    "    cross_device_ops=tf.distribute.NcclAllReduce(),\n",
+    "    devices=['/gpu:%d' % index for index in range(len(SELECTED_GPUS))]\n",
+    ")\n",
+    "\n",
+    "NUM_GPUS = DISTRIBUTED_STRATEGY.num_replicas_in_sync\n",
+    "\n",
+    "print('Number of devices: {}'.format(NUM_GPUS))\n",
+    "\n",
+    "import math\n",
+    "import numpy as np\n",
+    "import pickle\n",
+    "import random\n",
+    "import sys\n",
+    "import time\n",
+    "from skimage import transform\n",
+    "from vit_keras import vit\n",
+    "from vit_keras.layers import ClassToken, AddPositionEmbs, MultiHeadSelfAttention, TransformerBlock\n",
+    "\n",
+    "PRECOMPUTE_DIR = 'precompute'\n",
+    "DISCO_PATH = 'disco'\n",
+    "CACHE_DIR = os.path.join(DISCO_PATH, 'vit_cache')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_branch_id(branch_number):\n",
+    "    if branch_number == 1:\n",
+    "        return 'transformer_block'\n",
+    "    else:\n",
+    "        return 'transformer_block_%d' % (branch_number - 1)\n",
+    "\n",
+    "def get_model():\n",
+    "    backbone_model = tf.keras.models.load_model('vit_cc_backbone_v2.h5', custom_objects={\n",
+    "        'ClassToken': ClassToken,\n",
+    "        'AddPositionEmbs': AddPositionEmbs,\n",
+    "        'MultiHeadSelfAttention': MultiHeadSelfAttention,\n",
+    "        'TransformerBlock': TransformerBlock,\n",
+    "    })\n",
+    "\n",
+    "    # freeze\n",
+    "    for layer in backbone_model.layers:\n",
+    "        layer.trainable = False\n",
+    "    \n",
+    "    outputs = []\n",
+    "    for branch_number in range(1, 12):\n",
+    "        y, _ = backbone_model.get_layer(get_branch_id(branch_number)).output\n",
+    "        outputs.append(y)\n",
+    "    \n",
+    "    model = tf.keras.models.Model(\n",
+    "        inputs=backbone_model.get_layer(index=0).input,\n",
+    "        outputs=outputs\n",
+    "    )\n",
+    "\n",
+    "    return model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def precompute(batch_size=32 * NUM_GPUS):\n",
+    "    with DISTRIBUTED_STRATEGY.scope():\n",
+    "        model = get_model()\n",
+    "    for split in ['train', 'val', 'test']:\n",
+    "        print(split)\n",
+    "        total_count = sum([1 if file_name.startswith(split) else 0 for file_name in os.listdir(CACHE_DIR)])\n",
+    "        batch_count = math.ceil(total_count / batch_size)\n",
+    "        for batch_index in range(batch_count):\n",
+    "            sys.stdout.write('\\r[%d/%d]' % (batch_index + 1, batch_count))\n",
+    "            sys.stdout.flush()\n",
+    "            images = []\n",
+    "            labels = []\n",
+    "            for sample_index in range(batch_index * batch_size, (batch_index + 1) * batch_size):\n",
+    "                image_path = os.path.join(CACHE_DIR, '%s_%d.pkl' % (split, sample_index))\n",
+    "                if os.path.exists(image_path):  # last batch may contain less\n",
+    "                    with open(image_path, 'rb') as cache_file:\n",
+    "                        contents = pickle.load(cache_file)\n",
+    "                        images.append(contents['image'])\n",
+    "                        labels.append(np.sum(contents['density_map']))\n",
+    "            outputs = model(np.array(images))\n",
+    "            for branch_number in range(1, 12):\n",
+    "                branch_outputs = outputs[branch_number - 1]\n",
+    "                for i, branch_output in enumerate(branch_outputs):\n",
+    "                    sample_index = batch_index * batch_size + i\n",
+    "                    sample_path = os.path.join(\n",
+    "                        PRECOMPUTE_DIR,\n",
+    "                        'disco',\n",
+    "                        '%s_branch%d_sample%d.pkl' % (split, branch_number, sample_index)\n",
+    "                    )\n",
+    "                    with open(sample_path, 'wb') as sample_file:\n",
+    "                        pickle.dump({\n",
+    "                            'features': branch_output,\n",
+    "                            'label': labels[i],\n",
+    "                        }, sample_file)\n",
+    "        print()  # newline"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "precompute()"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.9"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/precompute_fashion_mnist_features.ipynb b/precompute_fashion_mnist_features.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..4926feb885617f15806c5bf8a456f7b6751adf61
--- /dev/null
+++ b/precompute_fashion_mnist_features.ipynb
@@ -0,0 +1,164 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "SELECTED_GPUS = [7]\n",
+    "\n",
+    "import os\n",
+    "\n",
+    "os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu_number) for gpu_number in SELECTED_GPUS])\n",
+    "\n",
+    "import tensorflow as tf \n",
+    "\n",
+    "tf.get_logger().setLevel('INFO')\n",
+    "\n",
+    "assert len(tf.config.list_physical_devices('GPU')) > 0\n",
+    "\n",
+    "GPUS = tf.config.experimental.list_physical_devices('GPU')\n",
+    "for gpu in GPUS:\n",
+    "    tf.config.experimental.set_memory_growth(gpu, True)\n",
+    "\n",
+    "DISTRIBUTED_STRATEGY = tf.distribute.MirroredStrategy(\n",
+    "    cross_device_ops=tf.distribute.NcclAllReduce(),\n",
+    "    devices=['/gpu:%d' % index for index in range(len(SELECTED_GPUS))]\n",
+    ")\n",
+    "\n",
+    "NUM_GPUS = DISTRIBUTED_STRATEGY.num_replicas_in_sync\n",
+    "\n",
+    "print('Number of devices: {}'.format(NUM_GPUS))\n",
+    "\n",
+    "import math\n",
+    "import numpy as np\n",
+    "import pickle\n",
+    "import random\n",
+    "import sys\n",
+    "import time\n",
+    "from skimage import transform\n",
+    "from vit_keras import vit\n",
+    "from vit_keras.layers import ClassToken, AddPositionEmbs, MultiHeadSelfAttention, TransformerBlock\n",
+    "\n",
+    "PRECOMPUTE_DIR = 'precompute'\n",
+    "PRECOMPUTE_FASHION_MNIST_DIR = os.path.join(PRECOMPUTE_DIR, 'fashion_mnist')\n",
+    "if not os.path.exists(PRECOMPUTE_FASHION_MNIST_DIR):\n",
+    "    os.makedirs(PRECOMPUTE_FASHION_MNIST_DIR)\n",
+    "CACHE_DIR = 'fashion_mnist'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_branch_id(branch_number):\n",
+    "    if branch_number == 1:\n",
+    "        return 'transformer_block'\n",
+    "    else:\n",
+    "        return 'transformer_block_%d' % (branch_number - 1)\n",
+    "\n",
+    "def get_model():\n",
+    "    backbone_model = tf.keras.models.load_model('vit_fashion_mnist_v1.h5', custom_objects={\n",
+    "        'ClassToken': ClassToken,\n",
+    "        'AddPositionEmbs': AddPositionEmbs,\n",
+    "        'MultiHeadSelfAttention': MultiHeadSelfAttention,\n",
+    "        'TransformerBlock': TransformerBlock,\n",
+    "    })\n",
+    "\n",
+    "    # freeze\n",
+    "    for layer in backbone_model.layers:\n",
+    "        layer.trainable = False\n",
+    "    \n",
+    "    outputs = []\n",
+    "    for branch_number in range(1, 12):\n",
+    "        y, _ = backbone_model.get_layer(get_branch_id(branch_number)).output\n",
+    "        outputs.append(y)\n",
+    "    \n",
+    "    model = tf.keras.models.Model(\n",
+    "        inputs=backbone_model.get_layer(index=0).input,\n",
+    "        outputs=outputs\n",
+    "    )\n",
+    "\n",
+    "    return model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def precompute(batch_size=32 * NUM_GPUS):\n",
+    "    with DISTRIBUTED_STRATEGY.scope():\n",
+    "        model = get_model()\n",
+    "    for split in ['train', 'val', 'test']:\n",
+    "        print(split)\n",
+    "        total_count = sum([1 if file_name.startswith(split) else 0 for file_name in os.listdir(CACHE_DIR)])\n",
+    "        batch_count = math.ceil(total_count / batch_size)\n",
+    "        for batch_index in range(batch_count):\n",
+    "            sys.stdout.write('\\r[%d/%d]' % (batch_index + 1, batch_count))\n",
+    "            sys.stdout.flush()\n",
+    "            images = []\n",
+    "            labels = []\n",
+    "            for sample_index in range(batch_index * batch_size, (batch_index + 1) * batch_size):\n",
+    "                image_path = os.path.join(CACHE_DIR, '%s_%d.pkl' % (split, sample_index))\n",
+    "                if os.path.exists(image_path):  # last batch may contain less\n",
+    "                    with open(image_path, 'rb') as cache_file:\n",
+    "                        contents = pickle.load(cache_file)\n",
+    "                        image = contents['image']\n",
+    "                        expanded = np.expand_dims(image, axis=-1)\n",
+    "                        repeated = np.repeat(expanded, 3, axis=-1)\n",
+    "                        images.append(repeated)\n",
+    "                        labels.append(contents['label'])\n",
+    "            outputs = model(np.array(images))\n",
+    "            for branch_number in range(1, 12):\n",
+    "                branch_outputs = outputs[branch_number - 1]\n",
+    "                for i, branch_output in enumerate(branch_outputs):\n",
+    "                    sample_index = batch_index * batch_size + i\n",
+    "                    sample_path = os.path.join(\n",
+    "                        PRECOMPUTE_FASHION_MNIST_DIR,\n",
+    "                        '%s_branch%d_sample%d.pkl' % (split, branch_number, sample_index)\n",
+    "                    )\n",
+    "                    with open(sample_path, 'wb') as sample_file:\n",
+    "                        pickle.dump({\n",
+    "                            'features': branch_output,\n",
+    "                            'label': labels[i],\n",
+    "                        }, sample_file)\n",
+    "        print()  # newline"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "precompute()"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.9"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/train_cifar100_backbone.ipynb b/train_cifar100_backbone.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..435e1ea5e403f62d3f1f7664806f8378d4241ca0
--- /dev/null
+++ b/train_cifar100_backbone.ipynb
@@ -0,0 +1,214 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "SELECTED_GPUS = [4, 5, 6, 7]\n",
+    "\n",
+    "import os\n",
+    "\n",
+    "os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu_number) for gpu_number in SELECTED_GPUS])\n",
+    "\n",
+    "import tensorflow as tf \n",
+    "\n",
+    "tf.get_logger().setLevel('INFO')\n",
+    "\n",
+    "assert len(tf.config.list_physical_devices('GPU')) > 0\n",
+    "\n",
+    "GPUS = tf.config.experimental.list_physical_devices('GPU')\n",
+    "for gpu in GPUS:\n",
+    "    tf.config.experimental.set_memory_growth(gpu, True)\n",
+    "\n",
+    "DISTRIBUTED_STRATEGY = tf.distribute.MirroredStrategy(\n",
+    "    cross_device_ops=tf.distribute.NcclAllReduce(),\n",
+    "    devices=['/gpu:%d' % index for index in range(len(SELECTED_GPUS))]\n",
+    ")\n",
+    "\n",
+    "NUM_GPUS = DISTRIBUTED_STRATEGY.num_replicas_in_sync\n",
+    "\n",
+    "print('Number of devices: {}'.format(NUM_GPUS))\n",
+    "\n",
+    "import math\n",
+    "import numpy as np\n",
+    "import pickle\n",
+    "import sys\n",
+    "from skimage import transform\n",
+    "from vit_keras import vit\n",
+    "\n",
+    "BATCH_SIZE = 8 * NUM_GPUS\n",
+    "IMAGE_SIZE = 384\n",
+    "CACHE_DIR = 'cifar100'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_model():\n",
+    "    model = vit.vit_b16(\n",
+    "        image_size=IMAGE_SIZE,\n",
+    "        activation='sigmoid',\n",
+    "        pretrained=True,\n",
+    "        include_top=True,\n",
+    "        pretrained_top=False,\n",
+    "        classes=100\n",
+    "    )\n",
+    "    return model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def cache_split(images, labels, split):\n",
+    "    for i in range(images.shape[0]):\n",
+    "        if (i + 1) % 100 == 0:\n",
+    "            sys.stdout.write('\\r%d' % (i + 1))\n",
+    "            sys.stdout.flush()\n",
+    "        with open(os.path.join(CACHE_DIR, '%s_%d.pkl' % (split, i)), 'wb') as cache_file:\n",
+    "            pickle.dump({\n",
+    "                'image': transform.resize(images[i], (IMAGE_SIZE, IMAGE_SIZE)),\n",
+    "                'label': labels[i],\n",
+    "            }, cache_file)\n",
+    "    print()  # newline\n",
+    "\n",
+    "def cache_all():\n",
+    "    (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar100.load_data()\n",
+    "\n",
+    "    train_labels = tf.keras.utils.to_categorical(train_labels)\n",
+    "    test_labels = tf.keras.utils.to_categorical(test_labels)\n",
+    "\n",
+    "    val_index = int(len(train_images) * 0.8)\n",
+    "    val_images = train_images[val_index:]\n",
+    "    val_labels = train_labels[val_index:]\n",
+    "    train_images = train_images[:val_index]\n",
+    "    train_labels = train_labels[:val_index]\n",
+    "\n",
+    "    cache_split(train_images, train_labels, 'train')\n",
+    "    cache_split(val_images, val_labels, 'val')\n",
+    "    cache_split(test_images, test_labels, 'test')\n",
+    "\n",
+    "class CIFAR100Sequence(tf.keras.utils.Sequence):\n",
+    "    def __init__(self, split):\n",
+    "        self.split = split\n",
+    "        self.count = sum([1 if file_name.startswith(split) else 0 for file_name in os.listdir(CACHE_DIR)])\n",
+    "        self.random_permutation = np.random.permutation(self.count)\n",
+    "\n",
+    "    def __len__(self):\n",
+    "        return math.ceil(self.count / BATCH_SIZE)\n",
+    "\n",
+    "    def on_epoch_end(self):\n",
+    "        self.random_permutation = np.random.permutation(self.count)\n",
+    "\n",
+    "    def __getitem__(self, index):\n",
+    "        images = []\n",
+    "        labels = []\n",
+    "        for i in self.random_permutation[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]:\n",
+    "            with open(os.path.join(CACHE_DIR, '%s_%d.pkl' % (self.split, i)), 'rb') as cache_file:\n",
+    "                contents = pickle.load(cache_file)\n",
+    "                images.append(contents['image'])\n",
+    "                labels.append(contents['label'])\n",
+    "        return np.array(images), np.array(labels)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def train(max_epochs):\n",
+    "    with DISTRIBUTED_STRATEGY.scope():\n",
+    "        model = get_model()\n",
+    "        model.compile(\n",
+    "            optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),\n",
+    "            loss='categorical_crossentropy',\n",
+    "            metrics=['accuracy']\n",
+    "        )\n",
+    "\n",
+    "    lr_reduce = tf.keras.callbacks.ReduceLROnPlateau(\n",
+    "        monitor='val_accuracy',\n",
+    "        factor=0.6,\n",
+    "        patience=2,\n",
+    "        verbose=1,\n",
+    "        mode='max',\n",
+    "        min_lr=1e-7\n",
+    "    )\n",
+    "\n",
+    "    early_stop = tf.keras.callbacks.EarlyStopping(\n",
+    "        monitor='val_accuracy',\n",
+    "        patience=5,\n",
+    "        verbose=1,\n",
+    "        mode='max'\n",
+    "    )\n",
+    "\n",
+    "    model_checkpoint_file = 'vit_cifar100_v1.h5'\n",
+    "\n",
+    "    checkpoint = tf.keras.callbacks.ModelCheckpoint(\n",
+    "        model_checkpoint_file,\n",
+    "        monitor='val_accuracy',\n",
+    "        verbose=1,\n",
+    "        save_weights_only=False,\n",
+    "        save_best_only=True,\n",
+    "        mode='max',\n",
+    "        save_freq='epoch'\n",
+    "    )\n",
+    "\n",
+    "    history = model.fit(\n",
+    "        CIFAR100Sequence('train'),\n",
+    "        validation_data=CIFAR100Sequence('val'),\n",
+    "        epochs=max_epochs,\n",
+    "        shuffle=True,\n",
+    "        callbacks=[\n",
+    "            lr_reduce,\n",
+    "            early_stop,\n",
+    "            checkpoint\n",
+    "        ],\n",
+    "        verbose=1\n",
+    "    )\n",
+    "\n",
+    "    test_accuracy = model.evaluate(CIFAR100Sequence('test'))[1]\n",
+    "\n",
+    "    return model, test_accuracy"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "cache_all()\n",
+    "model, test_accuracy = train(100)"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.9"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/train_cifar10_backbone.ipynb b/train_cifar10_backbone.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..6cdb3097501039763eed07e3d7e1d8c13e4c4ade
--- /dev/null
+++ b/train_cifar10_backbone.ipynb
@@ -0,0 +1,214 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "SELECTED_GPUS = [4, 5, 6, 7]\n",
+    "\n",
+    "import os\n",
+    "\n",
+    "os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu_number) for gpu_number in SELECTED_GPUS])\n",
+    "\n",
+    "import tensorflow as tf \n",
+    "\n",
+    "tf.get_logger().setLevel('INFO')\n",
+    "\n",
+    "assert len(tf.config.list_physical_devices('GPU')) > 0\n",
+    "\n",
+    "GPUS = tf.config.experimental.list_physical_devices('GPU')\n",
+    "for gpu in GPUS:\n",
+    "    tf.config.experimental.set_memory_growth(gpu, True)\n",
+    "\n",
+    "DISTRIBUTED_STRATEGY = tf.distribute.MirroredStrategy(\n",
+    "    cross_device_ops=tf.distribute.NcclAllReduce(),\n",
+    "    devices=['/gpu:%d' % index for index in range(len(SELECTED_GPUS))]\n",
+    ")\n",
+    "\n",
+    "NUM_GPUS = DISTRIBUTED_STRATEGY.num_replicas_in_sync\n",
+    "\n",
+    "print('Number of devices: {}'.format(NUM_GPUS))\n",
+    "\n",
+    "import math\n",
+    "import numpy as np\n",
+    "import pickle\n",
+    "import sys\n",
+    "from skimage import transform\n",
+    "from vit_keras import vit\n",
+    "\n",
+    "BATCH_SIZE = 8 * NUM_GPUS\n",
+    "IMAGE_SIZE = 384\n",
+    "CACHE_DIR = 'cifar10'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_model():\n",
+    "    model = vit.vit_b16(\n",
+    "        image_size=IMAGE_SIZE,\n",
+    "        activation='sigmoid',\n",
+    "        pretrained=True,\n",
+    "        include_top=True,\n",
+    "        pretrained_top=False,\n",
+    "        classes=10\n",
+    "    )\n",
+    "    return model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def cache_split(images, labels, split):\n",
+    "    for i in range(images.shape[0]):\n",
+    "        if (i + 1) % 100 == 0:\n",
+    "            sys.stdout.write('\\r%d' % (i + 1))\n",
+    "            sys.stdout.flush()\n",
+    "        with open(os.path.join(CACHE_DIR, '%s_%d.pkl' % (split, i)), 'wb') as cache_file:\n",
+    "            pickle.dump({\n",
+    "                'image': transform.resize(images[i], (IMAGE_SIZE, IMAGE_SIZE)),\n",
+    "                'label': labels[i],\n",
+    "            }, cache_file)\n",
+    "    print()  # newline\n",
+    "\n",
+    "def cache_all():\n",
+    "    (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()\n",
+    "\n",
+    "    train_labels = tf.keras.utils.to_categorical(train_labels)\n",
+    "    test_labels = tf.keras.utils.to_categorical(test_labels)\n",
+    "\n",
+    "    val_index = int(len(train_images) * 0.8)\n",
+    "    val_images = train_images[val_index:]\n",
+    "    val_labels = train_labels[val_index:]\n",
+    "    train_images = train_images[:val_index]\n",
+    "    train_labels = train_labels[:val_index]\n",
+    "\n",
+    "    cache_split(train_images, train_labels, 'train')\n",
+    "    cache_split(val_images, val_labels, 'val')\n",
+    "    cache_split(test_images, test_labels, 'test')\n",
+    "\n",
+    "class CIFAR10Sequence(tf.keras.utils.Sequence):\n",
+    "    def __init__(self, split):\n",
+    "        self.split = split\n",
+    "        self.count = sum([1 if file_name.startswith(split) else 0 for file_name in os.listdir(CACHE_DIR)])\n",
+    "        self.random_permutation = np.random.permutation(self.count)\n",
+    "\n",
+    "    def __len__(self):\n",
+    "        return math.ceil(self.count / BATCH_SIZE)\n",
+    "\n",
+    "    def on_epoch_end(self):\n",
+    "        self.random_permutation = np.random.permutation(self.count)\n",
+    "\n",
+    "    def __getitem__(self, index):\n",
+    "        images = []\n",
+    "        labels = []\n",
+    "        for i in self.random_permutation[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]:\n",
+    "            with open(os.path.join(CACHE_DIR, '%s_%d.pkl' % (self.split, i)), 'rb') as cache_file:\n",
+    "                contents = pickle.load(cache_file)\n",
+    "                images.append(contents['image'])\n",
+    "                labels.append(contents['label'])\n",
+    "        return np.array(images), np.array(labels)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def train(max_epochs):\n",
+    "    with DISTRIBUTED_STRATEGY.scope():\n",
+    "        model = get_model()\n",
+    "        model.compile(\n",
+    "            optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),\n",
+    "            loss='categorical_crossentropy',\n",
+    "            metrics=['accuracy']\n",
+    "        )\n",
+    "\n",
+    "    lr_reduce = tf.keras.callbacks.ReduceLROnPlateau(\n",
+    "        monitor='val_accuracy',\n",
+    "        factor=0.6,\n",
+    "        patience=2,\n",
+    "        verbose=1,\n",
+    "        mode='max',\n",
+    "        min_lr=1e-7\n",
+    "    )\n",
+    "\n",
+    "    early_stop = tf.keras.callbacks.EarlyStopping(\n",
+    "        monitor='val_accuracy',\n",
+    "        patience=5,\n",
+    "        verbose=1,\n",
+    "        mode='max'\n",
+    "    )\n",
+    "\n",
+    "    model_checkpoint_file = 'vit_cifar10_v1.h5'\n",
+    "\n",
+    "    checkpoint = tf.keras.callbacks.ModelCheckpoint(\n",
+    "        model_checkpoint_file,\n",
+    "        monitor='val_accuracy',\n",
+    "        verbose=1,\n",
+    "        save_weights_only=False,\n",
+    "        save_best_only=True,\n",
+    "        mode='max',\n",
+    "        save_freq='epoch'\n",
+    "    )\n",
+    "\n",
+    "    history = model.fit(\n",
+    "        CIFAR10Sequence('train'),\n",
+    "        validation_data=CIFAR10Sequence('val'),\n",
+    "        epochs=max_epochs,\n",
+    "        shuffle=True,\n",
+    "        callbacks=[\n",
+    "            lr_reduce,\n",
+    "            early_stop,\n",
+    "            checkpoint\n",
+    "        ],\n",
+    "        verbose=1\n",
+    "    )\n",
+    "\n",
+    "    test_accuracy = model.evaluate(CIFAR10Sequence('test'))[1]\n",
+    "\n",
+    "    return model, test_accuracy"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "cache_all()\n",
+    "model, test_accuracy = train(100)"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.9"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/train_disco_backbone.ipynb b/train_disco_backbone.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..0b8cb3964740b2ac260b6498f58254ffd3f3c932
--- /dev/null
+++ b/train_disco_backbone.ipynb
@@ -0,0 +1,395 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "SELECTED_GPUS = [4, 5, 6, 7]  # which GPUs to use\n",
+    "\n",
+    "import os\n",
+    "\n",
+    "os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu_number) for gpu_number in SELECTED_GPUS])\n",
+    "\n",
+    "import tensorflow as tf \n",
+    "\n",
+    "tf.get_logger().setLevel('INFO')\n",
+    "\n",
+    "assert len(tf.config.list_physical_devices('GPU')) > 0\n",
+    "\n",
+    "GPUS = tf.config.experimental.list_physical_devices('GPU')\n",
+    "for gpu in GPUS:\n",
+    "    tf.config.experimental.set_memory_growth(gpu, True)\n",
+    "\n",
+    "DISTRIBUTED_STRATEGY = tf.distribute.MirroredStrategy(\n",
+    "    cross_device_ops=tf.distribute.NcclAllReduce(),\n",
+    "    devices=['/gpu:%d' % index for index in range(len(SELECTED_GPUS))]\n",
+    ")\n",
+    "\n",
+    "NUM_GPUS = DISTRIBUTED_STRATEGY.num_replicas_in_sync\n",
+    "\n",
+    "print('Number of devices: {}'.format(NUM_GPUS))\n",
+    "\n",
+    "import math\n",
+    "import matplotlib.pyplot as plt\n",
+    "import numpy as np\n",
+    "import pickle\n",
+    "import random\n",
+    "import scipy.io\n",
+    "import scipy.stats as st\n",
+    "import sys\n",
+    "import tensorflow_addons as tfa\n",
+    "from scipy import signal\n",
+    "from skimage.transform import resize\n",
+    "\n",
+    "DISCO_PATH = 'disco'\n",
+    "WAVEFORMS_PATH = os.path.join(DISCO_PATH, 'auds')\n",
+    "IMAGES_PATH = os.path.join(DISCO_PATH, 'imgs')\n",
+    "TRAIN_DENSITY_MAPS_PATH = os.path.join(DISCO_PATH, 'train')\n",
+    "VAL_DENSITY_MAPS_PATH = os.path.join(DISCO_PATH, 'val')\n",
+    "TEST_DENSITY_MAPS_PATH = os.path.join(DISCO_PATH, 'test')\n",
+    "CACHE_DIR = os.path.join(DISCO_PATH, 'vit_cache')\n",
+    "\n",
+    "from vit_keras import vit\n",
+    "from vit_keras.layers import ClassToken, AddPositionEmbs, MultiHeadSelfAttention, TransformerBlock\n",
+    "\n",
+    "IMAGE_SIZE = 384\n",
+    "VIDEO_PATCHES = (2, 3)\n",
+    "VIDEO_SIZE = (VIDEO_PATCHES[0] * IMAGE_SIZE, VIDEO_PATCHES[1] * IMAGE_SIZE)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_model():\n",
+    "    backbone_model = vit.vit_b16(\n",
+    "        image_size=IMAGE_SIZE,\n",
+    "        pretrained=True,\n",
+    "        include_top=False,\n",
+    "        pretrained_top=False\n",
+    "    )\n",
+    "    y = backbone_model.get_layer(index=-1).output\n",
+    "    y = tf.keras.layers.Dense(1, name='regression_head')(y)\n",
+    "    model = tf.keras.models.Model(\n",
+    "        inputs=backbone_model.get_layer(index=0).input,\n",
+    "        outputs=y\n",
+    "    )\n",
+    "    return model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_dataset_split(split):\n",
+    "    examples = {}\n",
+    "    for file_name in os.listdir(WAVEFORMS_PATH):\n",
+    "        waveform_path = os.path.join(WAVEFORMS_PATH, file_name)\n",
+    "        if os.path.isfile(waveform_path) and file_name.endswith('.wav'):\n",
+    "            key = '.'.join(file_name.split('.')[:-1])\n",
+    "            if key not in examples:\n",
+    "                examples[key] = {}\n",
+    "            examples[key]['waveform_path'] = waveform_path\n",
+    "    for file_name in os.listdir(IMAGES_PATH):\n",
+    "        image_path = os.path.join(IMAGES_PATH, file_name)\n",
+    "        if os.path.isfile(image_path) and file_name.endswith('.jpg'):\n",
+    "            key = '.'.join(file_name.split('.')[:-1])\n",
+    "            if key not in examples:\n",
+    "                examples[key] = {}\n",
+    "            examples[key]['image_path'] = image_path\n",
+    "    for file_name in os.listdir(TRAIN_DENSITY_MAPS_PATH):\n",
+    "        density_map_path = os.path.join(TRAIN_DENSITY_MAPS_PATH, file_name)\n",
+    "        if os.path.isfile(density_map_path) and file_name.endswith('.mat'):\n",
+    "            key = '.'.join(file_name.split('.')[:-1])\n",
+    "            if key not in examples:\n",
+    "                examples[key] = {}\n",
+    "            examples[key]['density_map_path'] = density_map_path\n",
+    "            examples[key]['split'] = 'train'\n",
+    "    for file_name in os.listdir(VAL_DENSITY_MAPS_PATH):\n",
+    "        density_map_path = os.path.join(VAL_DENSITY_MAPS_PATH, file_name)\n",
+    "        if os.path.isfile(density_map_path) and file_name.endswith('.mat'):\n",
+    "            key = '.'.join(file_name.split('.')[:-1])\n",
+    "            if key not in examples:\n",
+    "                examples[key] = {}\n",
+    "            examples[key]['density_map_path'] = density_map_path\n",
+    "            examples[key]['split'] = 'val'\n",
+    "    for file_name in os.listdir(TEST_DENSITY_MAPS_PATH):\n",
+    "        density_map_path = os.path.join(TEST_DENSITY_MAPS_PATH, file_name)\n",
+    "        if os.path.isfile(density_map_path) and file_name.endswith('.mat'):\n",
+    "            key = '.'.join(file_name.split('.')[:-1])\n",
+    "            if key not in examples:\n",
+    "                examples[key] = {}\n",
+    "            examples[key]['density_map_path'] = density_map_path\n",
+    "            examples[key]['split'] = 'test'\n",
+    "    final_examples = []\n",
+    "    for key, info in examples.items():\n",
+    "        if 'split' in info and info['split'] == split:\n",
+    "            final_examples.append(info)\n",
+    "    return final_examples\n",
+    "\n",
+    "def visualize_data(image, density_map):\n",
+    "    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(20, 10))\n",
+    "    ax1.imshow(image)\n",
+    "    ax2.imshow(density_map)\n",
+    "    ax1.axis('off')\n",
+    "    ax2.axis('off')\n",
+    "    plt.show()\n",
+    "\n",
+    "def get_gaussian_kernel(kernel_size, sigma):\n",
+    "    \"\"\"\n",
+    "    Returns a 2D Gaussian kernel.\n",
+    "    from:\n",
+    "    https://stackoverflow.com/questions/29731726/how-to-calculate-a-gaussian-kernel-matrix-efficiently-in-numpy\n",
+    "    \"\"\"\n",
+    "    x = np.linspace(-sigma, sigma, kernel_size + 1)\n",
+    "    kern1d = np.diff(st.norm.cdf(x))\n",
+    "    kern2d = np.outer(kern1d, kern1d)\n",
+    "    return kern2d / kern2d.sum()\n",
+    "\n",
+    "def extract_patches(image):\n",
+    "    patches = []\n",
+    "    for i in range(VIDEO_PATCHES[0]):\n",
+    "        for j in range(VIDEO_PATCHES[1]):\n",
+    "            if len(image.shape) == 3:\n",
+    "                patches.append(\n",
+    "                    image[i * IMAGE_SIZE:(i + 1) * IMAGE_SIZE, j * IMAGE_SIZE:(j + 1) * IMAGE_SIZE, :]\n",
+    "                )\n",
+    "            else:\n",
+    "                patches.append(image[i * IMAGE_SIZE:(i + 1) * IMAGE_SIZE, j * IMAGE_SIZE:(j + 1) * IMAGE_SIZE])\n",
+    "    return np.array(patches)\n",
+    "\n",
+    "def precompute_batches():\n",
+    "    gaussian_kernel = get_gaussian_kernel(15, 4)\n",
+    "    split_lens = []\n",
+    "    resize_errors = []\n",
+    "    for split in ['train', 'val', 'test']:\n",
+    "        infos = get_dataset_split(split)\n",
+    "        infos_len = len(infos)\n",
+    "        split_lens.append(infos_len * VIDEO_PATCHES[0] * VIDEO_PATCHES[1])\n",
+    "        for index in range(infos_len):\n",
+    "            sys.stdout.write('\\r%d' % (index + 1))\n",
+    "            sys.stdout.flush()\n",
+    "\n",
+    "            info = infos[index]\n",
+    "            crowd_image = plt.imread(info['image_path'], format='jpeg')\n",
+    "            resized_crowd_image = resize(crowd_image, VIDEO_SIZE)\n",
+    "            crowd_image_patches = extract_patches(resized_crowd_image)\n",
+    "\n",
+    "            head_annotation = scipy.io.loadmat(info['density_map_path'])['map']\n",
+    "            density_map = signal.convolve2d(head_annotation, gaussian_kernel)\n",
+    "            resize_factor = density_map.shape[0] / VIDEO_SIZE[0] * density_map.shape[1] / VIDEO_SIZE[1]\n",
+    "            resized_density_map = resize(density_map, VIDEO_SIZE) * resize_factor  # to preserve sum\n",
+    "            density_patches = extract_patches(resized_density_map)\n",
+    "\n",
+    "            resize_errors.append(np.abs(np.sum(density_patches) - np.sum(resized_density_map)))\n",
+    "\n",
+    "            for patch_index in range(VIDEO_PATCHES[0] * VIDEO_PATCHES[1]):\n",
+    "                    all_path = os.path.join(\n",
+    "                        CACHE_DIR,\n",
+    "                        '%s_%d.pkl' % (split, index * VIDEO_PATCHES[0] * VIDEO_PATCHES[1] + patch_index)\n",
+    "                    )\n",
+    "                    with open(all_path, 'wb') as all_file:\n",
+    "                        pickle.dump({\n",
+    "                            'image': crowd_image_patches[patch_index],\n",
+    "                            'density_map': density_patches[patch_index],\n",
+    "                        }, all_file)\n",
+    "        print()  # newline\n",
+    "    if resize_errors:\n",
+    "        print('Mean absolute resize error:', np.mean(resize_errors))\n",
+    "    return split_lens\n",
+    "\n",
+    "def horizontal_flip(image):\n",
+    "    return np.flip(image, axis=1)\n",
+    "\n",
+    "class CCSequence(tf.keras.utils.Sequence):\n",
+    "    def __init__(self, split, batch_size):\n",
+    "        self.split = split\n",
+    "        self.split_len = sum([\n",
+    "            1 if file_name.startswith(self.split) else 0 for file_name in os.listdir(CACHE_DIR)\n",
+    "        ])\n",
+    "        self.batch_size = batch_size\n",
+    "        self.random_permutation = np.random.permutation(self.split_len)\n",
+    "\n",
+    "    def __len__(self):\n",
+    "        return math.ceil(self.split_len / self.batch_size)\n",
+    "\n",
+    "    def on_epoch_end(self):\n",
+    "        self.random_permutation = np.random.permutation(self.split_len)\n",
+    "\n",
+    "    def __getitem__(self, index):\n",
+    "        spectrograms = []\n",
+    "        images = []\n",
+    "        density_maps = []\n",
+    "        if self.split == 'test':\n",
+    "            index_generator = range(\n",
+    "                index * self.batch_size,\n",
+    "                min((index + 1) * self.batch_size, self.split_len - 1)\n",
+    "            )\n",
+    "        else:\n",
+    "            index_generator = self.random_permutation[index * self.batch_size:(index + 1) * self.batch_size]\n",
+    "        for random_index in index_generator:\n",
+    "            all_path = os.path.join(\n",
+    "                CACHE_DIR,\n",
+    "                '%s_%d.pkl' % (self.split, random_index)\n",
+    "            )\n",
+    "            with open(all_path, 'rb') as all_file:\n",
+    "                data = pickle.load(all_file)\n",
+    "                if self.split == 'train' and random.random() < 0.5:  # flip augmentation\n",
+    "                    images.append(horizontal_flip(data['image']))\n",
+    "                else:\n",
+    "                    images.append(data['image'])\n",
+    "                density_maps.append(np.sum(data['density_map']))\n",
+    "\n",
+    "        return np.array(images), np.array(density_maps)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def train_backbone(epochs):\n",
+    "    tf.keras.backend.clear_session()\n",
+    "\n",
+    "    batch_size=4 * NUM_GPUS\n",
+    "    train_sequence = CCSequence('train', batch_size)\n",
+    "    val_sequence = CCSequence('val', batch_size)\n",
+    "    test_sequence = CCSequence('test', batch_size)\n",
+    "\n",
+    "    with DISTRIBUTED_STRATEGY.scope():\n",
+    "        model = get_model()\n",
+    "        model.compile(\n",
+    "            optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),\n",
+    "            loss='mean_absolute_error',\n",
+    "            metrics=['mean_absolute_error']\n",
+    "        )\n",
+    "\n",
+    "    lr_reduce = tf.keras.callbacks.ReduceLROnPlateau(\n",
+    "        monitor='val_mean_absolute_error',\n",
+    "        factor=0.6,\n",
+    "        patience=2,\n",
+    "        verbose=1,\n",
+    "        mode='min',\n",
+    "        min_lr=1e-7\n",
+    "    )\n",
+    "\n",
+    "    model_checkpoint_file = 'vit_cc_backbone_v2.h5'\n",
+    "\n",
+    "    checkpoint = tf.keras.callbacks.ModelCheckpoint(\n",
+    "        model_checkpoint_file,\n",
+    "        monitor='val_mean_absolute_error',\n",
+    "        verbose=1,\n",
+    "        save_weights_only=False,\n",
+    "        save_best_only=True,\n",
+    "        mode='min',\n",
+    "        save_freq='epoch'\n",
+    "    )\n",
+    "\n",
+    "    history = model.fit(\n",
+    "        train_sequence,\n",
+    "        validation_data=val_sequence,\n",
+    "        epochs=epochs,\n",
+    "        shuffle=True,\n",
+    "        callbacks=[\n",
+    "            lr_reduce,\n",
+    "            checkpoint\n",
+    "        ],\n",
+    "        verbose=1\n",
+    "    )\n",
+    "\n",
+    "    model.evaluate(test_sequence)\n",
+    "    return model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "precompute_batches()\n",
+    "model = train_backbone(100)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "test_sequence = CCSequence('test', 4 * NUM_GPUS)\n",
+    "model = tf.keras.models.load_model('vit_cc_backbone_v2.h5', custom_objects={\n",
+    "    'ClassToken': ClassToken,\n",
+    "    'AddPositionEmbs': AddPositionEmbs,\n",
+    "    'MultiHeadSelfAttention': MultiHeadSelfAttention,\n",
+    "    'TransformerBlock': TransformerBlock,\n",
+    "})"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "gt = None\n",
+    "out = None\n",
+    "for i, (images, density_maps) in enumerate(test_sequence):\n",
+    "    sys.stdout.write('\\r%d' % (i + 1))\n",
+    "    sys.stdout.flush()\n",
+    "    if gt is not None:\n",
+    "        gt = np.concatenate((gt, density_maps))\n",
+    "    else:\n",
+    "        gt = density_maps\n",
+    "    if out is not None:\n",
+    "        out = np.concatenate((out, model(images).numpy().flatten()))\n",
+    "    else:\n",
+    "        out = model(images).numpy().flatten()\n",
+    "print()  # newline"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "mae = []\n",
+    "img_patches = VIDEO_PATCHES[0] * VIDEO_PATCHES[1]\n",
+    "for i in range(0, gt.shape[0], img_patches):\n",
+    "    gt_subset = gt[i:i + img_patches]\n",
+    "    out_subset = out[i:i + img_patches]\n",
+    "    mae.append(np.abs(np.sum(gt_subset) - np.sum(out_subset)))\n",
+    "print(np.mean(np.array(mae)))"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.9"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/train_fashion_mnist_backbone.ipynb b/train_fashion_mnist_backbone.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..bcceff9360d2f8810c6168786864237a3180632e
--- /dev/null
+++ b/train_fashion_mnist_backbone.ipynb
@@ -0,0 +1,220 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "SELECTED_GPUS = [7]\n",
+    "\n",
+    "import os\n",
+    "\n",
+    "os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(gpu_number) for gpu_number in SELECTED_GPUS])\n",
+    "\n",
+    "import tensorflow as tf \n",
+    "\n",
+    "tf.get_logger().setLevel('INFO')\n",
+    "\n",
+    "assert len(tf.config.list_physical_devices('GPU')) > 0\n",
+    "\n",
+    "GPUS = tf.config.experimental.list_physical_devices('GPU')\n",
+    "for gpu in GPUS:\n",
+    "    tf.config.experimental.set_memory_growth(gpu, True)\n",
+    "\n",
+    "DISTRIBUTED_STRATEGY = tf.distribute.MirroredStrategy(\n",
+    "    cross_device_ops=tf.distribute.NcclAllReduce(),\n",
+    "    devices=['/gpu:%d' % index for index in range(len(SELECTED_GPUS))]\n",
+    ")\n",
+    "\n",
+    "NUM_GPUS = DISTRIBUTED_STRATEGY.num_replicas_in_sync\n",
+    "\n",
+    "print('Number of devices: {}'.format(NUM_GPUS))\n",
+    "\n",
+    "import math\n",
+    "import numpy as np\n",
+    "import pickle\n",
+    "import sys\n",
+    "from skimage import transform\n",
+    "from vit_keras import vit\n",
+    "from vit_keras.layers import ClassToken, AddPositionEmbs, MultiHeadSelfAttention, TransformerBlock\n",
+    "\n",
+    "BATCH_SIZE = 8 * NUM_GPUS\n",
+    "IMAGE_SIZE = 384\n",
+    "CACHE_DIR = 'fashion_mnist'\n",
+    "if not os.path.exists(CACHE_DIR):\n",
+    "    os.makedirs(CACHE_DIR)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_model():\n",
+    "    model = vit.vit_b16(\n",
+    "        image_size=IMAGE_SIZE,\n",
+    "        activation='sigmoid',\n",
+    "        pretrained=True,\n",
+    "        include_top=True,\n",
+    "        pretrained_top=False,\n",
+    "        classes=10\n",
+    "    )\n",
+    "    return model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def cache_split(images, labels, split):\n",
+    "    for i in range(images.shape[0]):\n",
+    "        if (i + 1) % 100 == 0:\n",
+    "            sys.stdout.write('\\r%d' % (i + 1))\n",
+    "            sys.stdout.flush()\n",
+    "        with open(os.path.join(CACHE_DIR, '%s_%d.pkl' % (split, i)), 'wb') as cache_file:\n",
+    "            pickle.dump({\n",
+    "                'image': transform.resize(images[i], (IMAGE_SIZE, IMAGE_SIZE)),\n",
+    "                'label': labels[i],\n",
+    "            }, cache_file)\n",
+    "    print()  # newline\n",
+    "\n",
+    "def cache_all():\n",
+    "    (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()\n",
+    "\n",
+    "    train_labels = tf.keras.utils.to_categorical(train_labels)\n",
+    "    test_labels = tf.keras.utils.to_categorical(test_labels)\n",
+    "\n",
+    "    val_index = int(len(train_images) * 0.8)\n",
+    "    val_images = train_images[val_index:]\n",
+    "    val_labels = train_labels[val_index:]\n",
+    "    train_images = train_images[:val_index]\n",
+    "    train_labels = train_labels[:val_index]\n",
+    "\n",
+    "    cache_split(train_images, train_labels, 'train')\n",
+    "    cache_split(val_images, val_labels, 'val')\n",
+    "    cache_split(test_images, test_labels, 'test')\n",
+    "\n",
+    "class FashionMNISTSequence(tf.keras.utils.Sequence):\n",
+    "    def __init__(self, split):\n",
+    "        self.split = split\n",
+    "        self.count = sum([1 if file_name.startswith(split) else 0 for file_name in os.listdir(CACHE_DIR)])\n",
+    "        self.random_permutation = np.random.permutation(self.count)\n",
+    "\n",
+    "    def __len__(self):\n",
+    "        return math.ceil(self.count / BATCH_SIZE)\n",
+    "\n",
+    "    def on_epoch_end(self):\n",
+    "        self.random_permutation = np.random.permutation(self.count)\n",
+    "\n",
+    "    def __getitem__(self, index):\n",
+    "        images = []\n",
+    "        labels = []\n",
+    "        for i in self.random_permutation[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]:\n",
+    "            with open(os.path.join(CACHE_DIR, '%s_%d.pkl' % (self.split, i)), 'rb') as cache_file:\n",
+    "                contents = pickle.load(cache_file)\n",
+    "                image = contents['image']\n",
+    "                expanded = np.expand_dims(image, axis=-1)\n",
+    "                repeated = np.repeat(expanded, 3, axis=-1)\n",
+    "                images.append(repeated)\n",
+    "                labels.append(contents['label'])\n",
+    "        return np.array(images), np.array(labels)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def train(max_epochs):\n",
+    "    with DISTRIBUTED_STRATEGY.scope():\n",
+    "        model = get_model()\n",
+    "        model.compile(\n",
+    "            optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),\n",
+    "            loss='categorical_crossentropy',\n",
+    "            metrics=['accuracy']\n",
+    "        )\n",
+    "\n",
+    "    lr_reduce = tf.keras.callbacks.ReduceLROnPlateau(\n",
+    "        monitor='val_accuracy',\n",
+    "        factor=0.6,\n",
+    "        patience=2,\n",
+    "        verbose=1,\n",
+    "        mode='max',\n",
+    "        min_lr=1e-7\n",
+    "    )\n",
+    "\n",
+    "    early_stop = tf.keras.callbacks.EarlyStopping(\n",
+    "        monitor='val_accuracy',\n",
+    "        patience=5,\n",
+    "        verbose=1,\n",
+    "        mode='max'\n",
+    "    )\n",
+    "\n",
+    "    model_checkpoint_file = 'vit_fashion_mnist_v1.h5'\n",
+    "\n",
+    "    checkpoint = tf.keras.callbacks.ModelCheckpoint(\n",
+    "        model_checkpoint_file,\n",
+    "        monitor='val_accuracy',\n",
+    "        verbose=1,\n",
+    "        save_weights_only=False,\n",
+    "        save_best_only=True,\n",
+    "        mode='max',\n",
+    "        save_freq='epoch'\n",
+    "    )\n",
+    "\n",
+    "    history = model.fit(\n",
+    "        FashionMNISTSequence('train'),\n",
+    "        validation_data=FashionMNISTSequence('val'),\n",
+    "        epochs=max_epochs,\n",
+    "        shuffle=True,\n",
+    "        callbacks=[\n",
+    "            lr_reduce,\n",
+    "            early_stop,\n",
+    "            checkpoint\n",
+    "        ],\n",
+    "        verbose=1\n",
+    "    )\n",
+    "\n",
+    "    test_accuracy = model.evaluate(FashionMNISTSequence('test'))[1]\n",
+    "\n",
+    "    return model, test_accuracy"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "cache_all()\n",
+    "model, test_accuracy = train(100)"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.9"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}