diff --git a/doc/gallery/scan/benchmark_backends.ipynb b/doc/gallery/scan/benchmark_backends.ipynb
new file mode 100644
index 0000000000..4d55a80da1
--- /dev/null
+++ b/doc/gallery/scan/benchmark_backends.ipynb
@@ -0,0 +1,7171 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "9571ee33",
+ "metadata": {},
+ "source": [
+ "# Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "73cc811c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import jax.numpy as jnp\n",
+ "from numba import jit\n",
+ "import numba\n",
+ "import pytensor\n",
+ "import pytensor.tensor as pt\n",
+ "from pytensor.compile.mode import get_mode\n",
+ "import timeit\n",
+ "import jax\n",
+ "import math\n",
+ "from jax.scipy.special import gammaln\n",
+ "from functools import partial\n",
+ "\n",
+ "import plotly.graph_objects as go\n",
+ "from plotly.subplots import make_subplots\n",
+ "import pandas as pd"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "d5fc1350",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import plotly.io as pio\n",
+ "pio.renderers.default = \"notebook\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "daa0969b",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "pytensor version: 0+untagged.31345.g422eafd.dirty\n",
+ "jax version: 0.7.2\n",
+ "numba version: 0.62.1\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"pytensor version:\", pytensor.__version__)\n",
+ "print(\"jax version:\", jax.__version__)\n",
+ "print(\"numba version:\", numba.__version__)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "8294718d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class Benchmarker:\n",
+ " \"\"\"\n",
+ " Benchmark a set of functions by timing execution and summarizing statistics.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " functions : list of callables\n",
+ " List of callables to benchmark.\n",
+ " names : list of str, optional\n",
+ " Names corresponding to each function. Default is ['func_0', 'func_1', ...].\n",
+ " number : int or None, optional\n",
+ " Number of loops per timing. If None, auto-calibrated via Timer.autorange().\n",
+ " Default is None.\n",
+ " repeat : int, optional\n",
+ " Number of repeats for timing. Default is 7.\n",
+ " target_time : float, optional\n",
+ " Target duration in seconds for auto-calibration. Default is 0.2.\n",
+ "\n",
+ " Attributes\n",
+ " ----------\n",
+ " results : dict\n",
+ " Mapping from function names to a dict with keys:\n",
+ " - 'raw_us': numpy.ndarray of raw timings in microseconds\n",
+ " - 'loops': number of loops used per timing\n",
+ "\n",
+ " Methods\n",
+ " -------\n",
+ " run()\n",
+ " Auto-calibrate (if needed) and run timings for all functions.\n",
+ " summary(unit='us') -> pandas.DataFrame\n",
+ " Return a summary DataFrame with statistics converted to the given unit.\n",
+ " raw(name=None) -> dict or numpy.ndarray\n",
+ " Return raw timing data in microseconds for a specific function or all.\n",
+ " _convert_times(times, unit) -> numpy.ndarray\n",
+ " Convert an array of times from microseconds to the specified unit.\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(\n",
+ " self, functions, names=None, number=None, min_rounds=5, max_time=1.0, target_time=0.2\n",
+ " ):\n",
+ " self.functions = functions\n",
+ " self.names = names or [f\"func_{i}\" for i in range(len(functions))]\n",
+ " self.number = number\n",
+ " self.min_rounds = min_rounds\n",
+ " self.max_time = max_time\n",
+ " self.target_time = target_time\n",
+ " self.results = {}\n",
+ "\n",
+ " def run(self, inputs: dict[str, dict]):\n",
+ " \"\"\"\n",
+ " Auto-calibrate loop count and sample rounds if needed, then time each function.\n",
+ " \"\"\"\n",
+ " for name, func in zip(self.names, self.functions):\n",
+ " for input_name, kwargs in inputs.items():\n",
+ " timer = timeit.Timer(partial(func, **kwargs))\n",
+ "\n",
+ " # Calibrate loops\n",
+ " if self.number is None:\n",
+ " loops, calib_time = timer.autorange()\n",
+ " else:\n",
+ " loops = self.number\n",
+ " calib_time = timer.timeit(number=loops)\n",
+ "\n",
+ " # Determine rounds based on max_time and min_rounds\n",
+ " if self.max_time is not None:\n",
+ " rounds = max(self.min_rounds, int(np.ceil(self.max_time / calib_time)))\n",
+ " else:\n",
+ " rounds = self.min_rounds\n",
+ "\n",
+ " raw_round_times = np.array(timer.repeat(repeat=rounds, number=loops))\n",
+ "\n",
+ " # Convert to microseconds per single execution\n",
+ " raw_us = raw_round_times / loops * 1e6\n",
+ "\n",
+ " self.results[(name, input_name)] = {\n",
+ " \"raw_us\": raw_us,\n",
+ " \"loops\": loops,\n",
+ " \"rounds\": rounds,\n",
+ " }\n",
+ "\n",
+ " def summary(self, unit=\"us\"):\n",
+ " \"\"\"\n",
+ " Summarize benchmark statistics in a DataFrame.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " unit : {'us', 'ms', 'ns', 's'}, optional\n",
+ " Unit for output times. 'us' means microseconds, 'ms' milliseconds,\n",
+ " 'ns' nanoseconds, 's' seconds. Default is 'us'.\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " pandas.DataFrame\n",
+ " Summary with columns:\n",
+ " Name, Loops, Min, Max, Mean, StdDev, Median, IQR (all in given unit),\n",
+ " OPS (Kops/unit), Samples.\n",
+ " \"\"\"\n",
+ " records = []\n",
+ " indexes = []\n",
+ " for name, data in self.results.items():\n",
+ " raw_us = data[\"raw_us\"]\n",
+ " # Convert to target unit\n",
+ " times = self._convert_times(raw_us, unit)\n",
+ " if isinstance(name, tuple) and len(name) > 1:\n",
+ " indexes.append(name)\n",
+ " elif isinstance(name, tuple) and len(name) == 1:\n",
+ " indexes.append(name[0])\n",
+ " else:\n",
+ " indexes.append(name)\n",
+ "\n",
+ " stats = {\n",
+ " \"Loops\": data[\"loops\"],\n",
+ " f\"Min ({unit})\": np.min(times),\n",
+ " f\"Max ({unit})\": np.max(times),\n",
+ " f\"Mean ({unit})\": np.mean(times),\n",
+ " f\"StdDev ({unit})\": np.std(times),\n",
+ " f\"Median ({unit})\": np.median(times),\n",
+ " f\"IQR ({unit})\": np.percentile(times, 75) - np.percentile(times, 25),\n",
+ " \"OPS (Kops/s)\": 1e3 / (np.mean(raw_us)),\n",
+ " \"Samples\": len(raw_us),\n",
+ " }\n",
+ " records.append(stats)\n",
+ "\n",
+ " if all(isinstance(idx, tuple) for idx in indexes):\n",
+ " index = pd.MultiIndex.from_tuples(indexes)\n",
+ " else:\n",
+ " index = pd.Index(indexes)\n",
+ " return pd.DataFrame(records, index=index)\n",
+ "\n",
+ " def raw(self, name=None):\n",
+ " \"\"\"\n",
+ " Get raw timing data in microseconds.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " name : str, optional\n",
+ " If given, returns the raw_us array for that function. Otherwise returns\n",
+ " a dict of all raw results.\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " numpy.ndarray or dict\n",
+ " \"\"\"\n",
+ " if name:\n",
+ " return self.results.get(name, {}).get(\"raw_us\")\n",
+ " return {n: d[\"raw_us\"] for n, d in self.results.items()}\n",
+ "\n",
+ " def _convert_times(self, times, unit):\n",
+ " \"\"\"\n",
+ " Convert an array of times from microseconds to the specified unit.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " times : array-like\n",
+ " Times in microseconds.\n",
+ " unit : {'us', 'ms', 'ns', 's'}\n",
+ " Target unit: 'us' microseconds, 'ms' milliseconds,\n",
+ " 'ns' nanoseconds, 's' seconds.\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " numpy.ndarray\n",
+ " Converted times.\n",
+ "\n",
+ " Raises\n",
+ " ------\n",
+ " ValueError\n",
+ " If `unit` is not one of the supported options.\n",
+ " \"\"\"\n",
+ " unit = unit.lower()\n",
+ " if unit == \"us\":\n",
+ " factor = 1.0\n",
+ " elif unit == \"ms\":\n",
+ " factor = 1e-3\n",
+ " elif unit == \"ns\":\n",
+ " factor = 1e3\n",
+ " elif unit == \"s\":\n",
+ " factor = 1e-6\n",
+ " else:\n",
+ " raise ValueError(f\"Unsupported unit: {unit}\")\n",
+ " return times * factor"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "799a759b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# def block(func):\n",
+ "# def inner(*args, **kwargs):\n",
+ "# return jax.block_until_ready(func(*args, **kwargs))\n",
+ "# return inner\n",
+ "\n",
+ "\n",
+ "# I believe the above will only work if the return is a single jnp.array (at least that is what AI thinks) I would appreciate your insights here, Adrian.\n",
+ "\n",
+ "def block(func):\n",
+ " def inner(*args, **kwargs):\n",
+ " result = func(*args, **kwargs)\n",
+ " # Recursively block on all JAX arrays in result\n",
+ " jax.tree_util.tree_map(lambda x: x.block_until_ready() if hasattr(x, \"block_until_ready\") else None, result)\n",
+ " return result\n",
+ " return inner\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "6e76a860",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Set Pytensor to use float32\n",
+ "pytensor.config.floatX = \"float32\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1c3eb543",
+ "metadata": {},
+ "source": [
+ "# Introduction"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e066c9c4",
+ "metadata": {},
+ "source": [
+ "# Baby Steps"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "962c851c",
+ "metadata": {},
+ "source": [
+ "## Fibonacci Algorithm"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "1c8d1654",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "N_STEPS = 10_000\n",
+ "\n",
+ "b_symbolic = pt.vector(\"b\", dtype=\"int32\", shape=(1,))\n",
+ "\n",
+ "def step(a, b):\n",
+ " return a + b, a\n",
+ "\n",
+ "(outputs_a, outputs_b), _ = pytensor.scan(\n",
+ " fn=step,\n",
+ " outputs_info=[pt.ones(1, dtype=\"int32\"), b_symbolic],\n",
+ " n_steps=N_STEPS\n",
+ ")\n",
+ "\n",
+ "# compile function returning final a\n",
+ "fibonacci_pytensor = pytensor.function([b_symbolic], outputs_a[-1], trust_input=True)\n",
+ "fibonacci_pytensor_numba = pytensor.function([b_symbolic], outputs_a[-1], mode='NUMBA', trust_input=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "bf48d6bf",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@jit(nopython=True)\n",
+ "def fibonacci_numba(b):\n",
+ " a = np.ones(1, dtype=np.int32)\n",
+ " for _ in range(N_STEPS):\n",
+ " a[0], b[0] = a[0] + b[0], a[0]\n",
+ " return a[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "a5cb2f54",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@jax.jit\n",
+ "def fibonacci_jax(b):\n",
+ " a = jnp.ones(1, dtype=np.int32)\n",
+ " \n",
+ " def step(_, carry):\n",
+ " a, b = carry\n",
+ " return (a + b, a)\n",
+ " \n",
+ " a = jax.lax.fori_loop(0, N_STEPS, step, (a, b))\n",
+ "\n",
+ " return a"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "63bfdffe",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fibonacci_bench = Benchmarker(\n",
+ " functions=[fibonacci_pytensor, fibonacci_numba, fibonacci_jax, fibonacci_pytensor_numba], \n",
+ " names=['fibonacci_pytensor', 'fibonacci_numba', 'fibonacci_jax', 'fibonacci_pytensor_numba'],\n",
+ " number=10\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "65bf994e",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " | \n",
+ " Loops | \n",
+ " Min (us) | \n",
+ " Max (us) | \n",
+ " Mean (us) | \n",
+ " StdDev (us) | \n",
+ " Median (us) | \n",
+ " IQR (us) | \n",
+ " OPS (Kops/s) | \n",
+ " Samples | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | fibonacci_pytensor | \n",
+ " fibonacci_inputs | \n",
+ " 10 | \n",
+ " 5314.345798 | \n",
+ " 6147.600000 | \n",
+ " 5749.518199 | \n",
+ " 241.196105 | \n",
+ " 5631.999997 | \n",
+ " 395.727096 | \n",
+ " 0.173928 | \n",
+ " 19 | \n",
+ "
\n",
+ " \n",
+ " | fibonacci_numba | \n",
+ " fibonacci_inputs | \n",
+ " 10 | \n",
+ " 8.475001 | \n",
+ " 11.362496 | \n",
+ " 9.303579 | \n",
+ " 0.653885 | \n",
+ " 9.139549 | \n",
+ " 0.234378 | \n",
+ " 107.485516 | \n",
+ " 14 | \n",
+ "
\n",
+ " \n",
+ " | fibonacci_jax | \n",
+ " fibonacci_inputs | \n",
+ " 10 | \n",
+ " 21.829200 | \n",
+ " 37.458300 | \n",
+ " 23.662635 | \n",
+ " 3.214998 | \n",
+ " 22.087496 | \n",
+ " 2.304175 | \n",
+ " 42.260721 | \n",
+ " 32 | \n",
+ "
\n",
+ " \n",
+ " | fibonacci_pytensor_numba | \n",
+ " fibonacci_inputs | \n",
+ " 10 | \n",
+ " 222.412497 | \n",
+ " 235.400000 | \n",
+ " 229.243340 | \n",
+ " 5.403845 | \n",
+ " 231.954205 | \n",
+ " 10.200002 | \n",
+ " 4.362177 | \n",
+ " 5 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Loops Min (us) Max (us) \\\n",
+ "fibonacci_pytensor fibonacci_inputs 10 5314.345798 6147.600000 \n",
+ "fibonacci_numba fibonacci_inputs 10 8.475001 11.362496 \n",
+ "fibonacci_jax fibonacci_inputs 10 21.829200 37.458300 \n",
+ "fibonacci_pytensor_numba fibonacci_inputs 10 222.412497 235.400000 \n",
+ "\n",
+ " Mean (us) StdDev (us) \\\n",
+ "fibonacci_pytensor fibonacci_inputs 5749.518199 241.196105 \n",
+ "fibonacci_numba fibonacci_inputs 9.303579 0.653885 \n",
+ "fibonacci_jax fibonacci_inputs 23.662635 3.214998 \n",
+ "fibonacci_pytensor_numba fibonacci_inputs 229.243340 5.403845 \n",
+ "\n",
+ " Median (us) IQR (us) \\\n",
+ "fibonacci_pytensor fibonacci_inputs 5631.999997 395.727096 \n",
+ "fibonacci_numba fibonacci_inputs 9.139549 0.234378 \n",
+ "fibonacci_jax fibonacci_inputs 22.087496 2.304175 \n",
+ "fibonacci_pytensor_numba fibonacci_inputs 231.954205 10.200002 \n",
+ "\n",
+ " OPS (Kops/s) Samples \n",
+ "fibonacci_pytensor fibonacci_inputs 0.173928 19 \n",
+ "fibonacci_numba fibonacci_inputs 107.485516 14 \n",
+ "fibonacci_jax fibonacci_inputs 42.260721 32 \n",
+ "fibonacci_pytensor_numba fibonacci_inputs 4.362177 5 "
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "fibonacci_bench.run(\n",
+ " inputs={\n",
+ " \"fibonacci_inputs\": {\"b\": np.ones(1, dtype=np.int32)},\n",
+ " }\n",
+ ")\n",
+ "fibonacci_bench.summary()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7bbb35a2",
+ "metadata": {},
+ "source": [
+ "## Element-wise multiplication Algorithm"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "0b076f77",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Not sure this makes a difference\n",
+ "a_symbolic = pt.vector(\"a\", dtype=\"int32\")\n",
+ "b_symbolic = pt.vector(\"b\", dtype=\"int32\")\n",
+ "N_STEPS = 1000\n",
+ "\n",
+ "def step(a_element, b_element):\n",
+ " return a_element * b_element\n",
+ "\n",
+ "c, _ = pytensor.scan(\n",
+ " fn=step,\n",
+ " sequences=[a_symbolic, b_symbolic],\n",
+ " n_steps=N_STEPS\n",
+ ")\n",
+ "\n",
+ "# compile function returning final a\n",
+ "c_mode = get_mode(\"FAST_RUN\").excluding(\"scan_push_out_seq\")\n",
+ "elementwise_multiply_pytensor = pytensor.function([a_symbolic, b_symbolic], c, trust_input=True, mode=c_mode)\n",
+ "\n",
+ "numba_mode = get_mode(\"NUMBA\").excluding(\"scan_push_out_seq\")\n",
+ "elementwise_multiply_pytensor_numba = pytensor.function([a_symbolic, b_symbolic], c, mode=numba_mode, trust_input=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "14327cbf",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@jit(nopython=True)\n",
+ "def elementwise_multiply_numba(a, b):\n",
+ " n = a.shape[0]\n",
+ " c = np.empty(n, dtype=a.dtype)\n",
+ " for i in range(n):\n",
+ " c[i] = a[i] * b[i]\n",
+ " return c"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "13715c9b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@block\n",
+ "@jax.jit\n",
+ "def elementwise_multiply_jax(a, b):\n",
+ " n = a.shape[0]\n",
+ " c_init = jnp.empty(n, dtype=a.dtype)\n",
+ " def step(i, c):\n",
+ " return jax.lax.dynamic_update_index_in_dim(c, a[i] * b[i], i, axis=0)\n",
+ " \n",
+ " c = jax.lax.fori_loop(0, n, step, c_init)\n",
+ " return c"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "d43c4f9c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "a = np.random.normal(0, 1, (N_STEPS)).astype(np.int32)\n",
+ "b = np.random.normal(0, 1, (N_STEPS)).astype(np.int32)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "f0f4ede5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "elem_mult_bench = Benchmarker(\n",
+ " functions=[elementwise_multiply_pytensor, elementwise_multiply_numba, elementwise_multiply_jax, elementwise_multiply_pytensor_numba], \n",
+ " names=['elementwise_multiply_pytensor', 'elementwise_multiply_numba', 'elementwise_multiply_jax', 'elementwise_multiply_pytensor_numba'],\n",
+ " number=10\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "cdab8946",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " | \n",
+ " Loops | \n",
+ " Min (us) | \n",
+ " Max (us) | \n",
+ " Mean (us) | \n",
+ " StdDev (us) | \n",
+ " Median (us) | \n",
+ " IQR (us) | \n",
+ " OPS (Kops/s) | \n",
+ " Samples | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | elementwise_multiply_pytensor | \n",
+ " elem_mult_inputs | \n",
+ " 10 | \n",
+ " 431.329099 | \n",
+ " 496.695802 | \n",
+ " 465.787881 | \n",
+ " 9.956242 | \n",
+ " 466.149999 | \n",
+ " 12.630122 | \n",
+ " 2.146900 | \n",
+ " 218 | \n",
+ "
\n",
+ " \n",
+ " | elementwise_multiply_numba | \n",
+ " elem_mult_inputs | \n",
+ " 10 | \n",
+ " 0.362500 | \n",
+ " 0.604201 | \n",
+ " 0.383343 | \n",
+ " 0.044777 | \n",
+ " 0.374997 | \n",
+ " 0.008375 | \n",
+ " 2608.632115 | \n",
+ " 26 | \n",
+ "
\n",
+ " \n",
+ " | elementwise_multiply_jax | \n",
+ " elem_mult_inputs | \n",
+ " 10 | \n",
+ " 7.775001 | \n",
+ " 11.637498 | \n",
+ " 8.693605 | \n",
+ " 0.914430 | \n",
+ " 8.283349 | \n",
+ " 0.964597 | \n",
+ " 115.027083 | \n",
+ " 60 | \n",
+ "
\n",
+ " \n",
+ " | elementwise_multiply_pytensor_numba | \n",
+ " elem_mult_inputs | \n",
+ " 10 | \n",
+ " 28.791599 | \n",
+ " 31.770801 | \n",
+ " 29.504161 | \n",
+ " 1.138149 | \n",
+ " 28.958404 | \n",
+ " 0.224996 | \n",
+ " 33.893525 | \n",
+ " 5 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Loops Min (us) \\\n",
+ "elementwise_multiply_pytensor elem_mult_inputs 10 431.329099 \n",
+ "elementwise_multiply_numba elem_mult_inputs 10 0.362500 \n",
+ "elementwise_multiply_jax elem_mult_inputs 10 7.775001 \n",
+ "elementwise_multiply_pytensor_numba elem_mult_inputs 10 28.791599 \n",
+ "\n",
+ " Max (us) Mean (us) \\\n",
+ "elementwise_multiply_pytensor elem_mult_inputs 496.695802 465.787881 \n",
+ "elementwise_multiply_numba elem_mult_inputs 0.604201 0.383343 \n",
+ "elementwise_multiply_jax elem_mult_inputs 11.637498 8.693605 \n",
+ "elementwise_multiply_pytensor_numba elem_mult_inputs 31.770801 29.504161 \n",
+ "\n",
+ " StdDev (us) \\\n",
+ "elementwise_multiply_pytensor elem_mult_inputs 9.956242 \n",
+ "elementwise_multiply_numba elem_mult_inputs 0.044777 \n",
+ "elementwise_multiply_jax elem_mult_inputs 0.914430 \n",
+ "elementwise_multiply_pytensor_numba elem_mult_inputs 1.138149 \n",
+ "\n",
+ " Median (us) IQR (us) \\\n",
+ "elementwise_multiply_pytensor elem_mult_inputs 466.149999 12.630122 \n",
+ "elementwise_multiply_numba elem_mult_inputs 0.374997 0.008375 \n",
+ "elementwise_multiply_jax elem_mult_inputs 8.283349 0.964597 \n",
+ "elementwise_multiply_pytensor_numba elem_mult_inputs 28.958404 0.224996 \n",
+ "\n",
+ " OPS (Kops/s) Samples \n",
+ "elementwise_multiply_pytensor elem_mult_inputs 2.146900 218 \n",
+ "elementwise_multiply_numba elem_mult_inputs 2608.632115 26 \n",
+ "elementwise_multiply_jax elem_mult_inputs 115.027083 60 \n",
+ "elementwise_multiply_pytensor_numba elem_mult_inputs 33.893525 5 "
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "elem_mult_bench.run(\n",
+ " inputs={\n",
+ " \"elem_mult_inputs\": {\"a\": a, \"b\": b},\n",
+ " }\n",
+ ")\n",
+ "elem_mult_bench.summary()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "25a87710",
+ "metadata": {},
+ "source": [
+ "# Changepoint Detection Algorithms"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "16b2b312",
+ "metadata": {},
+ "source": [
+ "## Cumulative Sum (CUSUM) Algorithm"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "c1226cfc",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@jit(nopython=True)\n",
+ "def cusum_adaptive_numba(x, alpha=0.01, k=0.5, h=5.0):\n",
+ " \"\"\"\n",
+ " Two-sided CUSUM with adaptive exponential moving average baseline.\n",
+ " \n",
+ " Parameters\n",
+ " ----------\n",
+ " x: np.ndarray\n",
+ " input signal\n",
+ " alpha: float\n",
+ " EMA smoothing factor (0 < alpha <= 1)\n",
+ " k: float\n",
+ " slack to avoid small changes triggering alarms\n",
+ " h: float\n",
+ " threshold for raising an alarm\n",
+ " \n",
+ " Returns\n",
+ " -------\n",
+ " s_pos: np.ndarray\n",
+ " upper CUSUM stats\n",
+ " s_neg: np.ndarray\n",
+ " lower CUSUM stats\n",
+ " mu_t: np.ndarray\n",
+ " evolving baseline estimate\n",
+ " alarms_pos: np.ndarray\n",
+ " alarms for upward changes\n",
+ " alarms_neg: np.ndarray\n",
+ " alarms for downward changes\n",
+ " \"\"\"\n",
+ " n = x.shape[0]\n",
+ "\n",
+ " s_pos = np.zeros(n, dtype=np.float32)\n",
+ " s_neg = np.zeros(n, dtype=np.float32)\n",
+ " mu_t = np.zeros(n, dtype=np.float32)\n",
+ " alarms_pos = np.zeros(n, dtype=np.bool_)\n",
+ " alarms_neg = np.zeros(n, dtype=np.bool_)\n",
+ "\n",
+ " # Initialization\n",
+ " mu_t[0] = x[0]\n",
+ "\n",
+ " for i in range(1, n):\n",
+ " # Update baseline (EMA)\n",
+ " mu_t[i] = alpha * x[i] + (1 - alpha) * mu_t[i-1]\n",
+ "\n",
+ " # Update CUSUM stats\n",
+ " s_pos[i] = max(0.0, s_pos[i-1] + x[i] - mu_t[i] - k)\n",
+ " s_neg[i] = max(0.0, s_neg[i-1] - (x[i] - mu_t[i]) - k)\n",
+ "\n",
+ " # Alarms\n",
+ " alarms_pos = s_pos > h\n",
+ " alarms_neg = s_neg > h\n",
+ "\n",
+ " return s_pos, s_neg, mu_t, alarms_pos, alarms_neg"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "14937a2a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@block\n",
+ "@jax.jit\n",
+ "def cusum_adaptive_jax(x, alpha=0.01, k=0.5, h=5.0):\n",
+ " \"\"\"\n",
+ " Two-sided CUSUM with adaptive exponential moving average baseline.\n",
+ " \n",
+ " Parameters\n",
+ " ----------\n",
+ " x: jnp.ndarray\n",
+ " input signal\n",
+ " alpha: float\n",
+ " EMA smoothing factor (0 < alpha <= 1)\n",
+ " k: float\n",
+ " slack to avoid small changes triggering alarms\n",
+ " h: float\n",
+ " threshold for raising an alarm\n",
+ " \n",
+ " Returns\n",
+ " -------\n",
+ " s_pos: jnp.ndarray\n",
+ " upper CUSUM stats\n",
+ " s_neg: jnp.ndarray\n",
+ " lower CUSUM stats\n",
+ " mu_t: jnp.ndarray\n",
+ " evolving baseline estimate\n",
+ " alarms_pos: jnp.ndarray\n",
+ " alarms for upward changes\n",
+ " alarms_neg: jnp.ndarray\n",
+ " alarms for downward changes\n",
+ " \"\"\"\n",
+ " def body(carry, x_t):\n",
+ " s_pos_prev, s_neg_prev, mu_prev = carry\n",
+ " \n",
+ " # Update EMA baseline\n",
+ " mu_t = alpha * x_t + (1 - alpha) * mu_prev\n",
+ " \n",
+ " # Update CUSUMs using updated baseline\n",
+ " s_pos = jnp.maximum(0.0, s_pos_prev + x_t - mu_t - k)\n",
+ " s_neg = jnp.maximum(0.0, s_neg_prev - (x_t - mu_t) - k)\n",
+ " \n",
+ " new_carry = (s_pos, s_neg, mu_t)\n",
+ " output = (s_pos, s_neg, mu_t)\n",
+ " return new_carry, output\n",
+ "\n",
+ " # Initialize: CUSUMs at 0, initial mean = first sample\n",
+ " s0 = (0.0, 0.0, x[0])\n",
+ " _, (s_pos_vals, s_neg_vals, mu_vals) = jax.lax.scan(body, s0, x)\n",
+ " \n",
+ " # Thresholding\n",
+ " alarms_pos = s_pos_vals > h\n",
+ " alarms_neg = s_neg_vals > h\n",
+ "\n",
+ " return s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "815967a4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x_symbolic = pt.vector(\"x\")\n",
+ "alpha_symbolic = pt.scalar(\"alpha\")\n",
+ "k_symbolic = pt.scalar(\"k\")\n",
+ "h_symbolic = pt.scalar(\"h\")\n",
+ "\n",
+ "N_STEPS = 100 # Fixing this just incase but I don't think this is changing anything when we have sequences as input to scan\n",
+ "\n",
+ "def step(x_t, s_pos_prev, s_neg_prev, mu_prev, alpha, k):\n",
+ " # Update EMA baseline\n",
+ " mu_t = alpha * x_t + (1 - alpha) * mu_prev\n",
+ " \n",
+ " # Update CUSUMs using updated baseline\n",
+ " s_pos = pt.maximum(0.0, s_pos_prev + x_t - mu_t - k)\n",
+ " s_neg = pt.maximum(0.0, s_neg_prev - (x_t - mu_t) - k)\n",
+ " \n",
+ " return s_pos, s_neg, mu_t\n",
+ "\n",
+ "\n",
+ "(s_pos_vals, s_neg_vals, mu_vals), updates = pytensor.scan(\n",
+ " fn=step,\n",
+ " outputs_info=[pt.constant(0., dtype=\"float32\"), pt.constant(0., dtype=\"float32\"), x_symbolic[0]],\n",
+ " non_sequences=[alpha_symbolic, k_symbolic],\n",
+ " sequences=[x_symbolic],\n",
+ " n_steps=N_STEPS\n",
+ ")\n",
+ "\n",
+ "# Thresholding\n",
+ "alarms_pos = s_pos_vals > h_symbolic\n",
+ "alarms_neg = s_neg_vals > h_symbolic\n",
+ "\n",
+ "cusum_adaptive_pytensor = pytensor.function([x_symbolic, alpha_symbolic, k_symbolic, h_symbolic], [s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg], trust_input=True)\n",
+ "\n",
+ "cusum_adaptive_pytensor_numba = pytensor.function([x_symbolic, alpha_symbolic, k_symbolic, h_symbolic], [s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg], mode=\"NUMBA\", trust_input=True)\n",
+ "cusum_adaptive_pytensor_jax = pytensor.function([x_symbolic, alpha_symbolic, k_symbolic, h_symbolic], [s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg], mode=\"JAX\", trust_input=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "6c892129",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "xs1 = np.random.normal(80, 20, size=(int(N_STEPS/2)))\n",
+ "xs2 = np.random.normal(50, 20, size=(int(N_STEPS/2)))\n",
+ "xs = np.concat((xs1, xs2))\n",
+ "xs = xs.astype(np.float32)\n",
+ "xs = xs.astype(np.float32)\n",
+ "xs_std = (xs - xs.mean()) / xs.std()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "d21873d6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "cusum_bench = Benchmarker(\n",
+ " functions=[cusum_adaptive_pytensor, cusum_adaptive_numba, cusum_adaptive_jax, cusum_adaptive_pytensor_numba, block(cusum_adaptive_pytensor_jax)], \n",
+ " names=['cusum_adaptive_pytensor', 'cusum_adaptive_numba', 'cusum_adaptive_jax', 'cusum_adaptive_pytensor_numba', 'cusum_adaptive_pytensor_jax'],\n",
+ " number=10\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "d8eab72d",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " | \n",
+ " Loops | \n",
+ " Min (us) | \n",
+ " Max (us) | \n",
+ " Mean (us) | \n",
+ " StdDev (us) | \n",
+ " Median (us) | \n",
+ " IQR (us) | \n",
+ " OPS (Kops/s) | \n",
+ " Samples | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | cusum_adaptive_pytensor | \n",
+ " cusum_inputs | \n",
+ " 10 | \n",
+ " 127.283402 | \n",
+ " 161.174999 | \n",
+ " 137.054916 | \n",
+ " 4.574608 | \n",
+ " 136.816700 | \n",
+ " 5.375006 | \n",
+ " 7.296345 | \n",
+ " 697 | \n",
+ "
\n",
+ " \n",
+ " | cusum_adaptive_numba | \n",
+ " cusum_inputs | \n",
+ " 10 | \n",
+ " 1.658296 | \n",
+ " 2.295797 | \n",
+ " 1.988085 | \n",
+ " 0.223086 | \n",
+ " 2.041698 | \n",
+ " 0.372902 | \n",
+ " 502.996564 | \n",
+ " 7 | \n",
+ "
\n",
+ " \n",
+ " | cusum_adaptive_jax | \n",
+ " cusum_inputs | \n",
+ " 10 | \n",
+ " 13.975002 | \n",
+ " 19.087503 | \n",
+ " 14.758764 | \n",
+ " 1.196708 | \n",
+ " 14.233397 | \n",
+ " 0.585447 | \n",
+ " 67.756351 | \n",
+ " 39 | \n",
+ "
\n",
+ " \n",
+ " | cusum_adaptive_pytensor_numba | \n",
+ " cusum_inputs | \n",
+ " 10 | \n",
+ " 21.270796 | \n",
+ " 26.608299 | \n",
+ " 22.388300 | \n",
+ " 2.110692 | \n",
+ " 21.316600 | \n",
+ " 0.120798 | \n",
+ " 44.666188 | \n",
+ " 5 | \n",
+ "
\n",
+ " \n",
+ " | cusum_adaptive_pytensor_jax | \n",
+ " cusum_inputs | \n",
+ " 10 | \n",
+ " 21.491601 | \n",
+ " 37.291698 | \n",
+ " 22.706185 | \n",
+ " 2.756775 | \n",
+ " 21.787500 | \n",
+ " 1.479208 | \n",
+ " 44.040863 | \n",
+ " 33 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Loops Min (us) Max (us) \\\n",
+ "cusum_adaptive_pytensor cusum_inputs 10 127.283402 161.174999 \n",
+ "cusum_adaptive_numba cusum_inputs 10 1.658296 2.295797 \n",
+ "cusum_adaptive_jax cusum_inputs 10 13.975002 19.087503 \n",
+ "cusum_adaptive_pytensor_numba cusum_inputs 10 21.270796 26.608299 \n",
+ "cusum_adaptive_pytensor_jax cusum_inputs 10 21.491601 37.291698 \n",
+ "\n",
+ " Mean (us) StdDev (us) \\\n",
+ "cusum_adaptive_pytensor cusum_inputs 137.054916 4.574608 \n",
+ "cusum_adaptive_numba cusum_inputs 1.988085 0.223086 \n",
+ "cusum_adaptive_jax cusum_inputs 14.758764 1.196708 \n",
+ "cusum_adaptive_pytensor_numba cusum_inputs 22.388300 2.110692 \n",
+ "cusum_adaptive_pytensor_jax cusum_inputs 22.706185 2.756775 \n",
+ "\n",
+ " Median (us) IQR (us) \\\n",
+ "cusum_adaptive_pytensor cusum_inputs 136.816700 5.375006 \n",
+ "cusum_adaptive_numba cusum_inputs 2.041698 0.372902 \n",
+ "cusum_adaptive_jax cusum_inputs 14.233397 0.585447 \n",
+ "cusum_adaptive_pytensor_numba cusum_inputs 21.316600 0.120798 \n",
+ "cusum_adaptive_pytensor_jax cusum_inputs 21.787500 1.479208 \n",
+ "\n",
+ " OPS (Kops/s) Samples \n",
+ "cusum_adaptive_pytensor cusum_inputs 7.296345 697 \n",
+ "cusum_adaptive_numba cusum_inputs 502.996564 7 \n",
+ "cusum_adaptive_jax cusum_inputs 67.756351 39 \n",
+ "cusum_adaptive_pytensor_numba cusum_inputs 44.666188 5 \n",
+ "cusum_adaptive_pytensor_jax cusum_inputs 44.040863 33 "
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "cusum_bench.run(\n",
+ " inputs={\n",
+ " \"cusum_inputs\": {\"x\": xs, \"alpha\": 0.1, \"k\": 0.5, \"h\": 3.5},\n",
+ " }\n",
+ ")\n",
+ "cusum_bench.summary()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "3e9c3339",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "outputs = cusum_adaptive_numba(xs_std, alpha=0.1, k=0.5, h=3.5)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "b85d0c0e",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ " \n",
+ " \n",
+ " "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "fig = go.Figure()\n",
+ "fig.add_traces(\n",
+ " [\n",
+ " go.Scatter(\n",
+ " x = np.arange(len(xs)),\n",
+ " y = xs_std,\n",
+ " name=\"series\"\n",
+ " ),\n",
+ " go.Scatter(\n",
+ " x = np.arange(len(xs)),\n",
+ " y = outputs[0],\n",
+ " name=\"cum. positive devs.\"\n",
+ " ),\n",
+ " go.Scatter(\n",
+ " x = np.arange(len(xs)),\n",
+ " y = outputs[1],\n",
+ " name=\"cum. negative devs.\"\n",
+ " ),\n",
+ " go.Scatter(\n",
+ " x = np.arange(len(xs)),\n",
+ " y = outputs[2],\n",
+ " name=\"Exp. Mean\"\n",
+ " ),\n",
+ " go.Scatter(\n",
+ " x = np.arange(len(xs)),\n",
+ " y = outputs[3].astype(np.float16),\n",
+ " name=\"positive alarms\"\n",
+ " ),\n",
+ " go.Scatter(\n",
+ " x = np.arange(len(xs)),\n",
+ " y = outputs[4].astype(np.float16),\n",
+ " name=\"negative alarms\"\n",
+ " ),\n",
+ " \n",
+ " ]\n",
+ ")\n",
+ "fig.update_layout(\n",
+ " title = dict(\n",
+ " text = \"CUSUM Change Point Detection Algorithm\"\n",
+ " ),\n",
+ " xaxis=dict(\n",
+ " title = \"Time Index\"\n",
+ " ),\n",
+ " yaxis=dict(\n",
+ " title = \"Standardized Series Scaled\"\n",
+ " ),\n",
+ " legend=dict(\n",
+ " yanchor=\"top\",\n",
+ " y=1.1,\n",
+ " xanchor=\"left\",\n",
+ " x=0,\n",
+ " orientation=\"h\"\n",
+ " ),\n",
+ " template=\"plotly_dark\"\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c2ce03ea",
+ "metadata": {},
+ "source": [
+ "## Pruned Exact Linear Time (PELT) Algorithm"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "eab4cbb9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@jit(nopython=True)\n",
+ "def segment_cost_numba(S1, S2, i, j):\n",
+ " \"\"\"Cost of segment x[i:j], SSE around mean\"\"\"\n",
+ " n = j - i\n",
+ " sum_x = S1[j] - S1[i]\n",
+ " sum_x2 = S2[j] - S2[i]\n",
+ " if n > 0:\n",
+ " return sum_x2 - (sum_x ** 2) / n\n",
+ " else:\n",
+ " return np.inf\n",
+ "\n",
+ "@jit(nopython=True)\n",
+ "def pelt_numba(x, beta=10.0):\n",
+ " \"\"\"\n",
+ " Pruned Exact Linear Time algorithm for change point detection\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " x: np.ndarray\n",
+ " The timeseries signal\n",
+ " beta: float\n",
+ " Penalty of segmenting the series\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " C: np.ndarray\n",
+ " The best costs up to segment t\n",
+ " last_change: np.ndarray\n",
+ " The last change point up to segment t\n",
+ " \"\"\"\n",
+ " n = len(x)\n",
+ "\n",
+ " # cumulative sums for cost\n",
+ " S1 = np.empty(n+1, dtype=np.float32)\n",
+ " S2 = np.empty(n+1, dtype=np.float32)\n",
+ " S1[0], S2[0] = 0.0, 0.0\n",
+ " for i in range(1, n+1):\n",
+ " S1[i] = S1[i-1] + x[i-1]\n",
+ " S2[i] = S2[i-1] + x[i-1]**2\n",
+ "\n",
+ " # DP arrays\n",
+ " C = np.full((n+1,), np.inf)\n",
+ " C[0] = -beta\n",
+ " last_change = np.full((n+1,), -1)\n",
+ " min_size = 3\n",
+ "\n",
+ " for t in range(1, n+1):\n",
+ " costs = np.full(n, np.inf)\n",
+ " for s in range(n):\n",
+ " if s < t and (t - s) >= min_size:\n",
+ " costs[s] = C[s] + segment_cost_numba(S1, S2, s, t) + beta\n",
+ " best_s = np.argmin(costs)\n",
+ " C[t] = costs[best_s]\n",
+ " last_change[t] = best_s\n",
+ "\n",
+ " return C, last_change"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "e6997389",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def segment_cost_jax(S1, S2, i, j):\n",
+ " \"\"\"Cost of segment x[i:j], SSE around mean\"\"\"\n",
+ " n = j - i\n",
+ " sum_x = S1[j] - S1[i]\n",
+ " sum_x2 = S2[j] - S2[i]\n",
+ " return jnp.where(n > 0, sum_x2 - (sum_x ** 2) / n, jnp.inf)\n",
+ "\n",
+ "@block\n",
+ "@jax.jit\n",
+ "def pelt_jax(x, beta=10.0):\n",
+ " \"\"\"\n",
+ " Pruned Exact Linear Time algorithm for change point detection\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " x: np.ndarray\n",
+ " The timeseries signal\n",
+ " beta: float\n",
+ " Penalty of segmenting the series\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " C: jnp.ndarray\n",
+ " The best costs up to segment t\n",
+ " last_change: jnp.ndarray\n",
+ " The last change point up to segment t\n",
+ " \"\"\"\n",
+ " n = len(x)\n",
+ "\n",
+ " # cumulative sums for cost\n",
+ " S1 = jnp.concatenate([jnp.array([0.0]), jnp.cumsum(x)])\n",
+ " S2 = jnp.concatenate([jnp.array([0.0]), jnp.cumsum(x**2)])\n",
+ "\n",
+ " # DP arrays\n",
+ " C = jnp.full((n+1,), jnp.inf)\n",
+ " C = C.at[0].set(-beta)\n",
+ " last_change = jnp.full((n+1,), -1)\n",
+ " min_size = 3\n",
+ "\n",
+ " s_all = jnp.arange(n) # all possible candidates\n",
+ "\n",
+ " def body(t, carry):\n",
+ " C, last_change = carry\n",
+ "\n",
+ " # Compute cost for all s < t, mask invalid\n",
+ " # valid = s_all < t & ((t - s_all) >= min_size)\n",
+ " \n",
+ " valid = (s_all < t) & ((t - s_all) >= min_size)\n",
+ " costs = jnp.where(\n",
+ " valid,\n",
+ " C[s_all] + segment_cost_jax(S1, S2, s_all, t) + beta,\n",
+ " jnp.inf\n",
+ " )\n",
+ "\n",
+ " best_s = jnp.argmin(costs)\n",
+ " C = C.at[t].set(costs[best_s])\n",
+ " last_change = last_change.at[t].set(best_s)\n",
+ " return C, last_change\n",
+ "\n",
+ " C, last_change = jax.lax.fori_loop(1, n+1, body, (C, last_change))\n",
+ " return C, last_change"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "id": "094b9e8e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def segment_cost_pytensor(S1, S2, i, j):\n",
+ " \"\"\"Cost of segment x[i:j], SSE around mean\"\"\"\n",
+ " n = j - i\n",
+ " sum_x = S1[j] - S1[i]\n",
+ " sum_x2 = S2[j] - S2[i]\n",
+ " return pt.switch(\n",
+ " pt.gt(n, 0),\n",
+ " sum_x2 - (sum_x ** 2) / n,\n",
+ " np.inf\n",
+ " )\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "id": "03e5e927",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x_symbolic = pt.vector(\"x\")\n",
+ "beta_symbolic = pt.scalar(\"beta\")\n",
+ "n = x_symbolic.shape[0]\n",
+ "N_STEPS=100\n",
+ "\n",
+ "# cumulative sums for cost\n",
+ "S1 = pt.concatenate([pt.as_tensor([0.0]), pt.cumsum(x_symbolic)])\n",
+ "S2 = pt.concatenate([pt.as_tensor([0.0]), pt.cumsum(x_symbolic**2)])\n",
+ "\n",
+ "# DP arrays\n",
+ "C_init = pt.alloc(np.inf, n+1)\n",
+ "C_init = pt.set_subtensor(C_init[0], -beta_symbolic)\n",
+ "last_change_init = pt.alloc(-1, n+1)\n",
+ "\n",
+ "s_all = pt.arange(n) # candidate change points\n",
+ "min_size = 3\n",
+ "\n",
+ "def step(t, C_prev, last_change_prev, S1, S2, beta_symbolic, s_all):\n",
+ " # valid = (s_all < t) & ((t - s_all) >= min_size)\n",
+ " valid = pt.and_(pt.lt(s_all, t), pt.ge(t - s_all, min_size))\n",
+ "\n",
+ " # compute costs for all candidates\n",
+ " costs, _ = pytensor.scan(\n",
+ " fn=lambda s: pt.switch(\n",
+ " valid[s],\n",
+ " C_prev[s] + segment_cost_pytensor(S1, S2, s, t) + beta_symbolic,\n",
+ " np.inf\n",
+ " ),\n",
+ " sequences=[pt.arange(n)]\n",
+ " )\n",
+ " costs = costs.flatten()\n",
+ "\n",
+ " best_s = pt.argmin(costs, axis=0)\n",
+ " C_new = pt.set_subtensor(C_prev[t], costs[best_s])\n",
+ " last_change_new = pt.set_subtensor(last_change_prev[t], best_s)\n",
+ "\n",
+ " return C_new, last_change_new\n",
+ "\n",
+ "(C_vals, last_change_vals), _ = pytensor.scan(\n",
+ " fn=step,\n",
+ " sequences=[pt.arange(1, n+1)],\n",
+ " outputs_info=[C_init, last_change_init],\n",
+ " non_sequences=[S1, S2, beta_symbolic, s_all],\n",
+ " n_steps=N_STEPS # Added fixed iterations here\n",
+ ")\n",
+ "\n",
+ "pelt_pytensor = pytensor.function([x_symbolic, beta_symbolic], [C_vals[-1], last_change_vals[-1]], trust_input=True)\n",
+ "pelt_pytensor_numba = pytensor.function(inputs=[x_symbolic, beta_symbolic], outputs=[C_vals[-1], last_change_vals[-1]], mode=\"NUMBA\", trust_input=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "id": "69aea9a2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pelt_bench = Benchmarker(\n",
+ " functions=[pelt_pytensor, pelt_numba, pelt_jax, pelt_pytensor_numba], \n",
+ " names=['pelt_pytensor', 'pelt_numba', 'pelt_jax', 'pelt_pytensor_numba'],\n",
+ " number=10\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "id": "5aee4a18",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " | \n",
+ " Loops | \n",
+ " Min (us) | \n",
+ " Max (us) | \n",
+ " Mean (us) | \n",
+ " StdDev (us) | \n",
+ " Median (us) | \n",
+ " IQR (us) | \n",
+ " OPS (Kops/s) | \n",
+ " Samples | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | pelt_pytensor | \n",
+ " pelt_inputs | \n",
+ " 10 | \n",
+ " 11072.008300 | \n",
+ " 11740.058300 | \n",
+ " 11258.694433 | \n",
+ " 190.103733 | \n",
+ " 11197.529198 | \n",
+ " 142.754201 | \n",
+ " 0.088820 | \n",
+ " 9 | \n",
+ "
\n",
+ " \n",
+ " | pelt_numba | \n",
+ " pelt_inputs | \n",
+ " 10 | \n",
+ " 18.575002 | \n",
+ " 21.354103 | \n",
+ " 19.505802 | \n",
+ " 1.110412 | \n",
+ " 18.766604 | \n",
+ " 1.624902 | \n",
+ " 51.266798 | \n",
+ " 5 | \n",
+ "
\n",
+ " \n",
+ " | pelt_jax | \n",
+ " pelt_inputs | \n",
+ " 10 | \n",
+ " 64.749998 | \n",
+ " 77.745900 | \n",
+ " 67.179384 | \n",
+ " 2.991248 | \n",
+ " 66.441699 | \n",
+ " 2.021827 | \n",
+ " 14.885519 | \n",
+ " 20 | \n",
+ "
\n",
+ " \n",
+ " | pelt_pytensor_numba | \n",
+ " pelt_inputs | \n",
+ " 10 | \n",
+ " 2188.929199 | \n",
+ " 2497.112501 | \n",
+ " 2334.866661 | \n",
+ " 123.031774 | \n",
+ " 2302.750002 | \n",
+ " 231.975003 | \n",
+ " 0.428290 | \n",
+ " 5 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Loops Min (us) Max (us) \\\n",
+ "pelt_pytensor pelt_inputs 10 11072.008300 11740.058300 \n",
+ "pelt_numba pelt_inputs 10 18.575002 21.354103 \n",
+ "pelt_jax pelt_inputs 10 64.749998 77.745900 \n",
+ "pelt_pytensor_numba pelt_inputs 10 2188.929199 2497.112501 \n",
+ "\n",
+ " Mean (us) StdDev (us) Median (us) \\\n",
+ "pelt_pytensor pelt_inputs 11258.694433 190.103733 11197.529198 \n",
+ "pelt_numba pelt_inputs 19.505802 1.110412 18.766604 \n",
+ "pelt_jax pelt_inputs 67.179384 2.991248 66.441699 \n",
+ "pelt_pytensor_numba pelt_inputs 2334.866661 123.031774 2302.750002 \n",
+ "\n",
+ " IQR (us) OPS (Kops/s) Samples \n",
+ "pelt_pytensor pelt_inputs 142.754201 0.088820 9 \n",
+ "pelt_numba pelt_inputs 1.624902 51.266798 5 \n",
+ "pelt_jax pelt_inputs 2.021827 14.885519 20 \n",
+ "pelt_pytensor_numba pelt_inputs 231.975003 0.428290 5 "
+ ]
+ },
+ "execution_count": 32,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pelt_bench.run(\n",
+ " inputs={\n",
+ " \"pelt_inputs\": {\"x\": xs_std, \"beta\": 2. * np.log(len(xs_std))},\n",
+ " }\n",
+ ")\n",
+ "pelt_bench.summary()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "id": "47cae530",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "outputs = pelt_numba(xs_std, 2. * np.log(len(xs_std)))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "id": "7b395b06",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def plot_pelt_diagnostics(x, cps, C):\n",
+ " \"\"\"\n",
+ " Diagnostic plots for PELT changepoint detection.\n",
+ " \n",
+ " Args:\n",
+ " x: 1D array, original time series\n",
+ " C: 1D array, cumulative DP cost from pelt()\n",
+ " cps: list of changepoint indices (sorted ascending)\n",
+ " \"\"\"\n",
+ " n = len(x)\n",
+ " cps_full = [0] + cps + [n]\n",
+ "\n",
+ " # Segment means, std, SSE\n",
+ " segment_means = []\n",
+ " segment_stds = []\n",
+ " segment_costs = []\n",
+ " for start, end in zip(cps_full[:-1], cps_full[1:]):\n",
+ " seg = x[start:end]\n",
+ " mean = np.mean(seg)\n",
+ " std = np.std(seg)\n",
+ " cost = np.sum((seg - mean)**2)\n",
+ " segment_means.append(mean)\n",
+ " segment_stds.append(std)\n",
+ " segment_costs.append(cost)\n",
+ "\n",
+ " # Step function for segment mean\n",
+ " mean_step = np.zeros(n)\n",
+ " for i, (start, end) in enumerate(zip(cps_full[:-1], cps_full[1:])):\n",
+ " mean_step[start:end] = segment_means[i]\n",
+ "\n",
+ " # Step function for segment std\n",
+ " std_step = np.zeros(n)\n",
+ " for i, (start, end) in enumerate(zip(cps_full[:-1], cps_full[1:])):\n",
+ " std_step[start:end] = segment_stds[i]\n",
+ "\n",
+ " if len(x) < 20:\n",
+ " title1 = \"Warning: Sample size is small - Detected Changepoints\"\n",
+ " else:\n",
+ " title1 = \"Detected Changepoints\"\n",
+ "\n",
+ " fig = make_subplots(\n",
+ " rows=4, \n",
+ " cols=1,\n",
+ " subplot_titles=(title1, \"Average Shifts\", \"Variability Shifts\", \"Cumulative Cost\")\n",
+ " )\n",
+ "\n",
+ " fig.add_trace(\n",
+ " go.Scatter(\n",
+ " x = np.arange(len(x)),\n",
+ " y = x,\n",
+ " line_color=\"royalblue\",\n",
+ " name = \"Actuals\",\n",
+ " mode=\"lines\",\n",
+ " showlegend=False,\n",
+ " hovertemplate=\"Time Point: %{x}
Actual: %{y}\"\n",
+ " ),\n",
+ " row=1, col=1\n",
+ " )\n",
+ "\n",
+ " for cp in cps:\n",
+ " fig.add_vline(x=cp, line_dash='dash', line_color=\"red\", row=1, col=1)\n",
+ "\n",
+ " fig.add_trace(\n",
+ " go.Scatter(\n",
+ " x = np.arange(len(x)),\n",
+ " y = x,\n",
+ " name = \"Actuals\",\n",
+ " mode=\"lines\",\n",
+ " line_color=\"rgba(105, 105, 105, 0.25)\",\n",
+ " showlegend=False,\n",
+ " hoverinfo=\"skip\"\n",
+ " ),\n",
+ " row=2, col=1\n",
+ " )\n",
+ "\n",
+ " fig.add_trace(\n",
+ " go.Scatter(\n",
+ " x = np.arange(len(x)),\n",
+ " y = mean_step,\n",
+ " name = \"Average\",\n",
+ " line_color=\"royalblue\",\n",
+ " showlegend=False,\n",
+ " hovertemplate=\"Time Point: %{x}
Average: %{y}\"\n",
+ " ),\n",
+ " row=2, col=1\n",
+ " )\n",
+ "\n",
+ " fig.add_trace(\n",
+ " go.Scatter(\n",
+ " x = np.arange(len(x)),\n",
+ " y = std_step,\n",
+ " name = \"Standard Deviation\",\n",
+ " line_color=\"royalblue\",\n",
+ " showlegend=False,\n",
+ " hovertemplate=\"Time Point: %{x}
Standard Deviation: %{y}\"\n",
+ " ),\n",
+ " row=3, col=1\n",
+ " )\n",
+ "\n",
+ " fig.add_trace(\n",
+ " go.Scatter(\n",
+ " x = np.arange(len(x)),\n",
+ " y = C,\n",
+ " name = \"Cumulative Cost\",\n",
+ " line_color=\"royalblue\",\n",
+ " showlegend=False,\n",
+ " hovertemplate=\"Time Point: %{x}
Cost: %{y}\"\n",
+ " ),\n",
+ " row=4, col=1\n",
+ " )\n",
+ "\n",
+ " for cp in cps:\n",
+ " fig.add_vline(x=cp, line_dash='dash', line_color=\"red\", row=4, col=1)\n",
+ "\n",
+ " return fig.update_layout(height=1000, width=1200, template=\"plotly_dark\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "id": "e1de7df5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def get_changepoints(last_change, n):\n",
+ " \"\"\"\n",
+ " Backtrack changepoints from last_change array.\n",
+ " \n",
+ " Args:\n",
+ " last_change: array from pelt()\n",
+ " n: length of input series\n",
+ "\n",
+ " Returns:\n",
+ " list of changepoint indices (sorted ascending)\n",
+ " \"\"\"\n",
+ " cps = []\n",
+ " t = n\n",
+ " while t > 0:\n",
+ " s = int(last_change[t])\n",
+ " if s <= 0:\n",
+ " break\n",
+ " cps.append(s)\n",
+ " t = s\n",
+ " return list(reversed(cps))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "id": "b73d80ac",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "cps = get_changepoints(outputs[1], n=len(xs_std))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "id": "e2e376de",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "plot_pelt_diagnostics(xs, cps, outputs[0])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "09f2e235",
+ "metadata": {},
+ "source": [
+ "# Kalman Filter Algorithms"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1c42336f",
+ "metadata": {},
+ "source": [
+ "## Linear Gaussian Kalman Filter"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "id": "5004eeab",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@jit(nopython=True)\n",
+ "def atrocious_kalman_filter_numba(z, F, H, Q, R, x0, P0):\n",
+ " \"\"\"\n",
+ " This implementation of the Kalman filter is Atrocious and in standard Python would be a \n",
+ " BIG NO-NO. That being said this version SIGNIFICANTLY reduces Numba Compilation time. \n",
+ " \n",
+ " Linear Gaussian Kalman filter algorithm\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " z: np.ndarray\n",
+ " shape (T, m) - observations\n",
+ " F: np.ndarray\n",
+ " state transition matrix - shape (n, n)\n",
+ " H: np.ndarray\n",
+ " observation/design matrix - shape (m, n)\n",
+ " Q: np.ndarray\n",
+ " process noise covariance - shape (n, n)\n",
+ " R: np.ndarray\n",
+ " observation noise covariance - shape (m, m)\n",
+ " x0: np.ndarray\n",
+ " initial state mean - shape (n,)\n",
+ " P0: np.ndarray\n",
+ " initial state covariance - shape (n, n)\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " xs: np.ndarray\n",
+ " shape (T, n) - filtered state means\n",
+ " Ps: np.ndarray\n",
+ " shape (T, n, n) - filtered state covariances\n",
+ " \"\"\"\n",
+ " T = z.shape[0]\n",
+ " m = z.shape[1]\n",
+ " n = x0.shape[0]\n",
+ "\n",
+ " xs = np.empty((T, n), dtype=np.float32)\n",
+ " Ps = np.empty((T, n, n), dtype=np.float32)\n",
+ "\n",
+ " # local working arrays\n",
+ " x = np.empty(n, dtype=np.float32)\n",
+ " for i in range(n):\n",
+ " x[i] = x0[i]\n",
+ " P = np.empty((n, n), dtype=np.float32)\n",
+ " for i in range(n):\n",
+ " for j in range(n):\n",
+ " P[i, j] = P0[i, j]\n",
+ "\n",
+ " # temporary matrices/vectors\n",
+ " x_pred = np.empty((T, n), dtype=np.float32)\n",
+ " P_pred = np.empty((T, n, n), dtype=np.float32)\n",
+ " y = np.empty(m, dtype=np.float32)\n",
+ " S = np.empty((m, m), dtype=np.float32)\n",
+ " K = np.empty((n, m), dtype=np.float32)\n",
+ " I_n = np.eye(n, dtype=np.float32)\n",
+ "\n",
+ " for t in range(T):\n",
+ " # === Predict ===\n",
+ " # x_pred = F @ x\n",
+ " for i in range(n):\n",
+ " s = 0.0\n",
+ " for j in range(n):\n",
+ " s += F[i, j] * x[j]\n",
+ " x_pred[t, i] = s\n",
+ "\n",
+ " # P_pred = F @ P @ F.T + Q\n",
+ " # temp = F @ P\n",
+ " temp = np.empty((n, n), dtype=np.float32)\n",
+ " for i in range(n):\n",
+ " for j in range(n):\n",
+ " s = 0.0\n",
+ " for k in range(n):\n",
+ " s += F[i, k] * P[k, j]\n",
+ " temp[i, j] = s\n",
+ " # P_pred = temp @ F.T\n",
+ " for i in range(n):\n",
+ " for j in range(n):\n",
+ " s = 0.0\n",
+ " for k in range(n):\n",
+ " s += temp[i, k] * F[j, k] # F.T[k, j] = F[j, k]\n",
+ " P_pred[t, i, j] = s + Q[i, j]\n",
+ "\n",
+ " # === Update ===\n",
+ " # y = z[t] - H @ x_pred\n",
+ " for i in range(m):\n",
+ " s = 0.0\n",
+ " for j in range(n):\n",
+ " s += H[i, j] * x_pred[t, j]\n",
+ " y[i] = z[t, i] - s\n",
+ "\n",
+ " # S = H @ P_pred @ H.T + R\n",
+ " # temp2 = H @ P_pred\n",
+ " temp2 = np.empty((m, n), dtype=np.float32)\n",
+ " for i in range(m):\n",
+ " for j in range(n):\n",
+ " s = 0.0\n",
+ " for k in range(n):\n",
+ " s += H[i, k] * P_pred[t, k, j]\n",
+ " temp2[i, j] = s\n",
+ " # S = temp2 @ H.T\n",
+ " for i in range(m):\n",
+ " for j in range(m):\n",
+ " s = 0.0\n",
+ " for k in range(n):\n",
+ " s += temp2[i, k] * H[j, k] # H.T[k,j] = H[j,k]\n",
+ " S[i, j] = s + R[i, j]\n",
+ "\n",
+ " # K = P_pred @ H.T @ inv(S)\n",
+ " # first compute P_pred @ H.T -> (n, m)\n",
+ " P_Ht = np.empty((n, m), dtype=np.float32)\n",
+ " for i in range(n):\n",
+ " for j in range(m):\n",
+ " s = 0.0\n",
+ " for k in range(n):\n",
+ " s += P_pred[t, i, k] * H[j, k] # H.T[k,j] = H[j,k]\n",
+ " P_Ht[i, j] = s\n",
+ "\n",
+ " # invert S\n",
+ " S_inv = np.linalg.inv(S)\n",
+ "\n",
+ " # K = P_Ht @ S_inv (n,m) @ (m,m) -> (n,m)\n",
+ " for i in range(n):\n",
+ " for j in range(m):\n",
+ " s = 0.0\n",
+ " for k in range(m):\n",
+ " s += P_Ht[i, k] * S_inv[k, j]\n",
+ " K[i, j] = s\n",
+ "\n",
+ " # x = x_pred + K @ y\n",
+ " for i in range(n):\n",
+ " s = 0.0\n",
+ " for j in range(m):\n",
+ " s += K[i, j] * y[j]\n",
+ " x[i] = x_pred[t, i] + s\n",
+ "\n",
+ " # P = (I - K H) P_pred\n",
+ " # compute (I - K H)\n",
+ " KH = np.empty((n, n), dtype=np.float32)\n",
+ " for i in range(n):\n",
+ " for j in range(n):\n",
+ " s = 0.0\n",
+ " for k in range(m):\n",
+ " s += K[i, k] * H[k, j]\n",
+ " KH[i, j] = s\n",
+ "\n",
+ " I_minus_KH = np.empty((n, n), dtype=np.float32)\n",
+ " for i in range(n):\n",
+ " for j in range(n):\n",
+ " I_minus_KH[i, j] = I_n[i, j] - KH[i, j]\n",
+ "\n",
+ " # P = I_minus_KH @ P_pred\n",
+ " for i in range(n):\n",
+ " for j in range(n):\n",
+ " s = 0.0\n",
+ " for k in range(n):\n",
+ " s += I_minus_KH[i, k] * P_pred[t, k, j]\n",
+ " P[i, j] = s\n",
+ "\n",
+ " # store results\n",
+ " for i in range(n):\n",
+ " xs[t, i] = x[i]\n",
+ " for i in range(n):\n",
+ " for j in range(n):\n",
+ " Ps[t, i, j] = P[i, j]\n",
+ "\n",
+ " return xs, Ps, x_pred, P_pred\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 39,
+ "id": "25bcb14e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@jit(nopython=True)\n",
+ "def kalman_filter_numba(z, F, H, Q, R, x0, P0):\n",
+ " \"\"\"\n",
+ " Linear Gaussian Kalman filter algorithm\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " z: np.ndarray\n",
+ " shape (T, m) - observations\n",
+ " F: np.ndarray\n",
+ " state transition matrix - shape (n, n)\n",
+ " H: np.ndarray\n",
+ " observation/design matrix - shape (m, n)\n",
+ " Q: np.ndarray\n",
+ " process noise covariance - shape (n, n)\n",
+ " R: np.ndarray\n",
+ " observation noise covariance - shape (m, m)\n",
+ " x0: np.ndarray\n",
+ " initial state mean - shape (n,)\n",
+ " P0: np.ndarray\n",
+ " initial state covariance - shape (n, n)\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " xs: np.ndarray\n",
+ " shape (T, n) - filtered state means\n",
+ " Ps: np.ndarray\n",
+ " shape (T, n, n) - filtered state covariances\n",
+ " \"\"\"\n",
+ " T, m = z.shape\n",
+ " n = x0.shape[0]\n",
+ "\n",
+ " xs = np.zeros((T, n), dtype=np.float32)\n",
+ " Ps = np.zeros((T, n, n), dtype=np.float32)\n",
+ "\n",
+ " x_pred = np.zeros((T, n), dtype=np.float32)\n",
+ " P_pred = np.zeros((T, n, n), dtype=np.float32)\n",
+ "\n",
+ " x = x0.copy()\n",
+ " P = P0.copy()\n",
+ "\n",
+ " I = np.eye(n, dtype=np.float32)\n",
+ "\n",
+ " for t in range(T):\n",
+ " # --- Predict ---\n",
+ " x_pred[t] = F @ x\n",
+ " P_pred[t] = F @ P @ F.T + Q\n",
+ "\n",
+ " # --- Update ---\n",
+ " y = z[t] - H @ x_pred[t]\n",
+ " S = H @ P_pred[t] @ H.T + R\n",
+ " K = P_pred[t] @ H.T @ np.linalg.inv(S)\n",
+ "\n",
+ " x = x_pred[t] + K @ y\n",
+ " P = (I - K @ H) @ P_pred[t]\n",
+ "\n",
+ " xs[t] = x\n",
+ " Ps[t] = P\n",
+ "\n",
+ " return xs, Ps, x_pred, P_pred"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "id": "cd715dfb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@block\n",
+ "@jax.jit\n",
+ "def kalman_filter_jax(z, F, H, Q, R, x0, P0):\n",
+ " \"\"\"\n",
+ " Linear Gaussian Kalman filter algorithm\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " z: np.ndarray\n",
+ " shape (T, m) - observations\n",
+ " F: np.ndarray\n",
+ " state transition matrix - shape (n, n)\n",
+ " H: np.ndarray\n",
+ " observation/design matrix - shape (m, n)\n",
+ " Q: np.ndarray\n",
+ " process noise covariance - shape (n, n)\n",
+ " R: np.ndarray\n",
+ " observation noise covariance - shape (m, m)\n",
+ " x0: np.ndarray\n",
+ " initial state mean - shape (n,)\n",
+ " P0: np.ndarray\n",
+ " initial state covariance - shape (n, n)\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " xs: jnp.ndarray\n",
+ " shape (T, n) - filtered state means\n",
+ " Ps: jnp.ndarray\n",
+ " shape (T, n, n) - filtered state covariances\n",
+ " \"\"\"\n",
+ "\n",
+ " n = x0.shape[0]\n",
+ " I = jnp.eye(n)\n",
+ " X_pred_init = jnp.zeros((1,))\n",
+ " P_pred_init = jnp.zeros((1, 1,))\n",
+ "\n",
+ " def step(carry, z_t):\n",
+ " x, P, _, _ = carry\n",
+ "\n",
+ " # --- Predict ---\n",
+ " x_pred = F @ x\n",
+ " P_pred = F @ P @ F.T + Q\n",
+ "\n",
+ " # --- Update ---\n",
+ " y = z_t - H @ x_pred\n",
+ " S = H @ P_pred @ H.T + R\n",
+ " K = P_pred @ H.T @ jnp.linalg.inv(S)\n",
+ "\n",
+ " x_new = x_pred + K @ y\n",
+ " P_new = (I - K @ H) @ P_pred\n",
+ "\n",
+ " return (x_new, P_new, x_pred, P_pred), (x_new, P_new, x_pred, P_pred)\n",
+ "\n",
+ " # run scan\n",
+ " (_, _, _, _), (xs, Ps, x_pred, P_pred) = jax.lax.scan(step, (x0, P0, X_pred_init, P_pred_init), z)\n",
+ "\n",
+ " return xs, Ps, x_pred, P_pred"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "id": "5af37c40",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "z_symbolic = pt.matrix(\"z\")\n",
+ "F_symbolic = pt.matrix(\"F\")\n",
+ "H_symbolic = pt.matrix(\"H\")\n",
+ "Q_symbolic = pt.matrix(\"Q\")\n",
+ "R_symbolic = pt.matrix(\"R\")\n",
+ "x0_symbolic = pt.vector(\"x0\")\n",
+ "P0_symbolic = pt.matrix(\"P0\")\n",
+ "\n",
+ "n = x0_symbolic.shape[0]\n",
+ "I = pt.eye(n)\n",
+ "X_pred_init = pt.zeros_like(x0_symbolic)\n",
+ "P_pred_init = pt.zeros_like(P0_symbolic)\n",
+ "\n",
+ "N_STEPS = 500\n",
+ "\n",
+ "def step(z_t, x, P, x_pred, P_pred, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, I):\n",
+ "\n",
+ " # --- Predict ---\n",
+ " x_pred = F_symbolic @ x\n",
+ " P_pred = F_symbolic @ P @ F_symbolic.T + Q_symbolic\n",
+ "\n",
+ " # --- Update ---\n",
+ " y = z_t - H_symbolic @ x_pred\n",
+ " S = H_symbolic @ P_pred @ H_symbolic.T + R_symbolic\n",
+ " K = P_pred @ H_symbolic.T @ pt.linalg.inv(S)\n",
+ "\n",
+ " x_new = x_pred + K @ y\n",
+ " P_new = (I - K @ H_symbolic) @ P_pred\n",
+ "\n",
+ " return x_new, P_new, x_pred, P_pred\n",
+ "\n",
+ "# run scan\n",
+ "(xs, Ps, x_pred, P_pred), _ = pytensor.scan(\n",
+ " fn=step,\n",
+ " outputs_info=[x0_symbolic, P0_symbolic, X_pred_init, P_pred_init],\n",
+ " sequences=[z_symbolic],\n",
+ " non_sequences=[F_symbolic, H_symbolic, Q_symbolic, R_symbolic, I],\n",
+ " n_steps=N_STEPS\n",
+ ")\n",
+ "\n",
+ "kalman_filter_pytensor = pytensor.function([z_symbolic, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, x0_symbolic, P0_symbolic], [xs, Ps, x_pred, P_pred], trust_input=True)\n",
+ "\n",
+ "kalman_filter_pytensor_numba = pytensor.function([z_symbolic, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, x0_symbolic, P0_symbolic], [xs, Ps, x_pred, P_pred], mode=\"NUMBA\", trust_input=True)\n",
+ "kalman_filter_pytensor_jax = pytensor.function([z_symbolic, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, x0_symbolic, P0_symbolic], [xs, Ps, x_pred, P_pred], mode=\"JAX\", trust_input=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 42,
+ "id": "3c9e7254",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "T = 500\n",
+ "F = np.array([[1.0]]).astype(np.float32)\n",
+ "H = np.array([[1.0]]).astype(np.float32)\n",
+ "Q = np.array([[0.01]]).astype(np.float32)\n",
+ "R = np.array([[0.1]]).astype(np.float32)\n",
+ "x0 = np.array([0.0]).astype(np.float32)\n",
+ "P0 = np.array([[1.0]]).astype(np.float32)\n",
+ "\n",
+ "true = 1.0\n",
+ "z = (true + 0.4*np.random.randn(T)).reshape(T, 1).astype(np.float32)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 43,
+ "id": "00472afd",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "kalman_filter_bench = Benchmarker(\n",
+ " functions=[kalman_filter_pytensor, atrocious_kalman_filter_numba, kalman_filter_numba, kalman_filter_jax, kalman_filter_pytensor_numba, block(kalman_filter_pytensor_jax)], \n",
+ " names=['kalman_filter_pytensor', 'atrocious_kalman_filter_numba', 'kalman_filter_numba', 'kalman_filter_jax', 'kalman_filter_pytensor_numba', 'kalman_filter_pytensor_jax'],\n",
+ " number=10\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 44,
+ "id": "68ec703d",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/var/folders/qv/63yqp4p50630y7pqfgcbgtq80000gn/T/tmp7loid6xb:33: NumbaWarning:\n",
+ "\n",
+ "\u001b[1m\u001b[1mCannot cache compiled function \"scan\" as it uses dynamic globals (such as ctypes pointers and large global arrays)\u001b[0m\u001b[0m\n",
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " | \n",
+ " Loops | \n",
+ " Min (us) | \n",
+ " Max (us) | \n",
+ " Mean (us) | \n",
+ " StdDev (us) | \n",
+ " Median (us) | \n",
+ " IQR (us) | \n",
+ " OPS (Kops/s) | \n",
+ " Samples | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | kalman_filter_pytensor | \n",
+ " kalman_filter_inputs | \n",
+ " 10 | \n",
+ " 5683.275004 | \n",
+ " 7487.900002 | \n",
+ " 6232.157842 | \n",
+ " 453.579690 | \n",
+ " 6133.425003 | \n",
+ " 294.808298 | \n",
+ " 0.160458 | \n",
+ " 17 | \n",
+ "
\n",
+ " \n",
+ " | atrocious_kalman_filter_numba | \n",
+ " kalman_filter_inputs | \n",
+ " 10 | \n",
+ " 305.670901 | \n",
+ " 337.154098 | \n",
+ " 318.397480 | \n",
+ " 11.714092 | \n",
+ " 313.770800 | \n",
+ " 17.316604 | \n",
+ " 3.140728 | \n",
+ " 5 | \n",
+ "
\n",
+ " \n",
+ " | kalman_filter_numba | \n",
+ " kalman_filter_inputs | \n",
+ " 10 | \n",
+ " 610.762503 | \n",
+ " 686.358300 | \n",
+ " 657.730820 | \n",
+ " 26.215497 | \n",
+ " 670.712499 | \n",
+ " 21.345803 | \n",
+ " 1.520379 | \n",
+ " 5 | \n",
+ "
\n",
+ " \n",
+ " | kalman_filter_jax | \n",
+ " kalman_filter_inputs | \n",
+ " 10 | \n",
+ " 300.266704 | \n",
+ " 360.233302 | \n",
+ " 323.625648 | \n",
+ " 15.903232 | \n",
+ " 319.320802 | \n",
+ " 20.610352 | \n",
+ " 3.089990 | \n",
+ " 19 | \n",
+ "
\n",
+ " \n",
+ " | kalman_filter_pytensor_numba | \n",
+ " kalman_filter_inputs | \n",
+ " 10 | \n",
+ " 748.704199 | \n",
+ " 764.174998 | \n",
+ " 759.811681 | \n",
+ " 5.623608 | \n",
+ " 761.879201 | \n",
+ " 1.000002 | \n",
+ " 1.316116 | \n",
+ " 5 | \n",
+ "
\n",
+ " \n",
+ " | kalman_filter_pytensor_jax | \n",
+ " kalman_filter_inputs | \n",
+ " 10 | \n",
+ " 297.354202 | \n",
+ " 347.566698 | \n",
+ " 322.995273 | \n",
+ " 18.748797 | \n",
+ " 321.349999 | \n",
+ " 36.767675 | \n",
+ " 3.096021 | \n",
+ " 22 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Loops Min (us) \\\n",
+ "kalman_filter_pytensor kalman_filter_inputs 10 5683.275004 \n",
+ "atrocious_kalman_filter_numba kalman_filter_inputs 10 305.670901 \n",
+ "kalman_filter_numba kalman_filter_inputs 10 610.762503 \n",
+ "kalman_filter_jax kalman_filter_inputs 10 300.266704 \n",
+ "kalman_filter_pytensor_numba kalman_filter_inputs 10 748.704199 \n",
+ "kalman_filter_pytensor_jax kalman_filter_inputs 10 297.354202 \n",
+ "\n",
+ " Max (us) Mean (us) \\\n",
+ "kalman_filter_pytensor kalman_filter_inputs 7487.900002 6232.157842 \n",
+ "atrocious_kalman_filter_numba kalman_filter_inputs 337.154098 318.397480 \n",
+ "kalman_filter_numba kalman_filter_inputs 686.358300 657.730820 \n",
+ "kalman_filter_jax kalman_filter_inputs 360.233302 323.625648 \n",
+ "kalman_filter_pytensor_numba kalman_filter_inputs 764.174998 759.811681 \n",
+ "kalman_filter_pytensor_jax kalman_filter_inputs 347.566698 322.995273 \n",
+ "\n",
+ " StdDev (us) Median (us) \\\n",
+ "kalman_filter_pytensor kalman_filter_inputs 453.579690 6133.425003 \n",
+ "atrocious_kalman_filter_numba kalman_filter_inputs 11.714092 313.770800 \n",
+ "kalman_filter_numba kalman_filter_inputs 26.215497 670.712499 \n",
+ "kalman_filter_jax kalman_filter_inputs 15.903232 319.320802 \n",
+ "kalman_filter_pytensor_numba kalman_filter_inputs 5.623608 761.879201 \n",
+ "kalman_filter_pytensor_jax kalman_filter_inputs 18.748797 321.349999 \n",
+ "\n",
+ " IQR (us) OPS (Kops/s) \\\n",
+ "kalman_filter_pytensor kalman_filter_inputs 294.808298 0.160458 \n",
+ "atrocious_kalman_filter_numba kalman_filter_inputs 17.316604 3.140728 \n",
+ "kalman_filter_numba kalman_filter_inputs 21.345803 1.520379 \n",
+ "kalman_filter_jax kalman_filter_inputs 20.610352 3.089990 \n",
+ "kalman_filter_pytensor_numba kalman_filter_inputs 1.000002 1.316116 \n",
+ "kalman_filter_pytensor_jax kalman_filter_inputs 36.767675 3.096021 \n",
+ "\n",
+ " Samples \n",
+ "kalman_filter_pytensor kalman_filter_inputs 17 \n",
+ "atrocious_kalman_filter_numba kalman_filter_inputs 5 \n",
+ "kalman_filter_numba kalman_filter_inputs 5 \n",
+ "kalman_filter_jax kalman_filter_inputs 19 \n",
+ "kalman_filter_pytensor_numba kalman_filter_inputs 5 \n",
+ "kalman_filter_pytensor_jax kalman_filter_inputs 22 "
+ ]
+ },
+ "execution_count": 44,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "kalman_filter_bench.run(\n",
+ " inputs={\n",
+ " \"kalman_filter_inputs\": {\"z\": z, \"F\": F, \"H\": H, \"Q\": Q, \"R\": R, \"x0\": x0, \"P0\": P0},\n",
+ " }\n",
+ ")\n",
+ "kalman_filter_bench.summary()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 45,
+ "id": "37526f86",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "xs, Ps, x_pred, P_pred = kalman_filter_jax(z, F, H, Q, R, x0, P0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 46,
+ "id": "231140ba",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def compute_pred_intervals(z, x_pred, P_pred, H, R, zscore=1.96):\n",
+ " T = z.shape[0]\n",
+ " m = H.shape[0]\n",
+ " mean = np.zeros((T, m))\n",
+ " lower = np.zeros((T, m))\n",
+ " upper = np.zeros((T, m))\n",
+ " outside = np.zeros(T, dtype=np.bool_)\n",
+ "\n",
+ " for t in range(T):\n",
+ " mean[t] = H @ x_pred[t]\n",
+ " S = H @ P_pred[t] @ H.T + R\n",
+ " std = np.sqrt(np.diag(S))\n",
+ " lower[t] = mean[t] - zscore * std\n",
+ " upper[t] = mean[t] + zscore * std\n",
+ "\n",
+ " # check coverage of actual obs\n",
+ " outside[t] = np.any((z[t] < lower[t]) | (z[t] > upper[t]))\n",
+ "\n",
+ " coverage = 1 - outside.mean()\n",
+ " return mean, lower, upper, coverage\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 47,
+ "id": "ff739003",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mean, lower, upper, coverage = compute_pred_intervals(z, x_pred, P_pred, H, R)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 48,
+ "id": "7546c95c",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "np.float64(0.886)"
+ ]
+ },
+ "execution_count": 48,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "coverage"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "id": "6b765a37",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "fig= go.Figure()\n",
+ "fig.add_traces(\n",
+ " [\n",
+ " go.Scatter(\n",
+ " x = np.arange(T),\n",
+ " y = z.ravel(),\n",
+ " mode=\"markers\",\n",
+ " marker_color = \"royalblue\",\n",
+ " name = \"actuals\"\n",
+ " ),\n",
+ " go.Scatter(\n",
+ " x = np.arange(T),\n",
+ " y = xs.ravel(),\n",
+ " mode = \"lines\",\n",
+ " marker_color = \"orange\",\n",
+ " name = \"filtered mean\"\n",
+ " ),\n",
+ " go.Scatter(\n",
+ " name=\"\", \n",
+ " x=np.arange(T), \n",
+ " y=upper.ravel(), \n",
+ " mode=\"lines\", \n",
+ " marker=dict(color=\"#eb8c34\"), \n",
+ " line=dict(width=0), \n",
+ " legendgroup=\"95% CI\",\n",
+ " showlegend=False\n",
+ " ),\n",
+ " go.Scatter(\n",
+ " name=\"95% CI\", \n",
+ " x=np.arange(T), \n",
+ " y=lower.ravel(), \n",
+ " mode=\"lines\", marker=dict(color=\"#eb8c34\"), \n",
+ " line=dict(width=0), \n",
+ " legendgroup=\"95% CI\", \n",
+ " fill='tonexty', \n",
+ " fillcolor='rgba(235, 140, 52, 0.2)'\n",
+ " ),\n",
+ "\n",
+ " ]\n",
+ ")\n",
+ "fig.update_layout(\n",
+ " xaxis=dict(\n",
+ " title = \"Time Index\",\n",
+ " ),\n",
+ " yaxis=dict(\n",
+ " title = \"y\"\n",
+ " ),\n",
+ " template = \"plotly_dark\"\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "052b0194",
+ "metadata": {},
+ "source": [
+ "## Non-linear Kalman Filter"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 50,
+ "id": "6b088d6b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@jit(nopython=True)\n",
+ "def loglik_poisson_numba(s, y):\n",
+ " \"\"\"Poisson Log Likelihood\"\"\"\n",
+ " mu = np.exp(s)\n",
+ " return y * np.log(mu + 1e-30) - mu - math.lgamma(y + 1.0) # numba does not support scipy.special gammaln\n",
+ "\n",
+ "@jit(nopython=True)\n",
+ "def particle_filter_1d_predict_numba(A, Q, x0_mean, x0_std, ys, N=1000, seed=2):\n",
+ " \"\"\"\n",
+ " 1D particle filter.\n",
+ " \n",
+ " Parameters\n",
+ " ----------\n",
+ " A: float\n",
+ " State transition\n",
+ " Q: float\n",
+ " Process covariance\n",
+ " x0_mean: float\n",
+ " Prior mean for the latent state\n",
+ " x0_std: float\n",
+ " Prior standard deviation \n",
+ " ys: np.ndarray\n",
+ " observations\n",
+ " N: int\n",
+ " number of particles\n",
+ " seed: int\n",
+ " rng seed for reproducibility\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " filtered_means: np.ndarray\n",
+ " The filtered mean for the latent state \n",
+ " filtered_vars: np.ndarray\n",
+ " The filtered variance for the latent state\n",
+ " pred_means: np.ndarray\n",
+ " observation predicted mean \n",
+ " \"\"\"\n",
+ " np.random.seed(seed)\n",
+ " T = ys.shape[0]\n",
+ " particles = np.random.normal(x0_mean, x0_std, size=N)\n",
+ " weights = np.ones(N) / N\n",
+ "\n",
+ " filtered_means = np.zeros(T)\n",
+ " filtered_vars = np.zeros(T)\n",
+ " pred_means = np.zeros(T)\n",
+ "\n",
+ " for t in range(T):\n",
+ " y = ys[t]\n",
+ "\n",
+ " # propagate (vectorized)\n",
+ " particles = A * particles + np.random.normal(0, np.sqrt(Q), size=N)\n",
+ "\n",
+ " # update weights\n",
+ " logw = np.zeros(N)\n",
+ " for i in range(N):\n",
+ " logw[i] = loglik_poisson_numba(particles[i], y)\n",
+ " logw = logw - np.max(logw)\n",
+ " weights *= np.exp(logw)\n",
+ " weights /= np.sum(weights) + 1e-12\n",
+ "\n",
+ " # filtered moments\n",
+ " mean_t = np.sum(weights * particles)\n",
+ " var_t = np.sum(weights * (particles - mean_t) ** 2)\n",
+ "\n",
+ " # predictive mean\n",
+ " pred_mean = np.sum(weights * np.exp(particles))\n",
+ "\n",
+ " filtered_means[t] = mean_t\n",
+ " filtered_vars[t] = var_t\n",
+ " pred_means[t] = pred_mean\n",
+ "\n",
+ " # resample (multinomial resampling) because numba doesn't support np.random.choice\n",
+ " cumulative_sum = np.cumsum(weights)\n",
+ " cumulative_sum[-1] = 1.0 # guard against rounding error\n",
+ " indices = np.searchsorted(cumulative_sum, np.random.rand(N))\n",
+ "\n",
+ " particles = particles[indices]\n",
+ " weights = np.ones(N) / N\n",
+ "\n",
+ " return filtered_means, filtered_vars, pred_means"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 51,
+ "id": "a468c403",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Had to fix the loglikelihood and key to use benchmarker as is\n",
+ "def loglik_poisson_jax(s, y):\n",
+ " \"\"\"Poisson Log Likelihood\"\"\"\n",
+ " mu = jnp.exp(s)\n",
+ " return y * jnp.log(mu + 1e-30) - mu - gammaln(y + 1.0)\n",
+ "\n",
+ "@block\n",
+ "@partial(jax.jit, static_argnums=5)\n",
+ "def particle_filter_1d_predict_jax(\n",
+ " A, Q, x0_mean, x0_std, ys, N=1000,\n",
+ "):\n",
+ " \"\"\"\n",
+ " 1D particle filter.\n",
+ " \n",
+ " Parameters\n",
+ " ----------\n",
+ " A: float\n",
+ " State transition\n",
+ " Q: float\n",
+ " Process covariance\n",
+ " x0_mean: float\n",
+ " Prior mean for the latent state\n",
+ " x0_std: float\n",
+ " Prior standard deviation \n",
+ " ys: np.ndarray\n",
+ " observations\n",
+ " loglik_fn: function\n",
+ " The log likelihood function\n",
+ " key: \n",
+ " JAX prng key\n",
+ " N: int\n",
+ " number of particles\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " filtered_means: jnp.ndarray\n",
+ " The filtered mean for the latent state \n",
+ " filtered_vars: jnp.ndarray\n",
+ " The filtered variance for the latent state\n",
+ " pred_means: jnp.ndarray\n",
+ " observation predicted mean \n",
+ " \"\"\"\n",
+ " key = jax.random.PRNGKey(0)\n",
+ " T = ys.shape[0]\n",
+ " particles = jax.random.normal(key, (N,)) * x0_std + x0_mean # init particles from gaussian priors\n",
+ " weights = jnp.ones(N) / N # particle weights, all particles equally likely prior\n",
+ "\n",
+ " def body_fun(carry, t):\n",
+ " particles, weights, key = carry\n",
+ " y = ys[t]\n",
+ "\n",
+ " # propagate\n",
+ " key, subkey = jax.random.split(key)\n",
+ " particles = A * particles + jax.random.normal(subkey, (N,)) * jnp.sqrt(Q) # state transition model\n",
+ "\n",
+ " # update weights\n",
+ " logw = jax.vmap(lambda x: loglik_poisson_jax(x, y))(particles) # update particles in parallel\n",
+ " logw = logw - jnp.max(logw) # avoid overflow\n",
+ " weights = weights * jnp.exp(logw) # old weights times the likelihood\n",
+ " weights /= jnp.sum(weights) + 1e-12 # normalize so that weights sum to 1\n",
+ "\n",
+ " # filtered moments\n",
+ " mean_t = jnp.sum(weights * particles) # posterior mean of latent state\n",
+ " var_t = jnp.sum(weights * (particles - mean_t)**2) # posterior variance of latent state\n",
+ "\n",
+ " # predictive mean\n",
+ " pred_mean = jnp.sum(weights * jnp.exp(particles))\n",
+ "\n",
+ " # resample to prevent dominant particles\n",
+ " key, subkey = jax.random.split(key)\n",
+ " indices = jax.random.choice(subkey, N, p=weights, shape=(N,))\n",
+ " particles = particles[indices]\n",
+ " weights = jnp.ones(N) / N\n",
+ "\n",
+ " carry = (particles, weights, key)\n",
+ " out = (mean_t, var_t, pred_mean)\n",
+ " return carry, out\n",
+ "\n",
+ " _, outputs = jax.lax.scan(body_fun, (particles, weights, key), jnp.arange(T))\n",
+ " return outputs\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 52,
+ "id": "15fd4a0c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from pytensor.tensor.random.utils import RandomStream\n",
+ "\n",
+ "# Random stream for PyTensor\n",
+ "srng = RandomStream(seed=42)\n",
+ "\n",
+ "# Poisson log-likelihood\n",
+ "def loglik_poisson_pytensor(s, y):\n",
+ " mu = pt.exp(s)\n",
+ " return y.flatten() * pt.log(mu + 1e-30) - mu - pt.gammaln(y.flatten() + 1.0)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 53,
+ "id": "59525da5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ys_symbolic = pt.vector(\"ys\")\n",
+ "x0_mean_symbolic = pt.scalar(\"x0_mean\")\n",
+ "x0_std_symbolic = pt.scalar(\"x0_std\")\n",
+ "A_symbolic = pt.scalar(\"A\")\n",
+ "Q_symbolic = pt.scalar(\"Q\")\n",
+ "N_symbolic = pt.scalar(\"N\", dtype='int64')\n",
+ "\n",
+ "N_STEPS = 300\n",
+ "\n",
+ "# Initialize particles and weights\n",
+ "particles_init = srng.normal(size=(N_symbolic,)) * x0_std_symbolic + x0_mean_symbolic\n",
+ "weights_init = pt.ones((N_symbolic,)) / N_symbolic \n",
+ "\n",
+ "# Step function for scan\n",
+ "def step(y_t, particles_prev, weights_prev, A_symbolic, Q_symbolic):\n",
+ " # Propagate particles\n",
+ " particles_prop = A_symbolic * particles_prev + srng.normal(size=(N_symbolic,)) * pt.sqrt(Q_symbolic)\n",
+ "\n",
+ " # Update weights\n",
+ " # logw = pt.stack([loglik_poisson_pytensor(p, y_t) for p in particles_prop])\n",
+ " logw = loglik_poisson_pytensor(particles_prop, y_t)\n",
+ " logw_stable = logw - pt.max(logw)\n",
+ " w_unnorm = weights_prev * pt.exp(logw_stable)\n",
+ " w = w_unnorm / (pt.sum(w_unnorm) + 1e-12) \n",
+ "\n",
+ " # Filtered moments\n",
+ " mean_t = pt.sum(w * particles_prop)\n",
+ " var_t = pt.sum(w * (particles_prop - mean_t) ** 2)\n",
+ " pred_mean = pt.sum(w * pt.exp(particles_prop))\n",
+ "\n",
+ " # Resample particles\n",
+ " idx = srng.choice(size=(N_symbolic,), a=N_symbolic, p=w) \n",
+ " particles_resampled = particles_prop[idx]\n",
+ " weights_resampled = pt.ones((N_symbolic,)) / N_symbolic\n",
+ "\n",
+ " # Return flat tuple\n",
+ " return particles_resampled, weights_resampled, mean_t, var_t, pred_mean\n",
+ "\n",
+ "# first two are recurrent, rest are collected\n",
+ "outputs_info = [\n",
+ " particles_init,\n",
+ " weights_init,\n",
+ " None,\n",
+ " None,\n",
+ " None\n",
+ "]\n",
+ "\n",
+ "(particles_seq, weights_seq, means_seq, vars_seq, preds_seq), updates = pytensor.scan(\n",
+ " fn=step,\n",
+ " sequences=[ys_symbolic],\n",
+ " outputs_info=outputs_info,\n",
+ " non_sequences=[A_symbolic, Q_symbolic],\n",
+ " n_steps=N_STEPS\n",
+ ")\n",
+ "\n",
+ "particle_filter_1d_predict_pytensor = pytensor.function(\n",
+ " [A_symbolic, Q_symbolic, x0_mean_symbolic, x0_std_symbolic, ys_symbolic, N_symbolic],\n",
+ " [means_seq, vars_seq, preds_seq],\n",
+ " updates=updates,\n",
+ " no_default_updates=True,\n",
+ " trust_input=True\n",
+ ")\n",
+ "\n",
+ "particle_filter_1d_predict_pytensor_numba = pytensor.function(\n",
+ " [A_symbolic, Q_symbolic, x0_mean_symbolic, x0_std_symbolic, ys_symbolic, N_symbolic],\n",
+ " [means_seq, vars_seq, preds_seq],\n",
+ " updates=updates,\n",
+ " no_default_updates=True,\n",
+ " mode=\"NUMBA\", \n",
+ " trust_input=True\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 54,
+ "id": "c033a1d0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "key = jax.random.PRNGKey(0)\n",
+ "T = 300\n",
+ "A = 0.95\n",
+ "Q = 0.05\n",
+ "rng = np.random.RandomState(1)\n",
+ "\n",
+ "target_mean = 10.0\n",
+ "latent_var = Q / (1 - A**2)\n",
+ "x0_mean = np.log(target_mean) - 0.5 * latent_var\n",
+ "x0_std = 1.0\n",
+ "\n",
+ "# Simulate latent\n",
+ "x = np.zeros(T)\n",
+ "x[0] = rng.normal() * np.sqrt(latent_var) + x0_mean\n",
+ "for t in range(1, T):\n",
+ " x[t] = A * x[t-1] + rng.normal() * np.sqrt(Q)\n",
+ "\n",
+ "ys = np.array(rng.poisson(np.exp(x)), dtype=np.float32)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 55,
+ "id": "2a9cbfa5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "nonlinear_kalman_filter_bench = Benchmarker(\n",
+ " functions=[particle_filter_1d_predict_pytensor, particle_filter_1d_predict_numba, particle_filter_1d_predict_jax, particle_filter_1d_predict_pytensor_numba,], \n",
+ " names=['particle_filter_1d_predict_pytensor', 'particle_filter_1d_predict_numba', 'particle_filter_1d_predict_jax', 'particle_filter_1d_predict_pytensor_numba',],\n",
+ " number=5 # This takes a while to run reducing number of loops\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 56,
+ "id": "c782c42b",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " | \n",
+ " Loops | \n",
+ " Min (us) | \n",
+ " Max (us) | \n",
+ " Mean (us) | \n",
+ " StdDev (us) | \n",
+ " Median (us) | \n",
+ " IQR (us) | \n",
+ " OPS (Kops/s) | \n",
+ " Samples | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | particle_filter_1d_predict_pytensor | \n",
+ " kalman_filter_inputs | \n",
+ " 5 | \n",
+ " 696693.225007 | \n",
+ " 742831.816606 | \n",
+ " 716956.188362 | \n",
+ " 17644.519347 | \n",
+ " 708739.150001 | \n",
+ " 28481.516603 | \n",
+ " 0.001395 | \n",
+ " 5 | \n",
+ "
\n",
+ " \n",
+ " | particle_filter_1d_predict_numba | \n",
+ " kalman_filter_inputs | \n",
+ " 5 | \n",
+ " 46922.808408 | \n",
+ " 49887.733196 | \n",
+ " 48122.836680 | \n",
+ " 1115.921717 | \n",
+ " 48200.741794 | \n",
+ " 1708.116406 | \n",
+ " 0.020780 | \n",
+ " 5 | \n",
+ "
\n",
+ " \n",
+ " | particle_filter_1d_predict_jax | \n",
+ " kalman_filter_inputs | \n",
+ " 5 | \n",
+ " 30683.133402 | \n",
+ " 31434.324989 | \n",
+ " 31081.088318 | \n",
+ " 299.728880 | \n",
+ " 30930.791609 | \n",
+ " 498.875196 | \n",
+ " 0.032174 | \n",
+ " 5 | \n",
+ "
\n",
+ " \n",
+ " | particle_filter_1d_predict_pytensor_numba | \n",
+ " kalman_filter_inputs | \n",
+ " 5 | \n",
+ " 655818.125000 | \n",
+ " 702576.024993 | \n",
+ " 680401.395040 | \n",
+ " 18440.632807 | \n",
+ " 684621.258406 | \n",
+ " 34198.999999 | \n",
+ " 0.001470 | \n",
+ " 5 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Loops \\\n",
+ "particle_filter_1d_predict_pytensor kalman_filter_inputs 5 \n",
+ "particle_filter_1d_predict_numba kalman_filter_inputs 5 \n",
+ "particle_filter_1d_predict_jax kalman_filter_inputs 5 \n",
+ "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 5 \n",
+ "\n",
+ " Min (us) \\\n",
+ "particle_filter_1d_predict_pytensor kalman_filter_inputs 696693.225007 \n",
+ "particle_filter_1d_predict_numba kalman_filter_inputs 46922.808408 \n",
+ "particle_filter_1d_predict_jax kalman_filter_inputs 30683.133402 \n",
+ "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 655818.125000 \n",
+ "\n",
+ " Max (us) \\\n",
+ "particle_filter_1d_predict_pytensor kalman_filter_inputs 742831.816606 \n",
+ "particle_filter_1d_predict_numba kalman_filter_inputs 49887.733196 \n",
+ "particle_filter_1d_predict_jax kalman_filter_inputs 31434.324989 \n",
+ "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 702576.024993 \n",
+ "\n",
+ " Mean (us) \\\n",
+ "particle_filter_1d_predict_pytensor kalman_filter_inputs 716956.188362 \n",
+ "particle_filter_1d_predict_numba kalman_filter_inputs 48122.836680 \n",
+ "particle_filter_1d_predict_jax kalman_filter_inputs 31081.088318 \n",
+ "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 680401.395040 \n",
+ "\n",
+ " StdDev (us) \\\n",
+ "particle_filter_1d_predict_pytensor kalman_filter_inputs 17644.519347 \n",
+ "particle_filter_1d_predict_numba kalman_filter_inputs 1115.921717 \n",
+ "particle_filter_1d_predict_jax kalman_filter_inputs 299.728880 \n",
+ "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 18440.632807 \n",
+ "\n",
+ " Median (us) \\\n",
+ "particle_filter_1d_predict_pytensor kalman_filter_inputs 708739.150001 \n",
+ "particle_filter_1d_predict_numba kalman_filter_inputs 48200.741794 \n",
+ "particle_filter_1d_predict_jax kalman_filter_inputs 30930.791609 \n",
+ "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 684621.258406 \n",
+ "\n",
+ " IQR (us) \\\n",
+ "particle_filter_1d_predict_pytensor kalman_filter_inputs 28481.516603 \n",
+ "particle_filter_1d_predict_numba kalman_filter_inputs 1708.116406 \n",
+ "particle_filter_1d_predict_jax kalman_filter_inputs 498.875196 \n",
+ "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 34198.999999 \n",
+ "\n",
+ " OPS (Kops/s) \\\n",
+ "particle_filter_1d_predict_pytensor kalman_filter_inputs 0.001395 \n",
+ "particle_filter_1d_predict_numba kalman_filter_inputs 0.020780 \n",
+ "particle_filter_1d_predict_jax kalman_filter_inputs 0.032174 \n",
+ "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 0.001470 \n",
+ "\n",
+ " Samples \n",
+ "particle_filter_1d_predict_pytensor kalman_filter_inputs 5 \n",
+ "particle_filter_1d_predict_numba kalman_filter_inputs 5 \n",
+ "particle_filter_1d_predict_jax kalman_filter_inputs 5 \n",
+ "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 5 "
+ ]
+ },
+ "execution_count": 56,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "nonlinear_kalman_filter_bench.run(\n",
+ " inputs={\n",
+ " \"kalman_filter_inputs\": {\"A\": A, \"Q\": Q, \"x0_mean\": x0_mean, \"x0_std\": x0_std, \"ys\": ys, \"N\": 2000},\n",
+ " }\n",
+ ")\n",
+ "nonlinear_kalman_filter_bench.summary()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "29fe0cc4",
+ "metadata": {},
+ "source": [
+ "Slightly different estimates because I couldn't reproduce 1:1 "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 57,
+ "id": "2bd23897",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "filtered_means, filtered_vars, pred_means = particle_filter_1d_predict_numba(\n",
+ " A, Q, x0_mean, x0_std, ys, N=2000, seed=2\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 58,
+ "id": "7caac5a5",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "fig = make_subplots(\n",
+ " rows=2, cols=1,\n",
+ " subplot_titles=(\"Observation Predictions\", \"Latent State Estimation\"),\n",
+ " vertical_spacing=0.07,\n",
+ " shared_xaxes=True\n",
+ ")\n",
+ "\n",
+ "fig.add_traces(\n",
+ " [\n",
+ " go.Scatter(\n",
+ " x = np.arange(T),\n",
+ " y = ys,\n",
+ " mode = \"markers\",\n",
+ " marker_color = \"cornflowerblue\",\n",
+ " name = \"actuals\"\n",
+ " ),\n",
+ " go.Scatter(\n",
+ " x = np.arange(T),\n",
+ " y = pred_means,\n",
+ " mode = \"lines\",\n",
+ " marker_color = \"#eb8c34\",\n",
+ " name = \"predicted mean\"\n",
+ " ),\n",
+ " go.Scatter(\n",
+ " name=\"\", \n",
+ " x=np.arange(T), \n",
+ " y=pred_means + 2*jnp.sqrt(pred_means), \n",
+ " mode=\"lines\", \n",
+ " marker=dict(color=\"#eb8c34\"), \n",
+ " line=dict(width=0), \n",
+ " legendgroup=\"predicted mean 95% CI\",\n",
+ " showlegend=False\n",
+ " ),\n",
+ " go.Scatter(\n",
+ " name=\"predicted mean 95% CI\", \n",
+ " x=np.arange(T), \n",
+ " y=pred_means - 2*jnp.sqrt(pred_means), \n",
+ " mode=\"lines\", marker=dict(color=\"#eb8c34\"), \n",
+ " line=dict(width=0), \n",
+ " legendgroup=\"predicted mean 95% CI\", \n",
+ " fill='tonexty', \n",
+ " fillcolor='rgba(235, 140, 52, 0.2)'\n",
+ " ),\n",
+ " ],\n",
+ " rows=1, cols=1\n",
+ ")\n",
+ "\n",
+ "fig.add_traces(\n",
+ " [\n",
+ " go.Scatter(\n",
+ " x = np.arange(T),\n",
+ " y = x,\n",
+ " mode = \"lines\",\n",
+ " marker_color = \"cornflowerblue\",\n",
+ " name = \"true latent state\"\n",
+ " ),\n",
+ " go.Scatter(\n",
+ " x = np.arange(T),\n",
+ " y = filtered_means,\n",
+ " mode = \"lines\",\n",
+ " marker_color = \"#eb8c34\",\n",
+ " name = \"filtered state mean\"\n",
+ " ),\n",
+ " go.Scatter(\n",
+ " name=\"\", \n",
+ " x=np.arange(T), \n",
+ " y=filtered_means + 2*jnp.sqrt(filtered_vars), \n",
+ " mode=\"lines\", \n",
+ " marker=dict(color=\"#eb8c34\"), \n",
+ " line=dict(width=0), \n",
+ " legendgroup=\"filtered state mean 95% CI\",\n",
+ " showlegend=False\n",
+ " ),\n",
+ " go.Scatter(\n",
+ " name=\"filtered state mean 95% CI\", \n",
+ " x=np.arange(T), \n",
+ " y=filtered_means - 2*jnp.sqrt(filtered_vars), \n",
+ " mode=\"lines\", marker=dict(color=\"#eb8c34\"), \n",
+ " line=dict(width=0), \n",
+ " legendgroup=\"filtered state mean 95% CI\", \n",
+ " fill='tonexty', \n",
+ " fillcolor='rgba(235, 140, 52, 0.2)'\n",
+ " ),\n",
+ " ],\n",
+ " rows=2, cols=1\n",
+ ")\n",
+ "\n",
+ "for i, yaxis in enumerate(fig.select_yaxes(), 1):\n",
+ " legend_name = f\"legend{i}\"\n",
+ " fig.update_layout({legend_name: dict(y=yaxis.domain[1], yanchor=\"top\")}, showlegend=True)\n",
+ " fig.update_traces(row=i, legend=legend_name)\n",
+ "\n",
+ "fig.update_layout(height=1000, width=1200, template=\"plotly_dark\")\n",
+ "\n",
+ "fig.update_layout(\n",
+ " legend1=dict(\n",
+ " yanchor=\"top\",\n",
+ " y=1.0,\n",
+ " xanchor=\"left\",\n",
+ " x=0,\n",
+ " orientation=\"h\"\n",
+ " ),\n",
+ " legend2=dict(\n",
+ " yanchor=\"top\",\n",
+ " y=.465,\n",
+ " xanchor=\"left\",\n",
+ " x=0,\n",
+ " orientation=\"h\"\n",
+ " ),\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "48ccc984",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "pytensor-dev",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.13.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/doc/gallery/scan/numba_fib_scan.ipynb b/doc/gallery/scan/numba_fib_scan.ipynb
new file mode 100644
index 0000000000..9b1b81c5e2
--- /dev/null
+++ b/doc/gallery/scan/numba_fib_scan.ipynb
@@ -0,0 +1,573 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "id": "initial_id",
+ "metadata": {
+ "collapsed": true,
+ "ExecuteTime": {
+ "end_time": "2025-10-07T10:20:23.561430Z",
+ "start_time": "2025-10-07T10:20:21.620124Z"
+ }
+ },
+ "source": [
+ "import pytensor\n",
+ "import pytensor.tensor as pt\n",
+ "import numpy as np\n",
+ "\n",
+ "N_STEPS = 1000\n",
+ "\n",
+ "b_symbolic = pt.scalar(\"b_symbolic\", dtype=\"int32\")\n",
+ "\n",
+ "def step(a, b):\n",
+ " return a + b, a\n",
+ "\n",
+ "(outputs_a, outputs_b), _ = pytensor.scan(\n",
+ " fn=step,\n",
+ " outputs_info=[pt.constant(1, dtype=\"int32\"), b_symbolic],\n",
+ " n_steps=N_STEPS\n",
+ ")\n",
+ "\n",
+ "# compile function returning final a\n",
+ "fibonacci_pytensor = pytensor.function([b_symbolic], outputs_a[-1], trust_input=True)\n",
+ "fibonacci_pytensor_numba = pytensor.function([b_symbolic], outputs_a[-1], mode='NUMBA', trust_input=True)"
+ ],
+ "outputs": [],
+ "execution_count": 1
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-10-07T10:32:18.289971Z",
+ "start_time": "2025-10-07T10:32:18.284515Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "import numba\n",
+ "\n",
+ "@numba.njit\n",
+ "def fibonacci_numba_scalar(b):\n",
+ " b = b.copy()\n",
+ " a = np.ones((), dtype=np.int32)\n",
+ " for _ in range(N_STEPS):\n",
+ " a[()], b[()] = a[()] + b[()], a[()]\n",
+ " return a\n",
+ "\n",
+ "@numba.njit\n",
+ "def fibonacci_numba_array(b):\n",
+ " a = np.ones((), dtype=np.int32)\n",
+ " for _ in range(N_STEPS):\n",
+ " a, b = np.asarray(a + b), a\n",
+ " return a"
+ ],
+ "id": "b1d657d366647ada",
+ "outputs": [],
+ "execution_count": 66
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-10-07T10:32:19.607842Z",
+ "start_time": "2025-10-07T10:32:19.423324Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "b = np.ones((), dtype=np.int32)\n",
+ "assert fibonacci_numba_array(b) == fibonacci_numba_scalar(b)"
+ ],
+ "id": "7f45c87d259852e6",
+ "outputs": [],
+ "execution_count": 67
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-10-07T10:32:22.705191Z",
+ "start_time": "2025-10-07T10:32:20.090353Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "%timeit fibonacci_numba_scalar(b)",
+ "id": "b01c8978960c6e3d",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "3.21 μs ± 20.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n"
+ ]
+ }
+ ],
+ "execution_count": 68
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-10-07T10:32:25.876514Z",
+ "start_time": "2025-10-07T10:32:23.122275Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "%timeit fibonacci_numba_array(b)",
+ "id": "bfc8794b219db03e",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "32.8 μs ± 2.48 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
+ ]
+ }
+ ],
+ "execution_count": 69
+ },
+ {
+ "metadata": {},
+ "cell_type": "code",
+ "outputs": [],
+ "execution_count": null,
+ "source": [
+ "assert fibonacci_pytensor(b) == fibonacci_numba_scalar(b)\n",
+ "assert fibonacci_pytensor_numba(b) == fibonacci_numba_scalar(b)"
+ ],
+ "id": "a2185c1de1297a11"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-10-07T10:29:44.724064Z",
+ "start_time": "2025-10-07T10:29:42.655693Z"
+ }
+ },
+ "cell_type": "code",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2.49 ms ± 327 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
+ ]
+ }
+ ],
+ "execution_count": 54,
+ "source": "%timeit fibonacci_pytensor(b)",
+ "id": "f1e8bb6a0c673c8f"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-10-07T10:29:58.922566Z",
+ "start_time": "2025-10-07T10:29:44.752331Z"
+ }
+ },
+ "cell_type": "code",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "175 μs ± 6.13 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
+ ]
+ }
+ ],
+ "execution_count": 55,
+ "source": "%timeit fibonacci_pytensor_numba(b)",
+ "id": "17cd2859b4c6d3bd"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-10-07T10:30:11.832294Z",
+ "start_time": "2025-10-07T10:29:59.016709Z"
+ }
+ },
+ "cell_type": "code",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "158 μs ± 706 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
+ ]
+ }
+ ],
+ "execution_count": 56,
+ "source": "%timeit fibonacci_pytensor_numba.vm.jit_fn(b)",
+ "id": "6deb056f63953a42"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-10-07T10:20:58.015849Z",
+ "start_time": "2025-10-07T10:20:58.006831Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "fibonacci_pytensor_numba.dprint(print_type=True, print_memory_map=True)",
+ "id": "17580448648fdbcf",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Subtensor{i} [id A] v={0: [0]} 6\n",
+ " ├─ Scan{scan_fn, while_loop=False, inplace=all}.0 [id B] d={0: [1], 1: [2]} 5\n",
+ " │ ├─ 1000 [id C] \n",
+ " │ ├─ SetSubtensor{:stop} [id D] d={0: [0]} 4\n",
+ " │ │ ├─ AllocEmpty{dtype='int32'} [id E] 3\n",
+ " │ │ │ └─ 1 [id F] \n",
+ " │ │ ├─ [1] [id G] \n",
+ " │ │ └─ 1 [id H] \n",
+ " │ └─ SetSubtensor{:stop} [id I] d={0: [0]} 2\n",
+ " │ ├─ AllocEmpty{dtype='int32'} [id J] 1\n",
+ " │ │ └─ 2 [id K] \n",
+ " │ ├─ ExpandDims{axis=0} [id L] v={0: [0]} 0\n",
+ " │ │ └─ b_symbolic [id M] \n",
+ " │ └─ 1 [id H] \n",
+ " └─ 0 [id N] \n",
+ "\n",
+ "Inner graphs:\n",
+ "\n",
+ "Scan{scan_fn, while_loop=False, inplace=all} [id B] d={0: [1], 1: [2]}\n",
+ " ← Add [id O] \n",
+ " ├─ *0- [id P] -> [id D]\n",
+ " └─ *1- [id Q] -> [id I]\n",
+ " ← *0- [id P] -> [id D]\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 8
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-10-07T10:20:58.063585Z",
+ "start_time": "2025-10-07T10:20:58.059985Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "print(fibonacci_pytensor_numba.vm.jit_fn.py_func.__source__)",
+ "id": "f9806651f5146bbd",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "def numba_funcified_fgraph(b_symbolic):\n",
+ " # ExpandDims{axis=0}(b_symbolic)\n",
+ " tensor_variable = dimshuffle(b_symbolic)\n",
+ " # AllocEmpty{dtype='int32'}(2)\n",
+ " tensor_variable_1 = allocempty(tensor_constant)\n",
+ " # SetSubtensor{:stop}(AllocEmpty{dtype='int32'}.0, ExpandDims{axis=0}.0, 1)\n",
+ " tensor_variable_2 = set_subtensor(tensor_variable_1, tensor_variable, scalar_constant)\n",
+ " # AllocEmpty{dtype='int32'}(1)\n",
+ " tensor_variable_3 = allocempty_1(tensor_constant_1)\n",
+ " # SetSubtensor{:stop}(AllocEmpty{dtype='int32'}.0, [1], 1)\n",
+ " tensor_variable_4 = set_subtensor_1(tensor_variable_3, tensor_constant_2, scalar_constant)\n",
+ " # Scan{scan_fn, while_loop=False, inplace=all}(1000, SetSubtensor{:stop}.0, SetSubtensor{:stop}.0)\n",
+ " tensor_variable_5, tensor_variable_6 = scan(tensor_constant_3, tensor_variable_4, tensor_variable_2)\n",
+ " # Subtensor{i}(Scan{scan_fn, while_loop=False, inplace=all}.0, 0)\n",
+ " tensor_variable_7 = subtensor(tensor_variable_5, scalar_constant_1)\n",
+ " return (tensor_variable_7,)\n"
+ ]
+ }
+ ],
+ "execution_count": 9
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-10-07T10:20:58.113352Z",
+ "start_time": "2025-10-07T10:20:58.109693Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "print(fibonacci_pytensor_numba.vm.jit_fn.py_func.__globals__[\"allocempty\"].py_func.__source__)",
+ "id": "9995392081dcbffb",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "def allocempty(tensor_constant):\n",
+ " tensor_constant_item = to_scalar(tensor_constant)\n",
+ " scalar_shape = (tensor_constant_item, )\n",
+ " return np.empty(scalar_shape, dtype)\n",
+ " \n"
+ ]
+ }
+ ],
+ "execution_count": 10
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-10-07T10:20:58.162062Z",
+ "start_time": "2025-10-07T10:20:58.158525Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "print(fibonacci_pytensor_numba.vm.jit_fn.py_func.__globals__[\"set_subtensor\"].py_func.__source__)",
+ "id": "1f89dfa8b172fde9",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "def set_subtensor(tensor_variable, tensor_variable_1, scalar_constant):\n",
+ " z = tensor_variable\n",
+ " indices = (slice(None, scalar_constant, None),)\n",
+ " z[indices] = tensor_variable_1\n",
+ " return np.asarray(z)\n",
+ " \n"
+ ]
+ }
+ ],
+ "execution_count": 11
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-10-07T10:20:58.211349Z",
+ "start_time": "2025-10-07T10:20:58.207479Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "print(fibonacci_pytensor_numba.vm.jit_fn.py_func.__globals__[\"scan\"].py_func.__source__)",
+ "id": "648cd8952121141b",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "def scan(n_steps, outer_in_1, outer_in_2):\n",
+ "\n",
+ " outer_in_1_len = outer_in_1.shape[0]\n",
+ " outer_in_1_sitsot_storage = outer_in_1\n",
+ " outer_in_2_len = outer_in_2.shape[0]\n",
+ " outer_in_2_sitsot_storage = outer_in_2\n",
+ "\n",
+ " outer_in_1_sitsot_storage_temp_scalar_0 = np.empty((), dtype=np.int32)\n",
+ " outer_in_2_sitsot_storage_temp_scalar_0 = np.empty((), dtype=np.int32)\n",
+ "\n",
+ " i = 0\n",
+ " cond = np.array(False)\n",
+ " while i < n_steps and not cond.item():\n",
+ " outer_in_1_sitsot_storage_temp_scalar_0[()] = outer_in_1_sitsot_storage[(i) % outer_in_1_len]\n",
+ " outer_in_2_sitsot_storage_temp_scalar_0[()] = outer_in_2_sitsot_storage[(i) % outer_in_2_len]\n",
+ "\n",
+ " (inner_out_0, inner_out_1) = scan_inner_func(outer_in_1_sitsot_storage_temp_scalar_0, outer_in_2_sitsot_storage_temp_scalar_0)\n",
+ "\n",
+ " outer_in_1_sitsot_storage[(i + 1) % outer_in_1_len] = inner_out_0\n",
+ " outer_in_2_sitsot_storage[(i + 1) % outer_in_2_len] = inner_out_1\n",
+ " i += 1\n",
+ "\n",
+ " if 1 < outer_in_1_len < (i + 1):\n",
+ " outer_in_1_sitsot_storage_shift = (i + 1) % (outer_in_1_len)\n",
+ " if outer_in_1_sitsot_storage_shift > 0:\n",
+ " outer_in_1_sitsot_storage_left = outer_in_1_sitsot_storage[:outer_in_1_sitsot_storage_shift]\n",
+ " outer_in_1_sitsot_storage_right = outer_in_1_sitsot_storage[outer_in_1_sitsot_storage_shift:]\n",
+ " outer_in_1_sitsot_storage = np.concatenate((outer_in_1_sitsot_storage_right, outer_in_1_sitsot_storage_left))\n",
+ " if 1 < outer_in_2_len < (i + 1):\n",
+ " outer_in_2_sitsot_storage_shift = (i + 1) % (outer_in_2_len)\n",
+ " if outer_in_2_sitsot_storage_shift > 0:\n",
+ " outer_in_2_sitsot_storage_left = outer_in_2_sitsot_storage[:outer_in_2_sitsot_storage_shift]\n",
+ " outer_in_2_sitsot_storage_right = outer_in_2_sitsot_storage[outer_in_2_sitsot_storage_shift:]\n",
+ " outer_in_2_sitsot_storage = np.concatenate((outer_in_2_sitsot_storage_right, outer_in_2_sitsot_storage_left))\n",
+ "\n",
+ " return outer_in_1_sitsot_storage, outer_in_2_sitsot_storage\n",
+ " \n"
+ ]
+ }
+ ],
+ "execution_count": 12
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-10-07T10:30:25.495418Z",
+ "start_time": "2025-10-07T10:30:25.481386Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "from pytensor.link.numba.dispatch.basic import tuple_setitem, to_scalar\n",
+ "\n",
+ "@numba.njit\n",
+ "def allocempty(s):\n",
+ " s_item = to_scalar(s)\n",
+ " scalar_shape = (s_item,)\n",
+ " return np.empty(scalar_shape, dtype=np.int32)\n",
+ "\n",
+ "@numba.njit\n",
+ "def subtensor(x, idx):\n",
+ " indices = (idx,)\n",
+ " z = x[indices]\n",
+ " return np.asarray(z)\n",
+ "\n",
+ "@numba.njit\n",
+ "def inner_scan_func(a, b):\n",
+ " res = a + b\n",
+ " return res, a\n",
+ "\n",
+ "@numba.njit\n",
+ "def scan_fib(n_steps, a_buf, b_buf):\n",
+ " a_buf_len = a_buf.shape[0]\n",
+ " b_buf_len = b_buf.shape[0]\n",
+ "\n",
+ " tmp_a_scalar = np.empty((), dtype=np.int32)\n",
+ " tmp_b_scalar = np.empty((), dtype=np.int32)\n",
+ "\n",
+ " i = 0\n",
+ " while i < n_steps:\n",
+ " tmp_a_scalar[()] = a_buf[i % a_buf_len]\n",
+ " tmp_b_scalar[()] = b_buf[i % b_buf_len]\n",
+ " next_a, next_b = inner_scan_func(tmp_a_scalar, tmp_b_scalar)\n",
+ " a_buf[(i + 1) % a_buf_len] = next_a\n",
+ " b_buf[(i + 1) % b_buf_len] = next_b\n",
+ " i += 1\n",
+ "\n",
+ " if 1 < a_buf_len < (i + 1):\n",
+ " a_buf_shift = (i + 1) % a_buf_len\n",
+ " if a_buf_shift > 0:\n",
+ " a_buf = np.concatenate((a_buf[a_buf_shift:], a_buf[:a_buf_shift]))\n",
+ " if 1 < b_buf_len < (i + 1):\n",
+ " b_buf_shift = (i + 1) % b_buf_len\n",
+ " if b_buf_shift > 0:\n",
+ " b_buf = np.concatenate((b_buf[b_buf_shift:], b_buf[:b_buf_shift]))\n",
+ "\n",
+ " return a_buf, b_buf\n",
+ "\n",
+ "@numba.njit\n",
+ "def set_subtensor(x, y, idx):\n",
+ " indices = (slice(None, idx, None),)\n",
+ " x[indices] = y\n",
+ " return np.asarray(x)\n",
+ "\n",
+ "@numba.njit\n",
+ "def dimshuffle(x):\n",
+ " old_shape = x.shape\n",
+ " old_strides = x.strides\n",
+ "\n",
+ " new_shape = (1,)\n",
+ " new_strides = (0,)\n",
+ " new_order = (-1,)\n",
+ " for i, o in enumerate(new_order):\n",
+ " if o != -1:\n",
+ " new_shape = tuple_setitem(new_shape, i, old_shape[o])\n",
+ " new_strides = tuple_setitem(new_strides, i, old_strides[o])\n",
+ "\n",
+ " return np.lib.stride_tricks.as_strided(x, shape=new_shape, strides=new_strides)\n",
+ " # return np.expand_dims(x, axis=0)\n",
+ "\n",
+ "@numba.njit\n",
+ "def comparable_fibonacci_numba(b):\n",
+ " a_buf = allocempty(np.array(1, dtype=np.int64))\n",
+ " # a_buf = np.empty(1, dtype=np.int32)\n",
+ " # a_buf[:1] = np.array([1], dtype=np.int32)\n",
+ " a_buf_set = set_subtensor(a_buf, np.array([1], dtype=np.int32), np.int64(1))\n",
+ "\n",
+ " b_buf = allocempty(np.array(2, dtype=np.int64))\n",
+ " # b_buf = np.empty(2, dtype=np.int32)\n",
+ " # b_buf[:1] = np.expand_dims(b, axis=0)\n",
+ " b_expanded = dimshuffle(b)\n",
+ " b_buf_set = set_subtensor(b_buf, b_expanded, np.int64(1))\n",
+ "\n",
+ " a_buf_updated, b_buf_updated = scan_fib(np.array(N_STEPS, np.int64), a_buf_set, b_buf_set)\n",
+ "\n",
+ " res = subtensor(a_buf_updated, np.uint8(0))\n",
+ "\n",
+ " return (res,)"
+ ],
+ "id": "bcefae049d4d2540",
+ "outputs": [],
+ "execution_count": 59
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-10-07T10:30:31.559493Z",
+ "start_time": "2025-10-07T10:30:30.263832Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "b = np.ones((), dtype=np.int32)\n",
+ "assert comparable_fibonacci_numba(b) == fibonacci_numba_scalar(b)"
+ ],
+ "id": "65887ebba21f46c3",
+ "outputs": [],
+ "execution_count": 60
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-10-07T10:30:35.999409Z",
+ "start_time": "2025-10-07T10:30:31.567997Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "%timeit comparable_fibonacci_numba(b)",
+ "id": "2e0aba9917097009",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "54.6 μs ± 1.28 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
+ ]
+ }
+ ],
+ "execution_count": 61
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-10-07T10:21:06.095536Z",
+ "start_time": "2025-10-07T10:21:06.093418Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "",
+ "id": "f7e8c11b24c366f1",
+ "outputs": [],
+ "execution_count": null
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 2
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython2",
+ "version": "2.7.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}