diff --git a/doc/gallery/scan/benchmark_backends.ipynb b/doc/gallery/scan/benchmark_backends.ipynb new file mode 100644 index 0000000000..4d55a80da1 --- /dev/null +++ b/doc/gallery/scan/benchmark_backends.ipynb @@ -0,0 +1,7171 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9571ee33", + "metadata": {}, + "source": [ + "# Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "73cc811c", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import jax.numpy as jnp\n", + "from numba import jit\n", + "import numba\n", + "import pytensor\n", + "import pytensor.tensor as pt\n", + "from pytensor.compile.mode import get_mode\n", + "import timeit\n", + "import jax\n", + "import math\n", + "from jax.scipy.special import gammaln\n", + "from functools import partial\n", + "\n", + "import plotly.graph_objects as go\n", + "from plotly.subplots import make_subplots\n", + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d5fc1350", + "metadata": {}, + "outputs": [], + "source": [ + "import plotly.io as pio\n", + "pio.renderers.default = \"notebook\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "daa0969b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "pytensor version: 0+untagged.31345.g422eafd.dirty\n", + "jax version: 0.7.2\n", + "numba version: 0.62.1\n" + ] + } + ], + "source": [ + "print(\"pytensor version:\", pytensor.__version__)\n", + "print(\"jax version:\", jax.__version__)\n", + "print(\"numba version:\", numba.__version__)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "8294718d", + "metadata": {}, + "outputs": [], + "source": [ + "class Benchmarker:\n", + " \"\"\"\n", + " Benchmark a set of functions by timing execution and summarizing statistics.\n", + "\n", + " Parameters\n", + " ----------\n", + " functions : list of callables\n", + " List of callables to benchmark.\n", + " names : list of str, optional\n", + " Names corresponding to each function. Default is ['func_0', 'func_1', ...].\n", + " number : int or None, optional\n", + " Number of loops per timing. If None, auto-calibrated via Timer.autorange().\n", + " Default is None.\n", + " repeat : int, optional\n", + " Number of repeats for timing. Default is 7.\n", + " target_time : float, optional\n", + " Target duration in seconds for auto-calibration. Default is 0.2.\n", + "\n", + " Attributes\n", + " ----------\n", + " results : dict\n", + " Mapping from function names to a dict with keys:\n", + " - 'raw_us': numpy.ndarray of raw timings in microseconds\n", + " - 'loops': number of loops used per timing\n", + "\n", + " Methods\n", + " -------\n", + " run()\n", + " Auto-calibrate (if needed) and run timings for all functions.\n", + " summary(unit='us') -> pandas.DataFrame\n", + " Return a summary DataFrame with statistics converted to the given unit.\n", + " raw(name=None) -> dict or numpy.ndarray\n", + " Return raw timing data in microseconds for a specific function or all.\n", + " _convert_times(times, unit) -> numpy.ndarray\n", + " Convert an array of times from microseconds to the specified unit.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self, functions, names=None, number=None, min_rounds=5, max_time=1.0, target_time=0.2\n", + " ):\n", + " self.functions = functions\n", + " self.names = names or [f\"func_{i}\" for i in range(len(functions))]\n", + " self.number = number\n", + " self.min_rounds = min_rounds\n", + " self.max_time = max_time\n", + " self.target_time = target_time\n", + " self.results = {}\n", + "\n", + " def run(self, inputs: dict[str, dict]):\n", + " \"\"\"\n", + " Auto-calibrate loop count and sample rounds if needed, then time each function.\n", + " \"\"\"\n", + " for name, func in zip(self.names, self.functions):\n", + " for input_name, kwargs in inputs.items():\n", + " timer = timeit.Timer(partial(func, **kwargs))\n", + "\n", + " # Calibrate loops\n", + " if self.number is None:\n", + " loops, calib_time = timer.autorange()\n", + " else:\n", + " loops = self.number\n", + " calib_time = timer.timeit(number=loops)\n", + "\n", + " # Determine rounds based on max_time and min_rounds\n", + " if self.max_time is not None:\n", + " rounds = max(self.min_rounds, int(np.ceil(self.max_time / calib_time)))\n", + " else:\n", + " rounds = self.min_rounds\n", + "\n", + " raw_round_times = np.array(timer.repeat(repeat=rounds, number=loops))\n", + "\n", + " # Convert to microseconds per single execution\n", + " raw_us = raw_round_times / loops * 1e6\n", + "\n", + " self.results[(name, input_name)] = {\n", + " \"raw_us\": raw_us,\n", + " \"loops\": loops,\n", + " \"rounds\": rounds,\n", + " }\n", + "\n", + " def summary(self, unit=\"us\"):\n", + " \"\"\"\n", + " Summarize benchmark statistics in a DataFrame.\n", + "\n", + " Parameters\n", + " ----------\n", + " unit : {'us', 'ms', 'ns', 's'}, optional\n", + " Unit for output times. 'us' means microseconds, 'ms' milliseconds,\n", + " 'ns' nanoseconds, 's' seconds. Default is 'us'.\n", + "\n", + " Returns\n", + " -------\n", + " pandas.DataFrame\n", + " Summary with columns:\n", + " Name, Loops, Min, Max, Mean, StdDev, Median, IQR (all in given unit),\n", + " OPS (Kops/unit), Samples.\n", + " \"\"\"\n", + " records = []\n", + " indexes = []\n", + " for name, data in self.results.items():\n", + " raw_us = data[\"raw_us\"]\n", + " # Convert to target unit\n", + " times = self._convert_times(raw_us, unit)\n", + " if isinstance(name, tuple) and len(name) > 1:\n", + " indexes.append(name)\n", + " elif isinstance(name, tuple) and len(name) == 1:\n", + " indexes.append(name[0])\n", + " else:\n", + " indexes.append(name)\n", + "\n", + " stats = {\n", + " \"Loops\": data[\"loops\"],\n", + " f\"Min ({unit})\": np.min(times),\n", + " f\"Max ({unit})\": np.max(times),\n", + " f\"Mean ({unit})\": np.mean(times),\n", + " f\"StdDev ({unit})\": np.std(times),\n", + " f\"Median ({unit})\": np.median(times),\n", + " f\"IQR ({unit})\": np.percentile(times, 75) - np.percentile(times, 25),\n", + " \"OPS (Kops/s)\": 1e3 / (np.mean(raw_us)),\n", + " \"Samples\": len(raw_us),\n", + " }\n", + " records.append(stats)\n", + "\n", + " if all(isinstance(idx, tuple) for idx in indexes):\n", + " index = pd.MultiIndex.from_tuples(indexes)\n", + " else:\n", + " index = pd.Index(indexes)\n", + " return pd.DataFrame(records, index=index)\n", + "\n", + " def raw(self, name=None):\n", + " \"\"\"\n", + " Get raw timing data in microseconds.\n", + "\n", + " Parameters\n", + " ----------\n", + " name : str, optional\n", + " If given, returns the raw_us array for that function. Otherwise returns\n", + " a dict of all raw results.\n", + "\n", + " Returns\n", + " -------\n", + " numpy.ndarray or dict\n", + " \"\"\"\n", + " if name:\n", + " return self.results.get(name, {}).get(\"raw_us\")\n", + " return {n: d[\"raw_us\"] for n, d in self.results.items()}\n", + "\n", + " def _convert_times(self, times, unit):\n", + " \"\"\"\n", + " Convert an array of times from microseconds to the specified unit.\n", + "\n", + " Parameters\n", + " ----------\n", + " times : array-like\n", + " Times in microseconds.\n", + " unit : {'us', 'ms', 'ns', 's'}\n", + " Target unit: 'us' microseconds, 'ms' milliseconds,\n", + " 'ns' nanoseconds, 's' seconds.\n", + "\n", + " Returns\n", + " -------\n", + " numpy.ndarray\n", + " Converted times.\n", + "\n", + " Raises\n", + " ------\n", + " ValueError\n", + " If `unit` is not one of the supported options.\n", + " \"\"\"\n", + " unit = unit.lower()\n", + " if unit == \"us\":\n", + " factor = 1.0\n", + " elif unit == \"ms\":\n", + " factor = 1e-3\n", + " elif unit == \"ns\":\n", + " factor = 1e3\n", + " elif unit == \"s\":\n", + " factor = 1e-6\n", + " else:\n", + " raise ValueError(f\"Unsupported unit: {unit}\")\n", + " return times * factor" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "799a759b", + "metadata": {}, + "outputs": [], + "source": [ + "# def block(func):\n", + "# def inner(*args, **kwargs):\n", + "# return jax.block_until_ready(func(*args, **kwargs))\n", + "# return inner\n", + "\n", + "\n", + "# I believe the above will only work if the return is a single jnp.array (at least that is what AI thinks) I would appreciate your insights here, Adrian.\n", + "\n", + "def block(func):\n", + " def inner(*args, **kwargs):\n", + " result = func(*args, **kwargs)\n", + " # Recursively block on all JAX arrays in result\n", + " jax.tree_util.tree_map(lambda x: x.block_until_ready() if hasattr(x, \"block_until_ready\") else None, result)\n", + " return result\n", + " return inner\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "6e76a860", + "metadata": {}, + "outputs": [], + "source": [ + "# Set Pytensor to use float32\n", + "pytensor.config.floatX = \"float32\"" + ] + }, + { + "cell_type": "markdown", + "id": "1c3eb543", + "metadata": {}, + "source": [ + "# Introduction" + ] + }, + { + "cell_type": "markdown", + "id": "e066c9c4", + "metadata": {}, + "source": [ + "# Baby Steps" + ] + }, + { + "cell_type": "markdown", + "id": "962c851c", + "metadata": {}, + "source": [ + "## Fibonacci Algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "1c8d1654", + "metadata": {}, + "outputs": [], + "source": [ + "N_STEPS = 10_000\n", + "\n", + "b_symbolic = pt.vector(\"b\", dtype=\"int32\", shape=(1,))\n", + "\n", + "def step(a, b):\n", + " return a + b, a\n", + "\n", + "(outputs_a, outputs_b), _ = pytensor.scan(\n", + " fn=step,\n", + " outputs_info=[pt.ones(1, dtype=\"int32\"), b_symbolic],\n", + " n_steps=N_STEPS\n", + ")\n", + "\n", + "# compile function returning final a\n", + "fibonacci_pytensor = pytensor.function([b_symbolic], outputs_a[-1], trust_input=True)\n", + "fibonacci_pytensor_numba = pytensor.function([b_symbolic], outputs_a[-1], mode='NUMBA', trust_input=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "bf48d6bf", + "metadata": {}, + "outputs": [], + "source": [ + "@jit(nopython=True)\n", + "def fibonacci_numba(b):\n", + " a = np.ones(1, dtype=np.int32)\n", + " for _ in range(N_STEPS):\n", + " a[0], b[0] = a[0] + b[0], a[0]\n", + " return a[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "a5cb2f54", + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def fibonacci_jax(b):\n", + " a = jnp.ones(1, dtype=np.int32)\n", + " \n", + " def step(_, carry):\n", + " a, b = carry\n", + " return (a + b, a)\n", + " \n", + " a = jax.lax.fori_loop(0, N_STEPS, step, (a, b))\n", + "\n", + " return a" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "63bfdffe", + "metadata": {}, + "outputs": [], + "source": [ + "fibonacci_bench = Benchmarker(\n", + " functions=[fibonacci_pytensor, fibonacci_numba, fibonacci_jax, fibonacci_pytensor_numba], \n", + " names=['fibonacci_pytensor', 'fibonacci_numba', 'fibonacci_jax', 'fibonacci_pytensor_numba'],\n", + " number=10\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "65bf994e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
LoopsMin (us)Max (us)Mean (us)StdDev (us)Median (us)IQR (us)OPS (Kops/s)Samples
fibonacci_pytensorfibonacci_inputs105314.3457986147.6000005749.518199241.1961055631.999997395.7270960.17392819
fibonacci_numbafibonacci_inputs108.47500111.3624969.3035790.6538859.1395490.234378107.48551614
fibonacci_jaxfibonacci_inputs1021.82920037.45830023.6626353.21499822.0874962.30417542.26072132
fibonacci_pytensor_numbafibonacci_inputs10222.412497235.400000229.2433405.403845231.95420510.2000024.3621775
\n", + "
" + ], + "text/plain": [ + " Loops Min (us) Max (us) \\\n", + "fibonacci_pytensor fibonacci_inputs 10 5314.345798 6147.600000 \n", + "fibonacci_numba fibonacci_inputs 10 8.475001 11.362496 \n", + "fibonacci_jax fibonacci_inputs 10 21.829200 37.458300 \n", + "fibonacci_pytensor_numba fibonacci_inputs 10 222.412497 235.400000 \n", + "\n", + " Mean (us) StdDev (us) \\\n", + "fibonacci_pytensor fibonacci_inputs 5749.518199 241.196105 \n", + "fibonacci_numba fibonacci_inputs 9.303579 0.653885 \n", + "fibonacci_jax fibonacci_inputs 23.662635 3.214998 \n", + "fibonacci_pytensor_numba fibonacci_inputs 229.243340 5.403845 \n", + "\n", + " Median (us) IQR (us) \\\n", + "fibonacci_pytensor fibonacci_inputs 5631.999997 395.727096 \n", + "fibonacci_numba fibonacci_inputs 9.139549 0.234378 \n", + "fibonacci_jax fibonacci_inputs 22.087496 2.304175 \n", + "fibonacci_pytensor_numba fibonacci_inputs 231.954205 10.200002 \n", + "\n", + " OPS (Kops/s) Samples \n", + "fibonacci_pytensor fibonacci_inputs 0.173928 19 \n", + "fibonacci_numba fibonacci_inputs 107.485516 14 \n", + "fibonacci_jax fibonacci_inputs 42.260721 32 \n", + "fibonacci_pytensor_numba fibonacci_inputs 4.362177 5 " + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fibonacci_bench.run(\n", + " inputs={\n", + " \"fibonacci_inputs\": {\"b\": np.ones(1, dtype=np.int32)},\n", + " }\n", + ")\n", + "fibonacci_bench.summary()" + ] + }, + { + "cell_type": "markdown", + "id": "7bbb35a2", + "metadata": {}, + "source": [ + "## Element-wise multiplication Algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "0b076f77", + "metadata": {}, + "outputs": [], + "source": [ + "# Not sure this makes a difference\n", + "a_symbolic = pt.vector(\"a\", dtype=\"int32\")\n", + "b_symbolic = pt.vector(\"b\", dtype=\"int32\")\n", + "N_STEPS = 1000\n", + "\n", + "def step(a_element, b_element):\n", + " return a_element * b_element\n", + "\n", + "c, _ = pytensor.scan(\n", + " fn=step,\n", + " sequences=[a_symbolic, b_symbolic],\n", + " n_steps=N_STEPS\n", + ")\n", + "\n", + "# compile function returning final a\n", + "c_mode = get_mode(\"FAST_RUN\").excluding(\"scan_push_out_seq\")\n", + "elementwise_multiply_pytensor = pytensor.function([a_symbolic, b_symbolic], c, trust_input=True, mode=c_mode)\n", + "\n", + "numba_mode = get_mode(\"NUMBA\").excluding(\"scan_push_out_seq\")\n", + "elementwise_multiply_pytensor_numba = pytensor.function([a_symbolic, b_symbolic], c, mode=numba_mode, trust_input=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "14327cbf", + "metadata": {}, + "outputs": [], + "source": [ + "@jit(nopython=True)\n", + "def elementwise_multiply_numba(a, b):\n", + " n = a.shape[0]\n", + " c = np.empty(n, dtype=a.dtype)\n", + " for i in range(n):\n", + " c[i] = a[i] * b[i]\n", + " return c" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "13715c9b", + "metadata": {}, + "outputs": [], + "source": [ + "@block\n", + "@jax.jit\n", + "def elementwise_multiply_jax(a, b):\n", + " n = a.shape[0]\n", + " c_init = jnp.empty(n, dtype=a.dtype)\n", + " def step(i, c):\n", + " return jax.lax.dynamic_update_index_in_dim(c, a[i] * b[i], i, axis=0)\n", + " \n", + " c = jax.lax.fori_loop(0, n, step, c_init)\n", + " return c" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "d43c4f9c", + "metadata": {}, + "outputs": [], + "source": [ + "a = np.random.normal(0, 1, (N_STEPS)).astype(np.int32)\n", + "b = np.random.normal(0, 1, (N_STEPS)).astype(np.int32)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "f0f4ede5", + "metadata": {}, + "outputs": [], + "source": [ + "elem_mult_bench = Benchmarker(\n", + " functions=[elementwise_multiply_pytensor, elementwise_multiply_numba, elementwise_multiply_jax, elementwise_multiply_pytensor_numba], \n", + " names=['elementwise_multiply_pytensor', 'elementwise_multiply_numba', 'elementwise_multiply_jax', 'elementwise_multiply_pytensor_numba'],\n", + " number=10\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "cdab8946", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
LoopsMin (us)Max (us)Mean (us)StdDev (us)Median (us)IQR (us)OPS (Kops/s)Samples
elementwise_multiply_pytensorelem_mult_inputs10431.329099496.695802465.7878819.956242466.14999912.6301222.146900218
elementwise_multiply_numbaelem_mult_inputs100.3625000.6042010.3833430.0447770.3749970.0083752608.63211526
elementwise_multiply_jaxelem_mult_inputs107.77500111.6374988.6936050.9144308.2833490.964597115.02708360
elementwise_multiply_pytensor_numbaelem_mult_inputs1028.79159931.77080129.5041611.13814928.9584040.22499633.8935255
\n", + "
" + ], + "text/plain": [ + " Loops Min (us) \\\n", + "elementwise_multiply_pytensor elem_mult_inputs 10 431.329099 \n", + "elementwise_multiply_numba elem_mult_inputs 10 0.362500 \n", + "elementwise_multiply_jax elem_mult_inputs 10 7.775001 \n", + "elementwise_multiply_pytensor_numba elem_mult_inputs 10 28.791599 \n", + "\n", + " Max (us) Mean (us) \\\n", + "elementwise_multiply_pytensor elem_mult_inputs 496.695802 465.787881 \n", + "elementwise_multiply_numba elem_mult_inputs 0.604201 0.383343 \n", + "elementwise_multiply_jax elem_mult_inputs 11.637498 8.693605 \n", + "elementwise_multiply_pytensor_numba elem_mult_inputs 31.770801 29.504161 \n", + "\n", + " StdDev (us) \\\n", + "elementwise_multiply_pytensor elem_mult_inputs 9.956242 \n", + "elementwise_multiply_numba elem_mult_inputs 0.044777 \n", + "elementwise_multiply_jax elem_mult_inputs 0.914430 \n", + "elementwise_multiply_pytensor_numba elem_mult_inputs 1.138149 \n", + "\n", + " Median (us) IQR (us) \\\n", + "elementwise_multiply_pytensor elem_mult_inputs 466.149999 12.630122 \n", + "elementwise_multiply_numba elem_mult_inputs 0.374997 0.008375 \n", + "elementwise_multiply_jax elem_mult_inputs 8.283349 0.964597 \n", + "elementwise_multiply_pytensor_numba elem_mult_inputs 28.958404 0.224996 \n", + "\n", + " OPS (Kops/s) Samples \n", + "elementwise_multiply_pytensor elem_mult_inputs 2.146900 218 \n", + "elementwise_multiply_numba elem_mult_inputs 2608.632115 26 \n", + "elementwise_multiply_jax elem_mult_inputs 115.027083 60 \n", + "elementwise_multiply_pytensor_numba elem_mult_inputs 33.893525 5 " + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "elem_mult_bench.run(\n", + " inputs={\n", + " \"elem_mult_inputs\": {\"a\": a, \"b\": b},\n", + " }\n", + ")\n", + "elem_mult_bench.summary()" + ] + }, + { + "cell_type": "markdown", + "id": "25a87710", + "metadata": {}, + "source": [ + "# Changepoint Detection Algorithms" + ] + }, + { + "cell_type": "markdown", + "id": "16b2b312", + "metadata": {}, + "source": [ + "## Cumulative Sum (CUSUM) Algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "c1226cfc", + "metadata": {}, + "outputs": [], + "source": [ + "@jit(nopython=True)\n", + "def cusum_adaptive_numba(x, alpha=0.01, k=0.5, h=5.0):\n", + " \"\"\"\n", + " Two-sided CUSUM with adaptive exponential moving average baseline.\n", + " \n", + " Parameters\n", + " ----------\n", + " x: np.ndarray\n", + " input signal\n", + " alpha: float\n", + " EMA smoothing factor (0 < alpha <= 1)\n", + " k: float\n", + " slack to avoid small changes triggering alarms\n", + " h: float\n", + " threshold for raising an alarm\n", + " \n", + " Returns\n", + " -------\n", + " s_pos: np.ndarray\n", + " upper CUSUM stats\n", + " s_neg: np.ndarray\n", + " lower CUSUM stats\n", + " mu_t: np.ndarray\n", + " evolving baseline estimate\n", + " alarms_pos: np.ndarray\n", + " alarms for upward changes\n", + " alarms_neg: np.ndarray\n", + " alarms for downward changes\n", + " \"\"\"\n", + " n = x.shape[0]\n", + "\n", + " s_pos = np.zeros(n, dtype=np.float32)\n", + " s_neg = np.zeros(n, dtype=np.float32)\n", + " mu_t = np.zeros(n, dtype=np.float32)\n", + " alarms_pos = np.zeros(n, dtype=np.bool_)\n", + " alarms_neg = np.zeros(n, dtype=np.bool_)\n", + "\n", + " # Initialization\n", + " mu_t[0] = x[0]\n", + "\n", + " for i in range(1, n):\n", + " # Update baseline (EMA)\n", + " mu_t[i] = alpha * x[i] + (1 - alpha) * mu_t[i-1]\n", + "\n", + " # Update CUSUM stats\n", + " s_pos[i] = max(0.0, s_pos[i-1] + x[i] - mu_t[i] - k)\n", + " s_neg[i] = max(0.0, s_neg[i-1] - (x[i] - mu_t[i]) - k)\n", + "\n", + " # Alarms\n", + " alarms_pos = s_pos > h\n", + " alarms_neg = s_neg > h\n", + "\n", + " return s_pos, s_neg, mu_t, alarms_pos, alarms_neg" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "14937a2a", + "metadata": {}, + "outputs": [], + "source": [ + "@block\n", + "@jax.jit\n", + "def cusum_adaptive_jax(x, alpha=0.01, k=0.5, h=5.0):\n", + " \"\"\"\n", + " Two-sided CUSUM with adaptive exponential moving average baseline.\n", + " \n", + " Parameters\n", + " ----------\n", + " x: jnp.ndarray\n", + " input signal\n", + " alpha: float\n", + " EMA smoothing factor (0 < alpha <= 1)\n", + " k: float\n", + " slack to avoid small changes triggering alarms\n", + " h: float\n", + " threshold for raising an alarm\n", + " \n", + " Returns\n", + " -------\n", + " s_pos: jnp.ndarray\n", + " upper CUSUM stats\n", + " s_neg: jnp.ndarray\n", + " lower CUSUM stats\n", + " mu_t: jnp.ndarray\n", + " evolving baseline estimate\n", + " alarms_pos: jnp.ndarray\n", + " alarms for upward changes\n", + " alarms_neg: jnp.ndarray\n", + " alarms for downward changes\n", + " \"\"\"\n", + " def body(carry, x_t):\n", + " s_pos_prev, s_neg_prev, mu_prev = carry\n", + " \n", + " # Update EMA baseline\n", + " mu_t = alpha * x_t + (1 - alpha) * mu_prev\n", + " \n", + " # Update CUSUMs using updated baseline\n", + " s_pos = jnp.maximum(0.0, s_pos_prev + x_t - mu_t - k)\n", + " s_neg = jnp.maximum(0.0, s_neg_prev - (x_t - mu_t) - k)\n", + " \n", + " new_carry = (s_pos, s_neg, mu_t)\n", + " output = (s_pos, s_neg, mu_t)\n", + " return new_carry, output\n", + "\n", + " # Initialize: CUSUMs at 0, initial mean = first sample\n", + " s0 = (0.0, 0.0, x[0])\n", + " _, (s_pos_vals, s_neg_vals, mu_vals) = jax.lax.scan(body, s0, x)\n", + " \n", + " # Thresholding\n", + " alarms_pos = s_pos_vals > h\n", + " alarms_neg = s_neg_vals > h\n", + "\n", + " return s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg\n" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "815967a4", + "metadata": {}, + "outputs": [], + "source": [ + "x_symbolic = pt.vector(\"x\")\n", + "alpha_symbolic = pt.scalar(\"alpha\")\n", + "k_symbolic = pt.scalar(\"k\")\n", + "h_symbolic = pt.scalar(\"h\")\n", + "\n", + "N_STEPS = 100 # Fixing this just incase but I don't think this is changing anything when we have sequences as input to scan\n", + "\n", + "def step(x_t, s_pos_prev, s_neg_prev, mu_prev, alpha, k):\n", + " # Update EMA baseline\n", + " mu_t = alpha * x_t + (1 - alpha) * mu_prev\n", + " \n", + " # Update CUSUMs using updated baseline\n", + " s_pos = pt.maximum(0.0, s_pos_prev + x_t - mu_t - k)\n", + " s_neg = pt.maximum(0.0, s_neg_prev - (x_t - mu_t) - k)\n", + " \n", + " return s_pos, s_neg, mu_t\n", + "\n", + "\n", + "(s_pos_vals, s_neg_vals, mu_vals), updates = pytensor.scan(\n", + " fn=step,\n", + " outputs_info=[pt.constant(0., dtype=\"float32\"), pt.constant(0., dtype=\"float32\"), x_symbolic[0]],\n", + " non_sequences=[alpha_symbolic, k_symbolic],\n", + " sequences=[x_symbolic],\n", + " n_steps=N_STEPS\n", + ")\n", + "\n", + "# Thresholding\n", + "alarms_pos = s_pos_vals > h_symbolic\n", + "alarms_neg = s_neg_vals > h_symbolic\n", + "\n", + "cusum_adaptive_pytensor = pytensor.function([x_symbolic, alpha_symbolic, k_symbolic, h_symbolic], [s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg], trust_input=True)\n", + "\n", + "cusum_adaptive_pytensor_numba = pytensor.function([x_symbolic, alpha_symbolic, k_symbolic, h_symbolic], [s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg], mode=\"NUMBA\", trust_input=True)\n", + "cusum_adaptive_pytensor_jax = pytensor.function([x_symbolic, alpha_symbolic, k_symbolic, h_symbolic], [s_pos_vals, s_neg_vals, mu_vals, alarms_pos, alarms_neg], mode=\"JAX\", trust_input=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "6c892129", + "metadata": {}, + "outputs": [], + "source": [ + "xs1 = np.random.normal(80, 20, size=(int(N_STEPS/2)))\n", + "xs2 = np.random.normal(50, 20, size=(int(N_STEPS/2)))\n", + "xs = np.concat((xs1, xs2))\n", + "xs = xs.astype(np.float32)\n", + "xs = xs.astype(np.float32)\n", + "xs_std = (xs - xs.mean()) / xs.std()" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "d21873d6", + "metadata": {}, + "outputs": [], + "source": [ + "cusum_bench = Benchmarker(\n", + " functions=[cusum_adaptive_pytensor, cusum_adaptive_numba, cusum_adaptive_jax, cusum_adaptive_pytensor_numba, block(cusum_adaptive_pytensor_jax)], \n", + " names=['cusum_adaptive_pytensor', 'cusum_adaptive_numba', 'cusum_adaptive_jax', 'cusum_adaptive_pytensor_numba', 'cusum_adaptive_pytensor_jax'],\n", + " number=10\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "d8eab72d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
LoopsMin (us)Max (us)Mean (us)StdDev (us)Median (us)IQR (us)OPS (Kops/s)Samples
cusum_adaptive_pytensorcusum_inputs10127.283402161.174999137.0549164.574608136.8167005.3750067.296345697
cusum_adaptive_numbacusum_inputs101.6582962.2957971.9880850.2230862.0416980.372902502.9965647
cusum_adaptive_jaxcusum_inputs1013.97500219.08750314.7587641.19670814.2333970.58544767.75635139
cusum_adaptive_pytensor_numbacusum_inputs1021.27079626.60829922.3883002.11069221.3166000.12079844.6661885
cusum_adaptive_pytensor_jaxcusum_inputs1021.49160137.29169822.7061852.75677521.7875001.47920844.04086333
\n", + "
" + ], + "text/plain": [ + " Loops Min (us) Max (us) \\\n", + "cusum_adaptive_pytensor cusum_inputs 10 127.283402 161.174999 \n", + "cusum_adaptive_numba cusum_inputs 10 1.658296 2.295797 \n", + "cusum_adaptive_jax cusum_inputs 10 13.975002 19.087503 \n", + "cusum_adaptive_pytensor_numba cusum_inputs 10 21.270796 26.608299 \n", + "cusum_adaptive_pytensor_jax cusum_inputs 10 21.491601 37.291698 \n", + "\n", + " Mean (us) StdDev (us) \\\n", + "cusum_adaptive_pytensor cusum_inputs 137.054916 4.574608 \n", + "cusum_adaptive_numba cusum_inputs 1.988085 0.223086 \n", + "cusum_adaptive_jax cusum_inputs 14.758764 1.196708 \n", + "cusum_adaptive_pytensor_numba cusum_inputs 22.388300 2.110692 \n", + "cusum_adaptive_pytensor_jax cusum_inputs 22.706185 2.756775 \n", + "\n", + " Median (us) IQR (us) \\\n", + "cusum_adaptive_pytensor cusum_inputs 136.816700 5.375006 \n", + "cusum_adaptive_numba cusum_inputs 2.041698 0.372902 \n", + "cusum_adaptive_jax cusum_inputs 14.233397 0.585447 \n", + "cusum_adaptive_pytensor_numba cusum_inputs 21.316600 0.120798 \n", + "cusum_adaptive_pytensor_jax cusum_inputs 21.787500 1.479208 \n", + "\n", + " OPS (Kops/s) Samples \n", + "cusum_adaptive_pytensor cusum_inputs 7.296345 697 \n", + "cusum_adaptive_numba cusum_inputs 502.996564 7 \n", + "cusum_adaptive_jax cusum_inputs 67.756351 39 \n", + "cusum_adaptive_pytensor_numba cusum_inputs 44.666188 5 \n", + "cusum_adaptive_pytensor_jax cusum_inputs 44.040863 33 " + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cusum_bench.run(\n", + " inputs={\n", + " \"cusum_inputs\": {\"x\": xs, \"alpha\": 0.1, \"k\": 0.5, \"h\": 3.5},\n", + " }\n", + ")\n", + "cusum_bench.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "3e9c3339", + "metadata": {}, + "outputs": [], + "source": [ + "outputs = cusum_adaptive_numba(xs_std, alpha=0.1, k=0.5, h=3.5)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "b85d0c0e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + " \n", + " \n", + " " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig = go.Figure()\n", + "fig.add_traces(\n", + " [\n", + " go.Scatter(\n", + " x = np.arange(len(xs)),\n", + " y = xs_std,\n", + " name=\"series\"\n", + " ),\n", + " go.Scatter(\n", + " x = np.arange(len(xs)),\n", + " y = outputs[0],\n", + " name=\"cum. positive devs.\"\n", + " ),\n", + " go.Scatter(\n", + " x = np.arange(len(xs)),\n", + " y = outputs[1],\n", + " name=\"cum. negative devs.\"\n", + " ),\n", + " go.Scatter(\n", + " x = np.arange(len(xs)),\n", + " y = outputs[2],\n", + " name=\"Exp. Mean\"\n", + " ),\n", + " go.Scatter(\n", + " x = np.arange(len(xs)),\n", + " y = outputs[3].astype(np.float16),\n", + " name=\"positive alarms\"\n", + " ),\n", + " go.Scatter(\n", + " x = np.arange(len(xs)),\n", + " y = outputs[4].astype(np.float16),\n", + " name=\"negative alarms\"\n", + " ),\n", + " \n", + " ]\n", + ")\n", + "fig.update_layout(\n", + " title = dict(\n", + " text = \"CUSUM Change Point Detection Algorithm\"\n", + " ),\n", + " xaxis=dict(\n", + " title = \"Time Index\"\n", + " ),\n", + " yaxis=dict(\n", + " title = \"Standardized Series Scaled\"\n", + " ),\n", + " legend=dict(\n", + " yanchor=\"top\",\n", + " y=1.1,\n", + " xanchor=\"left\",\n", + " x=0,\n", + " orientation=\"h\"\n", + " ),\n", + " template=\"plotly_dark\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "c2ce03ea", + "metadata": {}, + "source": [ + "## Pruned Exact Linear Time (PELT) Algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "eab4cbb9", + "metadata": {}, + "outputs": [], + "source": [ + "@jit(nopython=True)\n", + "def segment_cost_numba(S1, S2, i, j):\n", + " \"\"\"Cost of segment x[i:j], SSE around mean\"\"\"\n", + " n = j - i\n", + " sum_x = S1[j] - S1[i]\n", + " sum_x2 = S2[j] - S2[i]\n", + " if n > 0:\n", + " return sum_x2 - (sum_x ** 2) / n\n", + " else:\n", + " return np.inf\n", + "\n", + "@jit(nopython=True)\n", + "def pelt_numba(x, beta=10.0):\n", + " \"\"\"\n", + " Pruned Exact Linear Time algorithm for change point detection\n", + "\n", + " Parameters\n", + " ----------\n", + " x: np.ndarray\n", + " The timeseries signal\n", + " beta: float\n", + " Penalty of segmenting the series\n", + "\n", + " Returns\n", + " -------\n", + " C: np.ndarray\n", + " The best costs up to segment t\n", + " last_change: np.ndarray\n", + " The last change point up to segment t\n", + " \"\"\"\n", + " n = len(x)\n", + "\n", + " # cumulative sums for cost\n", + " S1 = np.empty(n+1, dtype=np.float32)\n", + " S2 = np.empty(n+1, dtype=np.float32)\n", + " S1[0], S2[0] = 0.0, 0.0\n", + " for i in range(1, n+1):\n", + " S1[i] = S1[i-1] + x[i-1]\n", + " S2[i] = S2[i-1] + x[i-1]**2\n", + "\n", + " # DP arrays\n", + " C = np.full((n+1,), np.inf)\n", + " C[0] = -beta\n", + " last_change = np.full((n+1,), -1)\n", + " min_size = 3\n", + "\n", + " for t in range(1, n+1):\n", + " costs = np.full(n, np.inf)\n", + " for s in range(n):\n", + " if s < t and (t - s) >= min_size:\n", + " costs[s] = C[s] + segment_cost_numba(S1, S2, s, t) + beta\n", + " best_s = np.argmin(costs)\n", + " C[t] = costs[best_s]\n", + " last_change[t] = best_s\n", + "\n", + " return C, last_change" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "e6997389", + "metadata": {}, + "outputs": [], + "source": [ + "def segment_cost_jax(S1, S2, i, j):\n", + " \"\"\"Cost of segment x[i:j], SSE around mean\"\"\"\n", + " n = j - i\n", + " sum_x = S1[j] - S1[i]\n", + " sum_x2 = S2[j] - S2[i]\n", + " return jnp.where(n > 0, sum_x2 - (sum_x ** 2) / n, jnp.inf)\n", + "\n", + "@block\n", + "@jax.jit\n", + "def pelt_jax(x, beta=10.0):\n", + " \"\"\"\n", + " Pruned Exact Linear Time algorithm for change point detection\n", + "\n", + " Parameters\n", + " ----------\n", + " x: np.ndarray\n", + " The timeseries signal\n", + " beta: float\n", + " Penalty of segmenting the series\n", + "\n", + " Returns\n", + " -------\n", + " C: jnp.ndarray\n", + " The best costs up to segment t\n", + " last_change: jnp.ndarray\n", + " The last change point up to segment t\n", + " \"\"\"\n", + " n = len(x)\n", + "\n", + " # cumulative sums for cost\n", + " S1 = jnp.concatenate([jnp.array([0.0]), jnp.cumsum(x)])\n", + " S2 = jnp.concatenate([jnp.array([0.0]), jnp.cumsum(x**2)])\n", + "\n", + " # DP arrays\n", + " C = jnp.full((n+1,), jnp.inf)\n", + " C = C.at[0].set(-beta)\n", + " last_change = jnp.full((n+1,), -1)\n", + " min_size = 3\n", + "\n", + " s_all = jnp.arange(n) # all possible candidates\n", + "\n", + " def body(t, carry):\n", + " C, last_change = carry\n", + "\n", + " # Compute cost for all s < t, mask invalid\n", + " # valid = s_all < t & ((t - s_all) >= min_size)\n", + " \n", + " valid = (s_all < t) & ((t - s_all) >= min_size)\n", + " costs = jnp.where(\n", + " valid,\n", + " C[s_all] + segment_cost_jax(S1, S2, s_all, t) + beta,\n", + " jnp.inf\n", + " )\n", + "\n", + " best_s = jnp.argmin(costs)\n", + " C = C.at[t].set(costs[best_s])\n", + " last_change = last_change.at[t].set(best_s)\n", + " return C, last_change\n", + "\n", + " C, last_change = jax.lax.fori_loop(1, n+1, body, (C, last_change))\n", + " return C, last_change" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "094b9e8e", + "metadata": {}, + "outputs": [], + "source": [ + "def segment_cost_pytensor(S1, S2, i, j):\n", + " \"\"\"Cost of segment x[i:j], SSE around mean\"\"\"\n", + " n = j - i\n", + " sum_x = S1[j] - S1[i]\n", + " sum_x2 = S2[j] - S2[i]\n", + " return pt.switch(\n", + " pt.gt(n, 0),\n", + " sum_x2 - (sum_x ** 2) / n,\n", + " np.inf\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "03e5e927", + "metadata": {}, + "outputs": [], + "source": [ + "x_symbolic = pt.vector(\"x\")\n", + "beta_symbolic = pt.scalar(\"beta\")\n", + "n = x_symbolic.shape[0]\n", + "N_STEPS=100\n", + "\n", + "# cumulative sums for cost\n", + "S1 = pt.concatenate([pt.as_tensor([0.0]), pt.cumsum(x_symbolic)])\n", + "S2 = pt.concatenate([pt.as_tensor([0.0]), pt.cumsum(x_symbolic**2)])\n", + "\n", + "# DP arrays\n", + "C_init = pt.alloc(np.inf, n+1)\n", + "C_init = pt.set_subtensor(C_init[0], -beta_symbolic)\n", + "last_change_init = pt.alloc(-1, n+1)\n", + "\n", + "s_all = pt.arange(n) # candidate change points\n", + "min_size = 3\n", + "\n", + "def step(t, C_prev, last_change_prev, S1, S2, beta_symbolic, s_all):\n", + " # valid = (s_all < t) & ((t - s_all) >= min_size)\n", + " valid = pt.and_(pt.lt(s_all, t), pt.ge(t - s_all, min_size))\n", + "\n", + " # compute costs for all candidates\n", + " costs, _ = pytensor.scan(\n", + " fn=lambda s: pt.switch(\n", + " valid[s],\n", + " C_prev[s] + segment_cost_pytensor(S1, S2, s, t) + beta_symbolic,\n", + " np.inf\n", + " ),\n", + " sequences=[pt.arange(n)]\n", + " )\n", + " costs = costs.flatten()\n", + "\n", + " best_s = pt.argmin(costs, axis=0)\n", + " C_new = pt.set_subtensor(C_prev[t], costs[best_s])\n", + " last_change_new = pt.set_subtensor(last_change_prev[t], best_s)\n", + "\n", + " return C_new, last_change_new\n", + "\n", + "(C_vals, last_change_vals), _ = pytensor.scan(\n", + " fn=step,\n", + " sequences=[pt.arange(1, n+1)],\n", + " outputs_info=[C_init, last_change_init],\n", + " non_sequences=[S1, S2, beta_symbolic, s_all],\n", + " n_steps=N_STEPS # Added fixed iterations here\n", + ")\n", + "\n", + "pelt_pytensor = pytensor.function([x_symbolic, beta_symbolic], [C_vals[-1], last_change_vals[-1]], trust_input=True)\n", + "pelt_pytensor_numba = pytensor.function(inputs=[x_symbolic, beta_symbolic], outputs=[C_vals[-1], last_change_vals[-1]], mode=\"NUMBA\", trust_input=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "69aea9a2", + "metadata": {}, + "outputs": [], + "source": [ + "pelt_bench = Benchmarker(\n", + " functions=[pelt_pytensor, pelt_numba, pelt_jax, pelt_pytensor_numba], \n", + " names=['pelt_pytensor', 'pelt_numba', 'pelt_jax', 'pelt_pytensor_numba'],\n", + " number=10\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "5aee4a18", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
LoopsMin (us)Max (us)Mean (us)StdDev (us)Median (us)IQR (us)OPS (Kops/s)Samples
pelt_pytensorpelt_inputs1011072.00830011740.05830011258.694433190.10373311197.529198142.7542010.0888209
pelt_numbapelt_inputs1018.57500221.35410319.5058021.11041218.7666041.62490251.2667985
pelt_jaxpelt_inputs1064.74999877.74590067.1793842.99124866.4416992.02182714.88551920
pelt_pytensor_numbapelt_inputs102188.9291992497.1125012334.866661123.0317742302.750002231.9750030.4282905
\n", + "
" + ], + "text/plain": [ + " Loops Min (us) Max (us) \\\n", + "pelt_pytensor pelt_inputs 10 11072.008300 11740.058300 \n", + "pelt_numba pelt_inputs 10 18.575002 21.354103 \n", + "pelt_jax pelt_inputs 10 64.749998 77.745900 \n", + "pelt_pytensor_numba pelt_inputs 10 2188.929199 2497.112501 \n", + "\n", + " Mean (us) StdDev (us) Median (us) \\\n", + "pelt_pytensor pelt_inputs 11258.694433 190.103733 11197.529198 \n", + "pelt_numba pelt_inputs 19.505802 1.110412 18.766604 \n", + "pelt_jax pelt_inputs 67.179384 2.991248 66.441699 \n", + "pelt_pytensor_numba pelt_inputs 2334.866661 123.031774 2302.750002 \n", + "\n", + " IQR (us) OPS (Kops/s) Samples \n", + "pelt_pytensor pelt_inputs 142.754201 0.088820 9 \n", + "pelt_numba pelt_inputs 1.624902 51.266798 5 \n", + "pelt_jax pelt_inputs 2.021827 14.885519 20 \n", + "pelt_pytensor_numba pelt_inputs 231.975003 0.428290 5 " + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pelt_bench.run(\n", + " inputs={\n", + " \"pelt_inputs\": {\"x\": xs_std, \"beta\": 2. * np.log(len(xs_std))},\n", + " }\n", + ")\n", + "pelt_bench.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "47cae530", + "metadata": {}, + "outputs": [], + "source": [ + "outputs = pelt_numba(xs_std, 2. * np.log(len(xs_std)))" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "7b395b06", + "metadata": {}, + "outputs": [], + "source": [ + "def plot_pelt_diagnostics(x, cps, C):\n", + " \"\"\"\n", + " Diagnostic plots for PELT changepoint detection.\n", + " \n", + " Args:\n", + " x: 1D array, original time series\n", + " C: 1D array, cumulative DP cost from pelt()\n", + " cps: list of changepoint indices (sorted ascending)\n", + " \"\"\"\n", + " n = len(x)\n", + " cps_full = [0] + cps + [n]\n", + "\n", + " # Segment means, std, SSE\n", + " segment_means = []\n", + " segment_stds = []\n", + " segment_costs = []\n", + " for start, end in zip(cps_full[:-1], cps_full[1:]):\n", + " seg = x[start:end]\n", + " mean = np.mean(seg)\n", + " std = np.std(seg)\n", + " cost = np.sum((seg - mean)**2)\n", + " segment_means.append(mean)\n", + " segment_stds.append(std)\n", + " segment_costs.append(cost)\n", + "\n", + " # Step function for segment mean\n", + " mean_step = np.zeros(n)\n", + " for i, (start, end) in enumerate(zip(cps_full[:-1], cps_full[1:])):\n", + " mean_step[start:end] = segment_means[i]\n", + "\n", + " # Step function for segment std\n", + " std_step = np.zeros(n)\n", + " for i, (start, end) in enumerate(zip(cps_full[:-1], cps_full[1:])):\n", + " std_step[start:end] = segment_stds[i]\n", + "\n", + " if len(x) < 20:\n", + " title1 = \"Warning: Sample size is small - Detected Changepoints\"\n", + " else:\n", + " title1 = \"Detected Changepoints\"\n", + "\n", + " fig = make_subplots(\n", + " rows=4, \n", + " cols=1,\n", + " subplot_titles=(title1, \"Average Shifts\", \"Variability Shifts\", \"Cumulative Cost\")\n", + " )\n", + "\n", + " fig.add_trace(\n", + " go.Scatter(\n", + " x = np.arange(len(x)),\n", + " y = x,\n", + " line_color=\"royalblue\",\n", + " name = \"Actuals\",\n", + " mode=\"lines\",\n", + " showlegend=False,\n", + " hovertemplate=\"Time Point: %{x}
Actual: %{y}\"\n", + " ),\n", + " row=1, col=1\n", + " )\n", + "\n", + " for cp in cps:\n", + " fig.add_vline(x=cp, line_dash='dash', line_color=\"red\", row=1, col=1)\n", + "\n", + " fig.add_trace(\n", + " go.Scatter(\n", + " x = np.arange(len(x)),\n", + " y = x,\n", + " name = \"Actuals\",\n", + " mode=\"lines\",\n", + " line_color=\"rgba(105, 105, 105, 0.25)\",\n", + " showlegend=False,\n", + " hoverinfo=\"skip\"\n", + " ),\n", + " row=2, col=1\n", + " )\n", + "\n", + " fig.add_trace(\n", + " go.Scatter(\n", + " x = np.arange(len(x)),\n", + " y = mean_step,\n", + " name = \"Average\",\n", + " line_color=\"royalblue\",\n", + " showlegend=False,\n", + " hovertemplate=\"Time Point: %{x}
Average: %{y}\"\n", + " ),\n", + " row=2, col=1\n", + " )\n", + "\n", + " fig.add_trace(\n", + " go.Scatter(\n", + " x = np.arange(len(x)),\n", + " y = std_step,\n", + " name = \"Standard Deviation\",\n", + " line_color=\"royalblue\",\n", + " showlegend=False,\n", + " hovertemplate=\"Time Point: %{x}
Standard Deviation: %{y}\"\n", + " ),\n", + " row=3, col=1\n", + " )\n", + "\n", + " fig.add_trace(\n", + " go.Scatter(\n", + " x = np.arange(len(x)),\n", + " y = C,\n", + " name = \"Cumulative Cost\",\n", + " line_color=\"royalblue\",\n", + " showlegend=False,\n", + " hovertemplate=\"Time Point: %{x}
Cost: %{y}\"\n", + " ),\n", + " row=4, col=1\n", + " )\n", + "\n", + " for cp in cps:\n", + " fig.add_vline(x=cp, line_dash='dash', line_color=\"red\", row=4, col=1)\n", + "\n", + " return fig.update_layout(height=1000, width=1200, template=\"plotly_dark\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "e1de7df5", + "metadata": {}, + "outputs": [], + "source": [ + "def get_changepoints(last_change, n):\n", + " \"\"\"\n", + " Backtrack changepoints from last_change array.\n", + " \n", + " Args:\n", + " last_change: array from pelt()\n", + " n: length of input series\n", + "\n", + " Returns:\n", + " list of changepoint indices (sorted ascending)\n", + " \"\"\"\n", + " cps = []\n", + " t = n\n", + " while t > 0:\n", + " s = int(last_change[t])\n", + " if s <= 0:\n", + " break\n", + " cps.append(s)\n", + " t = s\n", + " return list(reversed(cps))" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "b73d80ac", + "metadata": {}, + "outputs": [], + "source": [ + "cps = get_changepoints(outputs[1], n=len(xs_std))" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "e2e376de", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_pelt_diagnostics(xs, cps, outputs[0])" + ] + }, + { + "cell_type": "markdown", + "id": "09f2e235", + "metadata": {}, + "source": [ + "# Kalman Filter Algorithms" + ] + }, + { + "cell_type": "markdown", + "id": "1c42336f", + "metadata": {}, + "source": [ + "## Linear Gaussian Kalman Filter" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "5004eeab", + "metadata": {}, + "outputs": [], + "source": [ + "@jit(nopython=True)\n", + "def atrocious_kalman_filter_numba(z, F, H, Q, R, x0, P0):\n", + " \"\"\"\n", + " This implementation of the Kalman filter is Atrocious and in standard Python would be a \n", + " BIG NO-NO. That being said this version SIGNIFICANTLY reduces Numba Compilation time. \n", + " \n", + " Linear Gaussian Kalman filter algorithm\n", + "\n", + " Parameters\n", + " ----------\n", + " z: np.ndarray\n", + " shape (T, m) - observations\n", + " F: np.ndarray\n", + " state transition matrix - shape (n, n)\n", + " H: np.ndarray\n", + " observation/design matrix - shape (m, n)\n", + " Q: np.ndarray\n", + " process noise covariance - shape (n, n)\n", + " R: np.ndarray\n", + " observation noise covariance - shape (m, m)\n", + " x0: np.ndarray\n", + " initial state mean - shape (n,)\n", + " P0: np.ndarray\n", + " initial state covariance - shape (n, n)\n", + "\n", + " Returns\n", + " -------\n", + " xs: np.ndarray\n", + " shape (T, n) - filtered state means\n", + " Ps: np.ndarray\n", + " shape (T, n, n) - filtered state covariances\n", + " \"\"\"\n", + " T = z.shape[0]\n", + " m = z.shape[1]\n", + " n = x0.shape[0]\n", + "\n", + " xs = np.empty((T, n), dtype=np.float32)\n", + " Ps = np.empty((T, n, n), dtype=np.float32)\n", + "\n", + " # local working arrays\n", + " x = np.empty(n, dtype=np.float32)\n", + " for i in range(n):\n", + " x[i] = x0[i]\n", + " P = np.empty((n, n), dtype=np.float32)\n", + " for i in range(n):\n", + " for j in range(n):\n", + " P[i, j] = P0[i, j]\n", + "\n", + " # temporary matrices/vectors\n", + " x_pred = np.empty((T, n), dtype=np.float32)\n", + " P_pred = np.empty((T, n, n), dtype=np.float32)\n", + " y = np.empty(m, dtype=np.float32)\n", + " S = np.empty((m, m), dtype=np.float32)\n", + " K = np.empty((n, m), dtype=np.float32)\n", + " I_n = np.eye(n, dtype=np.float32)\n", + "\n", + " for t in range(T):\n", + " # === Predict ===\n", + " # x_pred = F @ x\n", + " for i in range(n):\n", + " s = 0.0\n", + " for j in range(n):\n", + " s += F[i, j] * x[j]\n", + " x_pred[t, i] = s\n", + "\n", + " # P_pred = F @ P @ F.T + Q\n", + " # temp = F @ P\n", + " temp = np.empty((n, n), dtype=np.float32)\n", + " for i in range(n):\n", + " for j in range(n):\n", + " s = 0.0\n", + " for k in range(n):\n", + " s += F[i, k] * P[k, j]\n", + " temp[i, j] = s\n", + " # P_pred = temp @ F.T\n", + " for i in range(n):\n", + " for j in range(n):\n", + " s = 0.0\n", + " for k in range(n):\n", + " s += temp[i, k] * F[j, k] # F.T[k, j] = F[j, k]\n", + " P_pred[t, i, j] = s + Q[i, j]\n", + "\n", + " # === Update ===\n", + " # y = z[t] - H @ x_pred\n", + " for i in range(m):\n", + " s = 0.0\n", + " for j in range(n):\n", + " s += H[i, j] * x_pred[t, j]\n", + " y[i] = z[t, i] - s\n", + "\n", + " # S = H @ P_pred @ H.T + R\n", + " # temp2 = H @ P_pred\n", + " temp2 = np.empty((m, n), dtype=np.float32)\n", + " for i in range(m):\n", + " for j in range(n):\n", + " s = 0.0\n", + " for k in range(n):\n", + " s += H[i, k] * P_pred[t, k, j]\n", + " temp2[i, j] = s\n", + " # S = temp2 @ H.T\n", + " for i in range(m):\n", + " for j in range(m):\n", + " s = 0.0\n", + " for k in range(n):\n", + " s += temp2[i, k] * H[j, k] # H.T[k,j] = H[j,k]\n", + " S[i, j] = s + R[i, j]\n", + "\n", + " # K = P_pred @ H.T @ inv(S)\n", + " # first compute P_pred @ H.T -> (n, m)\n", + " P_Ht = np.empty((n, m), dtype=np.float32)\n", + " for i in range(n):\n", + " for j in range(m):\n", + " s = 0.0\n", + " for k in range(n):\n", + " s += P_pred[t, i, k] * H[j, k] # H.T[k,j] = H[j,k]\n", + " P_Ht[i, j] = s\n", + "\n", + " # invert S\n", + " S_inv = np.linalg.inv(S)\n", + "\n", + " # K = P_Ht @ S_inv (n,m) @ (m,m) -> (n,m)\n", + " for i in range(n):\n", + " for j in range(m):\n", + " s = 0.0\n", + " for k in range(m):\n", + " s += P_Ht[i, k] * S_inv[k, j]\n", + " K[i, j] = s\n", + "\n", + " # x = x_pred + K @ y\n", + " for i in range(n):\n", + " s = 0.0\n", + " for j in range(m):\n", + " s += K[i, j] * y[j]\n", + " x[i] = x_pred[t, i] + s\n", + "\n", + " # P = (I - K H) P_pred\n", + " # compute (I - K H)\n", + " KH = np.empty((n, n), dtype=np.float32)\n", + " for i in range(n):\n", + " for j in range(n):\n", + " s = 0.0\n", + " for k in range(m):\n", + " s += K[i, k] * H[k, j]\n", + " KH[i, j] = s\n", + "\n", + " I_minus_KH = np.empty((n, n), dtype=np.float32)\n", + " for i in range(n):\n", + " for j in range(n):\n", + " I_minus_KH[i, j] = I_n[i, j] - KH[i, j]\n", + "\n", + " # P = I_minus_KH @ P_pred\n", + " for i in range(n):\n", + " for j in range(n):\n", + " s = 0.0\n", + " for k in range(n):\n", + " s += I_minus_KH[i, k] * P_pred[t, k, j]\n", + " P[i, j] = s\n", + "\n", + " # store results\n", + " for i in range(n):\n", + " xs[t, i] = x[i]\n", + " for i in range(n):\n", + " for j in range(n):\n", + " Ps[t, i, j] = P[i, j]\n", + "\n", + " return xs, Ps, x_pred, P_pred\n" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "25bcb14e", + "metadata": {}, + "outputs": [], + "source": [ + "@jit(nopython=True)\n", + "def kalman_filter_numba(z, F, H, Q, R, x0, P0):\n", + " \"\"\"\n", + " Linear Gaussian Kalman filter algorithm\n", + "\n", + " Parameters\n", + " ----------\n", + " z: np.ndarray\n", + " shape (T, m) - observations\n", + " F: np.ndarray\n", + " state transition matrix - shape (n, n)\n", + " H: np.ndarray\n", + " observation/design matrix - shape (m, n)\n", + " Q: np.ndarray\n", + " process noise covariance - shape (n, n)\n", + " R: np.ndarray\n", + " observation noise covariance - shape (m, m)\n", + " x0: np.ndarray\n", + " initial state mean - shape (n,)\n", + " P0: np.ndarray\n", + " initial state covariance - shape (n, n)\n", + "\n", + " Returns\n", + " -------\n", + " xs: np.ndarray\n", + " shape (T, n) - filtered state means\n", + " Ps: np.ndarray\n", + " shape (T, n, n) - filtered state covariances\n", + " \"\"\"\n", + " T, m = z.shape\n", + " n = x0.shape[0]\n", + "\n", + " xs = np.zeros((T, n), dtype=np.float32)\n", + " Ps = np.zeros((T, n, n), dtype=np.float32)\n", + "\n", + " x_pred = np.zeros((T, n), dtype=np.float32)\n", + " P_pred = np.zeros((T, n, n), dtype=np.float32)\n", + "\n", + " x = x0.copy()\n", + " P = P0.copy()\n", + "\n", + " I = np.eye(n, dtype=np.float32)\n", + "\n", + " for t in range(T):\n", + " # --- Predict ---\n", + " x_pred[t] = F @ x\n", + " P_pred[t] = F @ P @ F.T + Q\n", + "\n", + " # --- Update ---\n", + " y = z[t] - H @ x_pred[t]\n", + " S = H @ P_pred[t] @ H.T + R\n", + " K = P_pred[t] @ H.T @ np.linalg.inv(S)\n", + "\n", + " x = x_pred[t] + K @ y\n", + " P = (I - K @ H) @ P_pred[t]\n", + "\n", + " xs[t] = x\n", + " Ps[t] = P\n", + "\n", + " return xs, Ps, x_pred, P_pred" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "cd715dfb", + "metadata": {}, + "outputs": [], + "source": [ + "@block\n", + "@jax.jit\n", + "def kalman_filter_jax(z, F, H, Q, R, x0, P0):\n", + " \"\"\"\n", + " Linear Gaussian Kalman filter algorithm\n", + "\n", + " Parameters\n", + " ----------\n", + " z: np.ndarray\n", + " shape (T, m) - observations\n", + " F: np.ndarray\n", + " state transition matrix - shape (n, n)\n", + " H: np.ndarray\n", + " observation/design matrix - shape (m, n)\n", + " Q: np.ndarray\n", + " process noise covariance - shape (n, n)\n", + " R: np.ndarray\n", + " observation noise covariance - shape (m, m)\n", + " x0: np.ndarray\n", + " initial state mean - shape (n,)\n", + " P0: np.ndarray\n", + " initial state covariance - shape (n, n)\n", + "\n", + " Returns\n", + " -------\n", + " xs: jnp.ndarray\n", + " shape (T, n) - filtered state means\n", + " Ps: jnp.ndarray\n", + " shape (T, n, n) - filtered state covariances\n", + " \"\"\"\n", + "\n", + " n = x0.shape[0]\n", + " I = jnp.eye(n)\n", + " X_pred_init = jnp.zeros((1,))\n", + " P_pred_init = jnp.zeros((1, 1,))\n", + "\n", + " def step(carry, z_t):\n", + " x, P, _, _ = carry\n", + "\n", + " # --- Predict ---\n", + " x_pred = F @ x\n", + " P_pred = F @ P @ F.T + Q\n", + "\n", + " # --- Update ---\n", + " y = z_t - H @ x_pred\n", + " S = H @ P_pred @ H.T + R\n", + " K = P_pred @ H.T @ jnp.linalg.inv(S)\n", + "\n", + " x_new = x_pred + K @ y\n", + " P_new = (I - K @ H) @ P_pred\n", + "\n", + " return (x_new, P_new, x_pred, P_pred), (x_new, P_new, x_pred, P_pred)\n", + "\n", + " # run scan\n", + " (_, _, _, _), (xs, Ps, x_pred, P_pred) = jax.lax.scan(step, (x0, P0, X_pred_init, P_pred_init), z)\n", + "\n", + " return xs, Ps, x_pred, P_pred" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "5af37c40", + "metadata": {}, + "outputs": [], + "source": [ + "z_symbolic = pt.matrix(\"z\")\n", + "F_symbolic = pt.matrix(\"F\")\n", + "H_symbolic = pt.matrix(\"H\")\n", + "Q_symbolic = pt.matrix(\"Q\")\n", + "R_symbolic = pt.matrix(\"R\")\n", + "x0_symbolic = pt.vector(\"x0\")\n", + "P0_symbolic = pt.matrix(\"P0\")\n", + "\n", + "n = x0_symbolic.shape[0]\n", + "I = pt.eye(n)\n", + "X_pred_init = pt.zeros_like(x0_symbolic)\n", + "P_pred_init = pt.zeros_like(P0_symbolic)\n", + "\n", + "N_STEPS = 500\n", + "\n", + "def step(z_t, x, P, x_pred, P_pred, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, I):\n", + "\n", + " # --- Predict ---\n", + " x_pred = F_symbolic @ x\n", + " P_pred = F_symbolic @ P @ F_symbolic.T + Q_symbolic\n", + "\n", + " # --- Update ---\n", + " y = z_t - H_symbolic @ x_pred\n", + " S = H_symbolic @ P_pred @ H_symbolic.T + R_symbolic\n", + " K = P_pred @ H_symbolic.T @ pt.linalg.inv(S)\n", + "\n", + " x_new = x_pred + K @ y\n", + " P_new = (I - K @ H_symbolic) @ P_pred\n", + "\n", + " return x_new, P_new, x_pred, P_pred\n", + "\n", + "# run scan\n", + "(xs, Ps, x_pred, P_pred), _ = pytensor.scan(\n", + " fn=step,\n", + " outputs_info=[x0_symbolic, P0_symbolic, X_pred_init, P_pred_init],\n", + " sequences=[z_symbolic],\n", + " non_sequences=[F_symbolic, H_symbolic, Q_symbolic, R_symbolic, I],\n", + " n_steps=N_STEPS\n", + ")\n", + "\n", + "kalman_filter_pytensor = pytensor.function([z_symbolic, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, x0_symbolic, P0_symbolic], [xs, Ps, x_pred, P_pred], trust_input=True)\n", + "\n", + "kalman_filter_pytensor_numba = pytensor.function([z_symbolic, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, x0_symbolic, P0_symbolic], [xs, Ps, x_pred, P_pred], mode=\"NUMBA\", trust_input=True)\n", + "kalman_filter_pytensor_jax = pytensor.function([z_symbolic, F_symbolic, H_symbolic, Q_symbolic, R_symbolic, x0_symbolic, P0_symbolic], [xs, Ps, x_pred, P_pred], mode=\"JAX\", trust_input=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "3c9e7254", + "metadata": {}, + "outputs": [], + "source": [ + "T = 500\n", + "F = np.array([[1.0]]).astype(np.float32)\n", + "H = np.array([[1.0]]).astype(np.float32)\n", + "Q = np.array([[0.01]]).astype(np.float32)\n", + "R = np.array([[0.1]]).astype(np.float32)\n", + "x0 = np.array([0.0]).astype(np.float32)\n", + "P0 = np.array([[1.0]]).astype(np.float32)\n", + "\n", + "true = 1.0\n", + "z = (true + 0.4*np.random.randn(T)).reshape(T, 1).astype(np.float32)" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "00472afd", + "metadata": {}, + "outputs": [], + "source": [ + "kalman_filter_bench = Benchmarker(\n", + " functions=[kalman_filter_pytensor, atrocious_kalman_filter_numba, kalman_filter_numba, kalman_filter_jax, kalman_filter_pytensor_numba, block(kalman_filter_pytensor_jax)], \n", + " names=['kalman_filter_pytensor', 'atrocious_kalman_filter_numba', 'kalman_filter_numba', 'kalman_filter_jax', 'kalman_filter_pytensor_numba', 'kalman_filter_pytensor_jax'],\n", + " number=10\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "68ec703d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/qv/63yqp4p50630y7pqfgcbgtq80000gn/T/tmp7loid6xb:33: NumbaWarning:\n", + "\n", + "\u001b[1m\u001b[1mCannot cache compiled function \"scan\" as it uses dynamic globals (such as ctypes pointers and large global arrays)\u001b[0m\u001b[0m\n", + "\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
LoopsMin (us)Max (us)Mean (us)StdDev (us)Median (us)IQR (us)OPS (Kops/s)Samples
kalman_filter_pytensorkalman_filter_inputs105683.2750047487.9000026232.157842453.5796906133.425003294.8082980.16045817
atrocious_kalman_filter_numbakalman_filter_inputs10305.670901337.154098318.39748011.714092313.77080017.3166043.1407285
kalman_filter_numbakalman_filter_inputs10610.762503686.358300657.73082026.215497670.71249921.3458031.5203795
kalman_filter_jaxkalman_filter_inputs10300.266704360.233302323.62564815.903232319.32080220.6103523.08999019
kalman_filter_pytensor_numbakalman_filter_inputs10748.704199764.174998759.8116815.623608761.8792011.0000021.3161165
kalman_filter_pytensor_jaxkalman_filter_inputs10297.354202347.566698322.99527318.748797321.34999936.7676753.09602122
\n", + "
" + ], + "text/plain": [ + " Loops Min (us) \\\n", + "kalman_filter_pytensor kalman_filter_inputs 10 5683.275004 \n", + "atrocious_kalman_filter_numba kalman_filter_inputs 10 305.670901 \n", + "kalman_filter_numba kalman_filter_inputs 10 610.762503 \n", + "kalman_filter_jax kalman_filter_inputs 10 300.266704 \n", + "kalman_filter_pytensor_numba kalman_filter_inputs 10 748.704199 \n", + "kalman_filter_pytensor_jax kalman_filter_inputs 10 297.354202 \n", + "\n", + " Max (us) Mean (us) \\\n", + "kalman_filter_pytensor kalman_filter_inputs 7487.900002 6232.157842 \n", + "atrocious_kalman_filter_numba kalman_filter_inputs 337.154098 318.397480 \n", + "kalman_filter_numba kalman_filter_inputs 686.358300 657.730820 \n", + "kalman_filter_jax kalman_filter_inputs 360.233302 323.625648 \n", + "kalman_filter_pytensor_numba kalman_filter_inputs 764.174998 759.811681 \n", + "kalman_filter_pytensor_jax kalman_filter_inputs 347.566698 322.995273 \n", + "\n", + " StdDev (us) Median (us) \\\n", + "kalman_filter_pytensor kalman_filter_inputs 453.579690 6133.425003 \n", + "atrocious_kalman_filter_numba kalman_filter_inputs 11.714092 313.770800 \n", + "kalman_filter_numba kalman_filter_inputs 26.215497 670.712499 \n", + "kalman_filter_jax kalman_filter_inputs 15.903232 319.320802 \n", + "kalman_filter_pytensor_numba kalman_filter_inputs 5.623608 761.879201 \n", + "kalman_filter_pytensor_jax kalman_filter_inputs 18.748797 321.349999 \n", + "\n", + " IQR (us) OPS (Kops/s) \\\n", + "kalman_filter_pytensor kalman_filter_inputs 294.808298 0.160458 \n", + "atrocious_kalman_filter_numba kalman_filter_inputs 17.316604 3.140728 \n", + "kalman_filter_numba kalman_filter_inputs 21.345803 1.520379 \n", + "kalman_filter_jax kalman_filter_inputs 20.610352 3.089990 \n", + "kalman_filter_pytensor_numba kalman_filter_inputs 1.000002 1.316116 \n", + "kalman_filter_pytensor_jax kalman_filter_inputs 36.767675 3.096021 \n", + "\n", + " Samples \n", + "kalman_filter_pytensor kalman_filter_inputs 17 \n", + "atrocious_kalman_filter_numba kalman_filter_inputs 5 \n", + "kalman_filter_numba kalman_filter_inputs 5 \n", + "kalman_filter_jax kalman_filter_inputs 19 \n", + "kalman_filter_pytensor_numba kalman_filter_inputs 5 \n", + "kalman_filter_pytensor_jax kalman_filter_inputs 22 " + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "kalman_filter_bench.run(\n", + " inputs={\n", + " \"kalman_filter_inputs\": {\"z\": z, \"F\": F, \"H\": H, \"Q\": Q, \"R\": R, \"x0\": x0, \"P0\": P0},\n", + " }\n", + ")\n", + "kalman_filter_bench.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "37526f86", + "metadata": {}, + "outputs": [], + "source": [ + "xs, Ps, x_pred, P_pred = kalman_filter_jax(z, F, H, Q, R, x0, P0)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "231140ba", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_pred_intervals(z, x_pred, P_pred, H, R, zscore=1.96):\n", + " T = z.shape[0]\n", + " m = H.shape[0]\n", + " mean = np.zeros((T, m))\n", + " lower = np.zeros((T, m))\n", + " upper = np.zeros((T, m))\n", + " outside = np.zeros(T, dtype=np.bool_)\n", + "\n", + " for t in range(T):\n", + " mean[t] = H @ x_pred[t]\n", + " S = H @ P_pred[t] @ H.T + R\n", + " std = np.sqrt(np.diag(S))\n", + " lower[t] = mean[t] - zscore * std\n", + " upper[t] = mean[t] + zscore * std\n", + "\n", + " # check coverage of actual obs\n", + " outside[t] = np.any((z[t] < lower[t]) | (z[t] > upper[t]))\n", + "\n", + " coverage = 1 - outside.mean()\n", + " return mean, lower, upper, coverage\n" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "ff739003", + "metadata": {}, + "outputs": [], + "source": [ + "mean, lower, upper, coverage = compute_pred_intervals(z, x_pred, P_pred, H, R)" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "7546c95c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "np.float64(0.886)" + ] + }, + "execution_count": 48, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "coverage" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "6b765a37", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig= go.Figure()\n", + "fig.add_traces(\n", + " [\n", + " go.Scatter(\n", + " x = np.arange(T),\n", + " y = z.ravel(),\n", + " mode=\"markers\",\n", + " marker_color = \"royalblue\",\n", + " name = \"actuals\"\n", + " ),\n", + " go.Scatter(\n", + " x = np.arange(T),\n", + " y = xs.ravel(),\n", + " mode = \"lines\",\n", + " marker_color = \"orange\",\n", + " name = \"filtered mean\"\n", + " ),\n", + " go.Scatter(\n", + " name=\"\", \n", + " x=np.arange(T), \n", + " y=upper.ravel(), \n", + " mode=\"lines\", \n", + " marker=dict(color=\"#eb8c34\"), \n", + " line=dict(width=0), \n", + " legendgroup=\"95% CI\",\n", + " showlegend=False\n", + " ),\n", + " go.Scatter(\n", + " name=\"95% CI\", \n", + " x=np.arange(T), \n", + " y=lower.ravel(), \n", + " mode=\"lines\", marker=dict(color=\"#eb8c34\"), \n", + " line=dict(width=0), \n", + " legendgroup=\"95% CI\", \n", + " fill='tonexty', \n", + " fillcolor='rgba(235, 140, 52, 0.2)'\n", + " ),\n", + "\n", + " ]\n", + ")\n", + "fig.update_layout(\n", + " xaxis=dict(\n", + " title = \"Time Index\",\n", + " ),\n", + " yaxis=dict(\n", + " title = \"y\"\n", + " ),\n", + " template = \"plotly_dark\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "052b0194", + "metadata": {}, + "source": [ + "## Non-linear Kalman Filter" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "6b088d6b", + "metadata": {}, + "outputs": [], + "source": [ + "@jit(nopython=True)\n", + "def loglik_poisson_numba(s, y):\n", + " \"\"\"Poisson Log Likelihood\"\"\"\n", + " mu = np.exp(s)\n", + " return y * np.log(mu + 1e-30) - mu - math.lgamma(y + 1.0) # numba does not support scipy.special gammaln\n", + "\n", + "@jit(nopython=True)\n", + "def particle_filter_1d_predict_numba(A, Q, x0_mean, x0_std, ys, N=1000, seed=2):\n", + " \"\"\"\n", + " 1D particle filter.\n", + " \n", + " Parameters\n", + " ----------\n", + " A: float\n", + " State transition\n", + " Q: float\n", + " Process covariance\n", + " x0_mean: float\n", + " Prior mean for the latent state\n", + " x0_std: float\n", + " Prior standard deviation \n", + " ys: np.ndarray\n", + " observations\n", + " N: int\n", + " number of particles\n", + " seed: int\n", + " rng seed for reproducibility\n", + "\n", + " Returns\n", + " -------\n", + " filtered_means: np.ndarray\n", + " The filtered mean for the latent state \n", + " filtered_vars: np.ndarray\n", + " The filtered variance for the latent state\n", + " pred_means: np.ndarray\n", + " observation predicted mean \n", + " \"\"\"\n", + " np.random.seed(seed)\n", + " T = ys.shape[0]\n", + " particles = np.random.normal(x0_mean, x0_std, size=N)\n", + " weights = np.ones(N) / N\n", + "\n", + " filtered_means = np.zeros(T)\n", + " filtered_vars = np.zeros(T)\n", + " pred_means = np.zeros(T)\n", + "\n", + " for t in range(T):\n", + " y = ys[t]\n", + "\n", + " # propagate (vectorized)\n", + " particles = A * particles + np.random.normal(0, np.sqrt(Q), size=N)\n", + "\n", + " # update weights\n", + " logw = np.zeros(N)\n", + " for i in range(N):\n", + " logw[i] = loglik_poisson_numba(particles[i], y)\n", + " logw = logw - np.max(logw)\n", + " weights *= np.exp(logw)\n", + " weights /= np.sum(weights) + 1e-12\n", + "\n", + " # filtered moments\n", + " mean_t = np.sum(weights * particles)\n", + " var_t = np.sum(weights * (particles - mean_t) ** 2)\n", + "\n", + " # predictive mean\n", + " pred_mean = np.sum(weights * np.exp(particles))\n", + "\n", + " filtered_means[t] = mean_t\n", + " filtered_vars[t] = var_t\n", + " pred_means[t] = pred_mean\n", + "\n", + " # resample (multinomial resampling) because numba doesn't support np.random.choice\n", + " cumulative_sum = np.cumsum(weights)\n", + " cumulative_sum[-1] = 1.0 # guard against rounding error\n", + " indices = np.searchsorted(cumulative_sum, np.random.rand(N))\n", + "\n", + " particles = particles[indices]\n", + " weights = np.ones(N) / N\n", + "\n", + " return filtered_means, filtered_vars, pred_means" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "a468c403", + "metadata": {}, + "outputs": [], + "source": [ + "# Had to fix the loglikelihood and key to use benchmarker as is\n", + "def loglik_poisson_jax(s, y):\n", + " \"\"\"Poisson Log Likelihood\"\"\"\n", + " mu = jnp.exp(s)\n", + " return y * jnp.log(mu + 1e-30) - mu - gammaln(y + 1.0)\n", + "\n", + "@block\n", + "@partial(jax.jit, static_argnums=5)\n", + "def particle_filter_1d_predict_jax(\n", + " A, Q, x0_mean, x0_std, ys, N=1000,\n", + "):\n", + " \"\"\"\n", + " 1D particle filter.\n", + " \n", + " Parameters\n", + " ----------\n", + " A: float\n", + " State transition\n", + " Q: float\n", + " Process covariance\n", + " x0_mean: float\n", + " Prior mean for the latent state\n", + " x0_std: float\n", + " Prior standard deviation \n", + " ys: np.ndarray\n", + " observations\n", + " loglik_fn: function\n", + " The log likelihood function\n", + " key: \n", + " JAX prng key\n", + " N: int\n", + " number of particles\n", + "\n", + " Returns\n", + " -------\n", + " filtered_means: jnp.ndarray\n", + " The filtered mean for the latent state \n", + " filtered_vars: jnp.ndarray\n", + " The filtered variance for the latent state\n", + " pred_means: jnp.ndarray\n", + " observation predicted mean \n", + " \"\"\"\n", + " key = jax.random.PRNGKey(0)\n", + " T = ys.shape[0]\n", + " particles = jax.random.normal(key, (N,)) * x0_std + x0_mean # init particles from gaussian priors\n", + " weights = jnp.ones(N) / N # particle weights, all particles equally likely prior\n", + "\n", + " def body_fun(carry, t):\n", + " particles, weights, key = carry\n", + " y = ys[t]\n", + "\n", + " # propagate\n", + " key, subkey = jax.random.split(key)\n", + " particles = A * particles + jax.random.normal(subkey, (N,)) * jnp.sqrt(Q) # state transition model\n", + "\n", + " # update weights\n", + " logw = jax.vmap(lambda x: loglik_poisson_jax(x, y))(particles) # update particles in parallel\n", + " logw = logw - jnp.max(logw) # avoid overflow\n", + " weights = weights * jnp.exp(logw) # old weights times the likelihood\n", + " weights /= jnp.sum(weights) + 1e-12 # normalize so that weights sum to 1\n", + "\n", + " # filtered moments\n", + " mean_t = jnp.sum(weights * particles) # posterior mean of latent state\n", + " var_t = jnp.sum(weights * (particles - mean_t)**2) # posterior variance of latent state\n", + "\n", + " # predictive mean\n", + " pred_mean = jnp.sum(weights * jnp.exp(particles))\n", + "\n", + " # resample to prevent dominant particles\n", + " key, subkey = jax.random.split(key)\n", + " indices = jax.random.choice(subkey, N, p=weights, shape=(N,))\n", + " particles = particles[indices]\n", + " weights = jnp.ones(N) / N\n", + "\n", + " carry = (particles, weights, key)\n", + " out = (mean_t, var_t, pred_mean)\n", + " return carry, out\n", + "\n", + " _, outputs = jax.lax.scan(body_fun, (particles, weights, key), jnp.arange(T))\n", + " return outputs\n" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "15fd4a0c", + "metadata": {}, + "outputs": [], + "source": [ + "from pytensor.tensor.random.utils import RandomStream\n", + "\n", + "# Random stream for PyTensor\n", + "srng = RandomStream(seed=42)\n", + "\n", + "# Poisson log-likelihood\n", + "def loglik_poisson_pytensor(s, y):\n", + " mu = pt.exp(s)\n", + " return y.flatten() * pt.log(mu + 1e-30) - mu - pt.gammaln(y.flatten() + 1.0)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "59525da5", + "metadata": {}, + "outputs": [], + "source": [ + "ys_symbolic = pt.vector(\"ys\")\n", + "x0_mean_symbolic = pt.scalar(\"x0_mean\")\n", + "x0_std_symbolic = pt.scalar(\"x0_std\")\n", + "A_symbolic = pt.scalar(\"A\")\n", + "Q_symbolic = pt.scalar(\"Q\")\n", + "N_symbolic = pt.scalar(\"N\", dtype='int64')\n", + "\n", + "N_STEPS = 300\n", + "\n", + "# Initialize particles and weights\n", + "particles_init = srng.normal(size=(N_symbolic,)) * x0_std_symbolic + x0_mean_symbolic\n", + "weights_init = pt.ones((N_symbolic,)) / N_symbolic \n", + "\n", + "# Step function for scan\n", + "def step(y_t, particles_prev, weights_prev, A_symbolic, Q_symbolic):\n", + " # Propagate particles\n", + " particles_prop = A_symbolic * particles_prev + srng.normal(size=(N_symbolic,)) * pt.sqrt(Q_symbolic)\n", + "\n", + " # Update weights\n", + " # logw = pt.stack([loglik_poisson_pytensor(p, y_t) for p in particles_prop])\n", + " logw = loglik_poisson_pytensor(particles_prop, y_t)\n", + " logw_stable = logw - pt.max(logw)\n", + " w_unnorm = weights_prev * pt.exp(logw_stable)\n", + " w = w_unnorm / (pt.sum(w_unnorm) + 1e-12) \n", + "\n", + " # Filtered moments\n", + " mean_t = pt.sum(w * particles_prop)\n", + " var_t = pt.sum(w * (particles_prop - mean_t) ** 2)\n", + " pred_mean = pt.sum(w * pt.exp(particles_prop))\n", + "\n", + " # Resample particles\n", + " idx = srng.choice(size=(N_symbolic,), a=N_symbolic, p=w) \n", + " particles_resampled = particles_prop[idx]\n", + " weights_resampled = pt.ones((N_symbolic,)) / N_symbolic\n", + "\n", + " # Return flat tuple\n", + " return particles_resampled, weights_resampled, mean_t, var_t, pred_mean\n", + "\n", + "# first two are recurrent, rest are collected\n", + "outputs_info = [\n", + " particles_init,\n", + " weights_init,\n", + " None,\n", + " None,\n", + " None\n", + "]\n", + "\n", + "(particles_seq, weights_seq, means_seq, vars_seq, preds_seq), updates = pytensor.scan(\n", + " fn=step,\n", + " sequences=[ys_symbolic],\n", + " outputs_info=outputs_info,\n", + " non_sequences=[A_symbolic, Q_symbolic],\n", + " n_steps=N_STEPS\n", + ")\n", + "\n", + "particle_filter_1d_predict_pytensor = pytensor.function(\n", + " [A_symbolic, Q_symbolic, x0_mean_symbolic, x0_std_symbolic, ys_symbolic, N_symbolic],\n", + " [means_seq, vars_seq, preds_seq],\n", + " updates=updates,\n", + " no_default_updates=True,\n", + " trust_input=True\n", + ")\n", + "\n", + "particle_filter_1d_predict_pytensor_numba = pytensor.function(\n", + " [A_symbolic, Q_symbolic, x0_mean_symbolic, x0_std_symbolic, ys_symbolic, N_symbolic],\n", + " [means_seq, vars_seq, preds_seq],\n", + " updates=updates,\n", + " no_default_updates=True,\n", + " mode=\"NUMBA\", \n", + " trust_input=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "c033a1d0", + "metadata": {}, + "outputs": [], + "source": [ + "key = jax.random.PRNGKey(0)\n", + "T = 300\n", + "A = 0.95\n", + "Q = 0.05\n", + "rng = np.random.RandomState(1)\n", + "\n", + "target_mean = 10.0\n", + "latent_var = Q / (1 - A**2)\n", + "x0_mean = np.log(target_mean) - 0.5 * latent_var\n", + "x0_std = 1.0\n", + "\n", + "# Simulate latent\n", + "x = np.zeros(T)\n", + "x[0] = rng.normal() * np.sqrt(latent_var) + x0_mean\n", + "for t in range(1, T):\n", + " x[t] = A * x[t-1] + rng.normal() * np.sqrt(Q)\n", + "\n", + "ys = np.array(rng.poisson(np.exp(x)), dtype=np.float32)" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "2a9cbfa5", + "metadata": {}, + "outputs": [], + "source": [ + "nonlinear_kalman_filter_bench = Benchmarker(\n", + " functions=[particle_filter_1d_predict_pytensor, particle_filter_1d_predict_numba, particle_filter_1d_predict_jax, particle_filter_1d_predict_pytensor_numba,], \n", + " names=['particle_filter_1d_predict_pytensor', 'particle_filter_1d_predict_numba', 'particle_filter_1d_predict_jax', 'particle_filter_1d_predict_pytensor_numba',],\n", + " number=5 # This takes a while to run reducing number of loops\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "id": "c782c42b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
LoopsMin (us)Max (us)Mean (us)StdDev (us)Median (us)IQR (us)OPS (Kops/s)Samples
particle_filter_1d_predict_pytensorkalman_filter_inputs5696693.225007742831.816606716956.18836217644.519347708739.15000128481.5166030.0013955
particle_filter_1d_predict_numbakalman_filter_inputs546922.80840849887.73319648122.8366801115.92171748200.7417941708.1164060.0207805
particle_filter_1d_predict_jaxkalman_filter_inputs530683.13340231434.32498931081.088318299.72888030930.791609498.8751960.0321745
particle_filter_1d_predict_pytensor_numbakalman_filter_inputs5655818.125000702576.024993680401.39504018440.632807684621.25840634198.9999990.0014705
\n", + "
" + ], + "text/plain": [ + " Loops \\\n", + "particle_filter_1d_predict_pytensor kalman_filter_inputs 5 \n", + "particle_filter_1d_predict_numba kalman_filter_inputs 5 \n", + "particle_filter_1d_predict_jax kalman_filter_inputs 5 \n", + "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 5 \n", + "\n", + " Min (us) \\\n", + "particle_filter_1d_predict_pytensor kalman_filter_inputs 696693.225007 \n", + "particle_filter_1d_predict_numba kalman_filter_inputs 46922.808408 \n", + "particle_filter_1d_predict_jax kalman_filter_inputs 30683.133402 \n", + "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 655818.125000 \n", + "\n", + " Max (us) \\\n", + "particle_filter_1d_predict_pytensor kalman_filter_inputs 742831.816606 \n", + "particle_filter_1d_predict_numba kalman_filter_inputs 49887.733196 \n", + "particle_filter_1d_predict_jax kalman_filter_inputs 31434.324989 \n", + "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 702576.024993 \n", + "\n", + " Mean (us) \\\n", + "particle_filter_1d_predict_pytensor kalman_filter_inputs 716956.188362 \n", + "particle_filter_1d_predict_numba kalman_filter_inputs 48122.836680 \n", + "particle_filter_1d_predict_jax kalman_filter_inputs 31081.088318 \n", + "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 680401.395040 \n", + "\n", + " StdDev (us) \\\n", + "particle_filter_1d_predict_pytensor kalman_filter_inputs 17644.519347 \n", + "particle_filter_1d_predict_numba kalman_filter_inputs 1115.921717 \n", + "particle_filter_1d_predict_jax kalman_filter_inputs 299.728880 \n", + "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 18440.632807 \n", + "\n", + " Median (us) \\\n", + "particle_filter_1d_predict_pytensor kalman_filter_inputs 708739.150001 \n", + "particle_filter_1d_predict_numba kalman_filter_inputs 48200.741794 \n", + "particle_filter_1d_predict_jax kalman_filter_inputs 30930.791609 \n", + "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 684621.258406 \n", + "\n", + " IQR (us) \\\n", + "particle_filter_1d_predict_pytensor kalman_filter_inputs 28481.516603 \n", + "particle_filter_1d_predict_numba kalman_filter_inputs 1708.116406 \n", + "particle_filter_1d_predict_jax kalman_filter_inputs 498.875196 \n", + "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 34198.999999 \n", + "\n", + " OPS (Kops/s) \\\n", + "particle_filter_1d_predict_pytensor kalman_filter_inputs 0.001395 \n", + "particle_filter_1d_predict_numba kalman_filter_inputs 0.020780 \n", + "particle_filter_1d_predict_jax kalman_filter_inputs 0.032174 \n", + "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 0.001470 \n", + "\n", + " Samples \n", + "particle_filter_1d_predict_pytensor kalman_filter_inputs 5 \n", + "particle_filter_1d_predict_numba kalman_filter_inputs 5 \n", + "particle_filter_1d_predict_jax kalman_filter_inputs 5 \n", + "particle_filter_1d_predict_pytensor_numba kalman_filter_inputs 5 " + ] + }, + "execution_count": 56, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nonlinear_kalman_filter_bench.run(\n", + " inputs={\n", + " \"kalman_filter_inputs\": {\"A\": A, \"Q\": Q, \"x0_mean\": x0_mean, \"x0_std\": x0_std, \"ys\": ys, \"N\": 2000},\n", + " }\n", + ")\n", + "nonlinear_kalman_filter_bench.summary()" + ] + }, + { + "cell_type": "markdown", + "id": "29fe0cc4", + "metadata": {}, + "source": [ + "Slightly different estimates because I couldn't reproduce 1:1 " + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "2bd23897", + "metadata": {}, + "outputs": [], + "source": [ + "filtered_means, filtered_vars, pred_means = particle_filter_1d_predict_numba(\n", + " A, Q, x0_mean, x0_std, ys, N=2000, seed=2\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "7caac5a5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig = make_subplots(\n", + " rows=2, cols=1,\n", + " subplot_titles=(\"Observation Predictions\", \"Latent State Estimation\"),\n", + " vertical_spacing=0.07,\n", + " shared_xaxes=True\n", + ")\n", + "\n", + "fig.add_traces(\n", + " [\n", + " go.Scatter(\n", + " x = np.arange(T),\n", + " y = ys,\n", + " mode = \"markers\",\n", + " marker_color = \"cornflowerblue\",\n", + " name = \"actuals\"\n", + " ),\n", + " go.Scatter(\n", + " x = np.arange(T),\n", + " y = pred_means,\n", + " mode = \"lines\",\n", + " marker_color = \"#eb8c34\",\n", + " name = \"predicted mean\"\n", + " ),\n", + " go.Scatter(\n", + " name=\"\", \n", + " x=np.arange(T), \n", + " y=pred_means + 2*jnp.sqrt(pred_means), \n", + " mode=\"lines\", \n", + " marker=dict(color=\"#eb8c34\"), \n", + " line=dict(width=0), \n", + " legendgroup=\"predicted mean 95% CI\",\n", + " showlegend=False\n", + " ),\n", + " go.Scatter(\n", + " name=\"predicted mean 95% CI\", \n", + " x=np.arange(T), \n", + " y=pred_means - 2*jnp.sqrt(pred_means), \n", + " mode=\"lines\", marker=dict(color=\"#eb8c34\"), \n", + " line=dict(width=0), \n", + " legendgroup=\"predicted mean 95% CI\", \n", + " fill='tonexty', \n", + " fillcolor='rgba(235, 140, 52, 0.2)'\n", + " ),\n", + " ],\n", + " rows=1, cols=1\n", + ")\n", + "\n", + "fig.add_traces(\n", + " [\n", + " go.Scatter(\n", + " x = np.arange(T),\n", + " y = x,\n", + " mode = \"lines\",\n", + " marker_color = \"cornflowerblue\",\n", + " name = \"true latent state\"\n", + " ),\n", + " go.Scatter(\n", + " x = np.arange(T),\n", + " y = filtered_means,\n", + " mode = \"lines\",\n", + " marker_color = \"#eb8c34\",\n", + " name = \"filtered state mean\"\n", + " ),\n", + " go.Scatter(\n", + " name=\"\", \n", + " x=np.arange(T), \n", + " y=filtered_means + 2*jnp.sqrt(filtered_vars), \n", + " mode=\"lines\", \n", + " marker=dict(color=\"#eb8c34\"), \n", + " line=dict(width=0), \n", + " legendgroup=\"filtered state mean 95% CI\",\n", + " showlegend=False\n", + " ),\n", + " go.Scatter(\n", + " name=\"filtered state mean 95% CI\", \n", + " x=np.arange(T), \n", + " y=filtered_means - 2*jnp.sqrt(filtered_vars), \n", + " mode=\"lines\", marker=dict(color=\"#eb8c34\"), \n", + " line=dict(width=0), \n", + " legendgroup=\"filtered state mean 95% CI\", \n", + " fill='tonexty', \n", + " fillcolor='rgba(235, 140, 52, 0.2)'\n", + " ),\n", + " ],\n", + " rows=2, cols=1\n", + ")\n", + "\n", + "for i, yaxis in enumerate(fig.select_yaxes(), 1):\n", + " legend_name = f\"legend{i}\"\n", + " fig.update_layout({legend_name: dict(y=yaxis.domain[1], yanchor=\"top\")}, showlegend=True)\n", + " fig.update_traces(row=i, legend=legend_name)\n", + "\n", + "fig.update_layout(height=1000, width=1200, template=\"plotly_dark\")\n", + "\n", + "fig.update_layout(\n", + " legend1=dict(\n", + " yanchor=\"top\",\n", + " y=1.0,\n", + " xanchor=\"left\",\n", + " x=0,\n", + " orientation=\"h\"\n", + " ),\n", + " legend2=dict(\n", + " yanchor=\"top\",\n", + " y=.465,\n", + " xanchor=\"left\",\n", + " x=0,\n", + " orientation=\"h\"\n", + " ),\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48ccc984", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pytensor-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/doc/gallery/scan/numba_fib_scan.ipynb b/doc/gallery/scan/numba_fib_scan.ipynb new file mode 100644 index 0000000000..9b1b81c5e2 --- /dev/null +++ b/doc/gallery/scan/numba_fib_scan.ipynb @@ -0,0 +1,573 @@ +{ + "cells": [ + { + "cell_type": "code", + "id": "initial_id", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2025-10-07T10:20:23.561430Z", + "start_time": "2025-10-07T10:20:21.620124Z" + } + }, + "source": [ + "import pytensor\n", + "import pytensor.tensor as pt\n", + "import numpy as np\n", + "\n", + "N_STEPS = 1000\n", + "\n", + "b_symbolic = pt.scalar(\"b_symbolic\", dtype=\"int32\")\n", + "\n", + "def step(a, b):\n", + " return a + b, a\n", + "\n", + "(outputs_a, outputs_b), _ = pytensor.scan(\n", + " fn=step,\n", + " outputs_info=[pt.constant(1, dtype=\"int32\"), b_symbolic],\n", + " n_steps=N_STEPS\n", + ")\n", + "\n", + "# compile function returning final a\n", + "fibonacci_pytensor = pytensor.function([b_symbolic], outputs_a[-1], trust_input=True)\n", + "fibonacci_pytensor_numba = pytensor.function([b_symbolic], outputs_a[-1], mode='NUMBA', trust_input=True)" + ], + "outputs": [], + "execution_count": 1 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-07T10:32:18.289971Z", + "start_time": "2025-10-07T10:32:18.284515Z" + } + }, + "cell_type": "code", + "source": [ + "import numba\n", + "\n", + "@numba.njit\n", + "def fibonacci_numba_scalar(b):\n", + " b = b.copy()\n", + " a = np.ones((), dtype=np.int32)\n", + " for _ in range(N_STEPS):\n", + " a[()], b[()] = a[()] + b[()], a[()]\n", + " return a\n", + "\n", + "@numba.njit\n", + "def fibonacci_numba_array(b):\n", + " a = np.ones((), dtype=np.int32)\n", + " for _ in range(N_STEPS):\n", + " a, b = np.asarray(a + b), a\n", + " return a" + ], + "id": "b1d657d366647ada", + "outputs": [], + "execution_count": 66 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-07T10:32:19.607842Z", + "start_time": "2025-10-07T10:32:19.423324Z" + } + }, + "cell_type": "code", + "source": [ + "b = np.ones((), dtype=np.int32)\n", + "assert fibonacci_numba_array(b) == fibonacci_numba_scalar(b)" + ], + "id": "7f45c87d259852e6", + "outputs": [], + "execution_count": 67 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-07T10:32:22.705191Z", + "start_time": "2025-10-07T10:32:20.090353Z" + } + }, + "cell_type": "code", + "source": "%timeit fibonacci_numba_scalar(b)", + "id": "b01c8978960c6e3d", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3.21 μs ± 20.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" + ] + } + ], + "execution_count": 68 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-07T10:32:25.876514Z", + "start_time": "2025-10-07T10:32:23.122275Z" + } + }, + "cell_type": "code", + "source": "%timeit fibonacci_numba_array(b)", + "id": "bfc8794b219db03e", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "32.8 μs ± 2.48 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" + ] + } + ], + "execution_count": 69 + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "assert fibonacci_pytensor(b) == fibonacci_numba_scalar(b)\n", + "assert fibonacci_pytensor_numba(b) == fibonacci_numba_scalar(b)" + ], + "id": "a2185c1de1297a11" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-07T10:29:44.724064Z", + "start_time": "2025-10-07T10:29:42.655693Z" + } + }, + "cell_type": "code", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2.49 ms ± 327 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "execution_count": 54, + "source": "%timeit fibonacci_pytensor(b)", + "id": "f1e8bb6a0c673c8f" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-07T10:29:58.922566Z", + "start_time": "2025-10-07T10:29:44.752331Z" + } + }, + "cell_type": "code", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "175 μs ± 6.13 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" + ] + } + ], + "execution_count": 55, + "source": "%timeit fibonacci_pytensor_numba(b)", + "id": "17cd2859b4c6d3bd" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-07T10:30:11.832294Z", + "start_time": "2025-10-07T10:29:59.016709Z" + } + }, + "cell_type": "code", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "158 μs ± 706 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" + ] + } + ], + "execution_count": 56, + "source": "%timeit fibonacci_pytensor_numba.vm.jit_fn(b)", + "id": "6deb056f63953a42" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-07T10:20:58.015849Z", + "start_time": "2025-10-07T10:20:58.006831Z" + } + }, + "cell_type": "code", + "source": "fibonacci_pytensor_numba.dprint(print_type=True, print_memory_map=True)", + "id": "17580448648fdbcf", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Subtensor{i} [id A] v={0: [0]} 6\n", + " ├─ Scan{scan_fn, while_loop=False, inplace=all}.0 [id B] d={0: [1], 1: [2]} 5\n", + " │ ├─ 1000 [id C] \n", + " │ ├─ SetSubtensor{:stop} [id D] d={0: [0]} 4\n", + " │ │ ├─ AllocEmpty{dtype='int32'} [id E] 3\n", + " │ │ │ └─ 1 [id F] \n", + " │ │ ├─ [1] [id G] \n", + " │ │ └─ 1 [id H] \n", + " │ └─ SetSubtensor{:stop} [id I] d={0: [0]} 2\n", + " │ ├─ AllocEmpty{dtype='int32'} [id J] 1\n", + " │ │ └─ 2 [id K] \n", + " │ ├─ ExpandDims{axis=0} [id L] v={0: [0]} 0\n", + " │ │ └─ b_symbolic [id M] \n", + " │ └─ 1 [id H] \n", + " └─ 0 [id N] \n", + "\n", + "Inner graphs:\n", + "\n", + "Scan{scan_fn, while_loop=False, inplace=all} [id B] d={0: [1], 1: [2]}\n", + " ← Add [id O] \n", + " ├─ *0- [id P] -> [id D]\n", + " └─ *1- [id Q] -> [id I]\n", + " ← *0- [id P] -> [id D]\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 8 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-07T10:20:58.063585Z", + "start_time": "2025-10-07T10:20:58.059985Z" + } + }, + "cell_type": "code", + "source": "print(fibonacci_pytensor_numba.vm.jit_fn.py_func.__source__)", + "id": "f9806651f5146bbd", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "def numba_funcified_fgraph(b_symbolic):\n", + " # ExpandDims{axis=0}(b_symbolic)\n", + " tensor_variable = dimshuffle(b_symbolic)\n", + " # AllocEmpty{dtype='int32'}(2)\n", + " tensor_variable_1 = allocempty(tensor_constant)\n", + " # SetSubtensor{:stop}(AllocEmpty{dtype='int32'}.0, ExpandDims{axis=0}.0, 1)\n", + " tensor_variable_2 = set_subtensor(tensor_variable_1, tensor_variable, scalar_constant)\n", + " # AllocEmpty{dtype='int32'}(1)\n", + " tensor_variable_3 = allocempty_1(tensor_constant_1)\n", + " # SetSubtensor{:stop}(AllocEmpty{dtype='int32'}.0, [1], 1)\n", + " tensor_variable_4 = set_subtensor_1(tensor_variable_3, tensor_constant_2, scalar_constant)\n", + " # Scan{scan_fn, while_loop=False, inplace=all}(1000, SetSubtensor{:stop}.0, SetSubtensor{:stop}.0)\n", + " tensor_variable_5, tensor_variable_6 = scan(tensor_constant_3, tensor_variable_4, tensor_variable_2)\n", + " # Subtensor{i}(Scan{scan_fn, while_loop=False, inplace=all}.0, 0)\n", + " tensor_variable_7 = subtensor(tensor_variable_5, scalar_constant_1)\n", + " return (tensor_variable_7,)\n" + ] + } + ], + "execution_count": 9 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-07T10:20:58.113352Z", + "start_time": "2025-10-07T10:20:58.109693Z" + } + }, + "cell_type": "code", + "source": "print(fibonacci_pytensor_numba.vm.jit_fn.py_func.__globals__[\"allocempty\"].py_func.__source__)", + "id": "9995392081dcbffb", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "def allocempty(tensor_constant):\n", + " tensor_constant_item = to_scalar(tensor_constant)\n", + " scalar_shape = (tensor_constant_item, )\n", + " return np.empty(scalar_shape, dtype)\n", + " \n" + ] + } + ], + "execution_count": 10 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-07T10:20:58.162062Z", + "start_time": "2025-10-07T10:20:58.158525Z" + } + }, + "cell_type": "code", + "source": "print(fibonacci_pytensor_numba.vm.jit_fn.py_func.__globals__[\"set_subtensor\"].py_func.__source__)", + "id": "1f89dfa8b172fde9", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "def set_subtensor(tensor_variable, tensor_variable_1, scalar_constant):\n", + " z = tensor_variable\n", + " indices = (slice(None, scalar_constant, None),)\n", + " z[indices] = tensor_variable_1\n", + " return np.asarray(z)\n", + " \n" + ] + } + ], + "execution_count": 11 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-07T10:20:58.211349Z", + "start_time": "2025-10-07T10:20:58.207479Z" + } + }, + "cell_type": "code", + "source": "print(fibonacci_pytensor_numba.vm.jit_fn.py_func.__globals__[\"scan\"].py_func.__source__)", + "id": "648cd8952121141b", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "def scan(n_steps, outer_in_1, outer_in_2):\n", + "\n", + " outer_in_1_len = outer_in_1.shape[0]\n", + " outer_in_1_sitsot_storage = outer_in_1\n", + " outer_in_2_len = outer_in_2.shape[0]\n", + " outer_in_2_sitsot_storage = outer_in_2\n", + "\n", + " outer_in_1_sitsot_storage_temp_scalar_0 = np.empty((), dtype=np.int32)\n", + " outer_in_2_sitsot_storage_temp_scalar_0 = np.empty((), dtype=np.int32)\n", + "\n", + " i = 0\n", + " cond = np.array(False)\n", + " while i < n_steps and not cond.item():\n", + " outer_in_1_sitsot_storage_temp_scalar_0[()] = outer_in_1_sitsot_storage[(i) % outer_in_1_len]\n", + " outer_in_2_sitsot_storage_temp_scalar_0[()] = outer_in_2_sitsot_storage[(i) % outer_in_2_len]\n", + "\n", + " (inner_out_0, inner_out_1) = scan_inner_func(outer_in_1_sitsot_storage_temp_scalar_0, outer_in_2_sitsot_storage_temp_scalar_0)\n", + "\n", + " outer_in_1_sitsot_storage[(i + 1) % outer_in_1_len] = inner_out_0\n", + " outer_in_2_sitsot_storage[(i + 1) % outer_in_2_len] = inner_out_1\n", + " i += 1\n", + "\n", + " if 1 < outer_in_1_len < (i + 1):\n", + " outer_in_1_sitsot_storage_shift = (i + 1) % (outer_in_1_len)\n", + " if outer_in_1_sitsot_storage_shift > 0:\n", + " outer_in_1_sitsot_storage_left = outer_in_1_sitsot_storage[:outer_in_1_sitsot_storage_shift]\n", + " outer_in_1_sitsot_storage_right = outer_in_1_sitsot_storage[outer_in_1_sitsot_storage_shift:]\n", + " outer_in_1_sitsot_storage = np.concatenate((outer_in_1_sitsot_storage_right, outer_in_1_sitsot_storage_left))\n", + " if 1 < outer_in_2_len < (i + 1):\n", + " outer_in_2_sitsot_storage_shift = (i + 1) % (outer_in_2_len)\n", + " if outer_in_2_sitsot_storage_shift > 0:\n", + " outer_in_2_sitsot_storage_left = outer_in_2_sitsot_storage[:outer_in_2_sitsot_storage_shift]\n", + " outer_in_2_sitsot_storage_right = outer_in_2_sitsot_storage[outer_in_2_sitsot_storage_shift:]\n", + " outer_in_2_sitsot_storage = np.concatenate((outer_in_2_sitsot_storage_right, outer_in_2_sitsot_storage_left))\n", + "\n", + " return outer_in_1_sitsot_storage, outer_in_2_sitsot_storage\n", + " \n" + ] + } + ], + "execution_count": 12 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-07T10:30:25.495418Z", + "start_time": "2025-10-07T10:30:25.481386Z" + } + }, + "cell_type": "code", + "source": [ + "from pytensor.link.numba.dispatch.basic import tuple_setitem, to_scalar\n", + "\n", + "@numba.njit\n", + "def allocempty(s):\n", + " s_item = to_scalar(s)\n", + " scalar_shape = (s_item,)\n", + " return np.empty(scalar_shape, dtype=np.int32)\n", + "\n", + "@numba.njit\n", + "def subtensor(x, idx):\n", + " indices = (idx,)\n", + " z = x[indices]\n", + " return np.asarray(z)\n", + "\n", + "@numba.njit\n", + "def inner_scan_func(a, b):\n", + " res = a + b\n", + " return res, a\n", + "\n", + "@numba.njit\n", + "def scan_fib(n_steps, a_buf, b_buf):\n", + " a_buf_len = a_buf.shape[0]\n", + " b_buf_len = b_buf.shape[0]\n", + "\n", + " tmp_a_scalar = np.empty((), dtype=np.int32)\n", + " tmp_b_scalar = np.empty((), dtype=np.int32)\n", + "\n", + " i = 0\n", + " while i < n_steps:\n", + " tmp_a_scalar[()] = a_buf[i % a_buf_len]\n", + " tmp_b_scalar[()] = b_buf[i % b_buf_len]\n", + " next_a, next_b = inner_scan_func(tmp_a_scalar, tmp_b_scalar)\n", + " a_buf[(i + 1) % a_buf_len] = next_a\n", + " b_buf[(i + 1) % b_buf_len] = next_b\n", + " i += 1\n", + "\n", + " if 1 < a_buf_len < (i + 1):\n", + " a_buf_shift = (i + 1) % a_buf_len\n", + " if a_buf_shift > 0:\n", + " a_buf = np.concatenate((a_buf[a_buf_shift:], a_buf[:a_buf_shift]))\n", + " if 1 < b_buf_len < (i + 1):\n", + " b_buf_shift = (i + 1) % b_buf_len\n", + " if b_buf_shift > 0:\n", + " b_buf = np.concatenate((b_buf[b_buf_shift:], b_buf[:b_buf_shift]))\n", + "\n", + " return a_buf, b_buf\n", + "\n", + "@numba.njit\n", + "def set_subtensor(x, y, idx):\n", + " indices = (slice(None, idx, None),)\n", + " x[indices] = y\n", + " return np.asarray(x)\n", + "\n", + "@numba.njit\n", + "def dimshuffle(x):\n", + " old_shape = x.shape\n", + " old_strides = x.strides\n", + "\n", + " new_shape = (1,)\n", + " new_strides = (0,)\n", + " new_order = (-1,)\n", + " for i, o in enumerate(new_order):\n", + " if o != -1:\n", + " new_shape = tuple_setitem(new_shape, i, old_shape[o])\n", + " new_strides = tuple_setitem(new_strides, i, old_strides[o])\n", + "\n", + " return np.lib.stride_tricks.as_strided(x, shape=new_shape, strides=new_strides)\n", + " # return np.expand_dims(x, axis=0)\n", + "\n", + "@numba.njit\n", + "def comparable_fibonacci_numba(b):\n", + " a_buf = allocempty(np.array(1, dtype=np.int64))\n", + " # a_buf = np.empty(1, dtype=np.int32)\n", + " # a_buf[:1] = np.array([1], dtype=np.int32)\n", + " a_buf_set = set_subtensor(a_buf, np.array([1], dtype=np.int32), np.int64(1))\n", + "\n", + " b_buf = allocempty(np.array(2, dtype=np.int64))\n", + " # b_buf = np.empty(2, dtype=np.int32)\n", + " # b_buf[:1] = np.expand_dims(b, axis=0)\n", + " b_expanded = dimshuffle(b)\n", + " b_buf_set = set_subtensor(b_buf, b_expanded, np.int64(1))\n", + "\n", + " a_buf_updated, b_buf_updated = scan_fib(np.array(N_STEPS, np.int64), a_buf_set, b_buf_set)\n", + "\n", + " res = subtensor(a_buf_updated, np.uint8(0))\n", + "\n", + " return (res,)" + ], + "id": "bcefae049d4d2540", + "outputs": [], + "execution_count": 59 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-07T10:30:31.559493Z", + "start_time": "2025-10-07T10:30:30.263832Z" + } + }, + "cell_type": "code", + "source": [ + "b = np.ones((), dtype=np.int32)\n", + "assert comparable_fibonacci_numba(b) == fibonacci_numba_scalar(b)" + ], + "id": "65887ebba21f46c3", + "outputs": [], + "execution_count": 60 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-07T10:30:35.999409Z", + "start_time": "2025-10-07T10:30:31.567997Z" + } + }, + "cell_type": "code", + "source": "%timeit comparable_fibonacci_numba(b)", + "id": "2e0aba9917097009", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "54.6 μs ± 1.28 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" + ] + } + ], + "execution_count": 61 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-07T10:21:06.095536Z", + "start_time": "2025-10-07T10:21:06.093418Z" + } + }, + "cell_type": "code", + "source": "", + "id": "f7e8c11b24c366f1", + "outputs": [], + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}