From 87c30a1378c2ccdd61a8d256a4a7fc5a5aa347f8 Mon Sep 17 00:00:00 2001 From: kp992 Date: Fri, 5 Sep 2025 12:13:05 -0700 Subject: [PATCH 01/10] Use jax in kesten and replace numba --- lectures/kesten_processes.md | 177 +++++++++++++++++++++++------------ 1 file changed, 116 insertions(+), 61 deletions(-) diff --git a/lectures/kesten_processes.md b/lectures/kesten_processes.md index 1fe6e921b..88b8fa486 100644 --- a/lectures/kesten_processes.md +++ b/lectures/kesten_processes.md @@ -33,13 +33,12 @@ In addition to what's in Anaconda, this lecture will need the following librarie ```{code-cell} ipython3 :tags: [hide-output] -!pip install quantecon -!pip install --upgrade yfinance +!pip install --upgrade quantecon yfinance ``` ## Overview -{doc}`Previously ` we learned about linear scalar-valued stochastic processes (AR(1) models). +Previously in {doc}`intro:ar1_processes`, we learned about linear scalar-valued stochastic processes (AR(1) models). Now we generalize these linear models slightly by allowing the multiplicative coefficient to be stochastic. @@ -58,6 +57,7 @@ Let's start with some imports: import matplotlib.pyplot as plt import numpy as np import quantecon as qe +import yfinance as yf ``` The following two lines are only added to avoid a `FutureWarning` caused by @@ -65,13 +65,14 @@ compatibility issues between pandas and matplotlib. ```{code-cell} ipython3 from pandas.plotting import register_matplotlib_converters + register_matplotlib_converters() ``` Additional technical background related to this lecture can be found in the monograph of {cite}`buraczewski2016stochastic`. -## Kesten Processes +## Kesten processes ```{index} single: Kesten processes; heavy tails ``` @@ -97,7 +98,7 @@ In particular, we will assume that * $\{a_t\}_{t \geq 1}$ is a nonnegative IID stochastic process and * $\{\eta_t\}_{t \geq 1}$ is another nonnegative IID stochastic process, independent of the first. -### Example: GARCH Volatility +### Example: GARCH volatility The GARCH model is common in financial applications, where time series such as asset returns exhibit time varying volatility. @@ -107,18 +108,14 @@ Composite Index for the period 1st January 2006 to 1st November 2019. (ndcode)= ```{code-cell} ipython3 -import yfinance as yf - -s = yf.download('^IXIC', '2006-1-1', '2019-11-1', auto_adjust=False)['Adj Close'] +s = yf.download("^IXIC", "2006-1-1", "2019-11-1", auto_adjust=False)["Adj Close"] r = s.pct_change() fig, ax = plt.subplots() - ax.plot(r, alpha=0.7) - -ax.set_ylabel('returns', fontsize=12) -ax.set_xlabel('date', fontsize=12) +ax.set_ylabel("returns", fontsize=12) +ax.set_xlabel("date", fontsize=12) plt.show() ``` @@ -150,7 +147,7 @@ where $\{\zeta_t\}$ is again IID and independent of $\{\xi_t\}$. The volatility sequence $\{\sigma_t^2 \}$, which drives the dynamics of returns, is a Kesten process. -### Example: Wealth Dynamics +### Example: wealth dynamics Suppose that a given household saves a fixed fraction $s$ of its current wealth in every period. @@ -171,7 +168,7 @@ is a Kesten process. ### Stationarity -In earlier lectures, such as the one on {doc}`AR(1) processes `, we introduced the notion of a stationary distribution. +In earlier lectures, such as the one on {doc}`intro:ar1_processes`, we introduced the notion of a stationary distribution. In the present context, we can define a stationary distribution as follows: @@ -203,7 +200,7 @@ current state is drawn from $F^*$. The equality in {eq}`kp_stationary` states that this distribution is unchanged. -### Cross-Sectional Interpretation +### Cross-sectional interpretation There is an important cross-sectional interpretation of stationary distributions, discussed previously but worth repeating here. @@ -275,7 +272,7 @@ labor income has finite mean and $\mathbb E \ln R_t + \ln s < 0$. Under certain conditions, the stationary distribution of a Kesten process has a Pareto tail. -(See our {doc}`earlier lecture ` on heavy-tailed distributions for background.) +(See our {doc}`intro:heavy_tails` on heavy-tailed distributions for background.) This fact is significant for economics because of the prevalence of Pareto-tailed distributions. @@ -339,14 +336,16 @@ The spikes in the time series are visible in the following simulation, which gen μ = -0.5 σ = 1.0 + def kesten_ts(ts_length=100): x = np.zeros(ts_length) - for t in range(ts_length-1): + for t in range(ts_length - 1): a = np.exp(μ + σ * np.random.randn()) b = np.exp(np.random.randn()) - x[t+1] = a * x[t] + b + x[t + 1] = a * x[t] + b return x + fig, ax = plt.subplots() num_paths = 10 @@ -355,13 +354,13 @@ np.random.seed(12) for i in range(num_paths): ax.plot(kesten_ts()) -ax.set(xlabel='time', ylabel='$X_t$') +ax.set(xlabel="time", ylabel="$X_t$") plt.show() ``` ## Application: Firm Dynamics -As noted in our {doc}`lecture on heavy tails `, for common measures of firm size such as revenue or employment, the US firm size distribution exhibits a Pareto tail (see, e.g., {cite}`axtell2001zipf`, {cite}`gabaix2016power`). +As noted in our {doc}`intro:heavy_tails`, for common measures of firm size such as revenue or employment, the US firm size distribution exhibits a Pareto tail (see, e.g., {cite}`axtell2001zipf`, {cite}`gabaix2016power`). Let us try to explain this rather striking fact using the Kesten--Goldie Theorem. @@ -460,22 +459,24 @@ Here is one solution: years = 15 days = years * 250 + def garch_ts(ts_length=days): σ2 = 0 r = np.zeros(ts_length) - for t in range(ts_length-1): + for t in range(ts_length - 1): ξ = np.random.randn() σ2 = α_0 + σ2 * (α_1 * ξ**2 + β) r[t] = np.sqrt(σ2) * np.random.randn() return r + fig, ax = plt.subplots() np.random.seed(12) ax.plot(garch_ts(), alpha=0.7) -ax.set(xlabel='time', ylabel='$\\sigma_t^2$') +ax.set(xlabel="time", ylabel="$\\sigma_t^2$") plt.show() ``` @@ -653,16 +654,16 @@ In the simulation, assume that * the parameters are ```{code-cell} ipython3 -μ_a = -0.5 # location parameter for a -σ_a = 0.1 # scale parameter for a -μ_b = 0.0 # location parameter for b -σ_b = 0.5 # scale parameter for b -μ_e = 0.0 # location parameter for e -σ_e = 0.5 # scale parameter for e -s_bar = 1.0 # threshold -T = 500 # sampling date -M = 1_000_000 # number of firms -s_init = 1.0 # initial condition for each firm +μ_a = -0.5 # location parameter for a +σ_a = 0.1 # scale parameter for a +μ_b = 0.0 # location parameter for b +σ_b = 0.5 # scale parameter for b +μ_e = 0.0 # location parameter for e +σ_e = 0.5 # scale parameter for e +s_bar = 1.0 # threshold +T = 500 # sampling date +M = 1_000_000 # number of firms +s_init = 1.0 # initial condition for each firm ``` ```{exercise-end} @@ -676,37 +677,91 @@ Here's one solution. First we generate the observations: ```{code-cell} ipython3 -from numba import jit, prange -from numpy.random import randn - - -@jit(parallel=True) -def generate_draws(μ_a=-0.5, - σ_a=0.1, - μ_b=0.0, - σ_b=0.5, - μ_e=0.0, - σ_e=0.5, - s_bar=1.0, - T=500, - M=1_000_000, - s_init=1.0): - - draws = np.empty(M) - for m in prange(M): - s = s_init - for t in range(T): - if s < s_bar: - new_s = np.exp(μ_e + σ_e * randn()) - else: - a = np.exp(μ_a + σ_a * randn()) - b = np.exp(μ_b + σ_b * randn()) - new_s = a * s + b - s = new_s - draws[m] = s +import jax +import jax.numpy as jnp +from jax import random, vmap, jit + + +def generate_single_draw(key, μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar, T, s_init): + """Generate a single draw using JAX's scan for the time loop.""" + + def step_fn(carry, t): + s, subkey = carry + subkey, new_subkey = random.split(subkey) + + # Generate random normal samples + rand_normal = random.normal(new_subkey) + + # Conditional logic using jnp.where + # If s < s_bar: new_s = exp(μ_e + σ_e * randn()) + # Else: new_s = a * s + b where a = exp(μ_a + σ_a * randn()), b = exp(μ_b + σ_b * randn()) + + # For the else branch, we need two random numbers + subkey, key1, key2 = random.split(subkey, 3) + rand_a = random.normal(key1) + rand_b = random.normal(key2) + + # Calculate both possible new values + new_s_under_bar = jnp.exp(μ_e + σ_e * rand_normal) + + a = jnp.exp(μ_a + σ_a * rand_a) + b = jnp.exp(μ_b + σ_b * rand_b) + new_s_over_bar = a * s + b + + # Choose based on condition + new_s = jnp.where(s < s_bar, new_s_under_bar, new_s_over_bar) + + return (new_s, subkey), new_s + + # Initial state: (s_init, key) + init_carry = (s_init, key) + + # Run the scan + final_carry, _ = jax.lax.scan(step_fn, init_carry, jnp.arange(T)) + + # Return final s value + return final_carry[0] + + +generate_single_draw = jax.jit(generate_single_draw, static_argnums=(8,)) +``` + +```{code-cell} ipython3 +@jit +def generate_draws( + key=random.PRNGKey(123), + μ_a=-0.5, + σ_a=0.1, + μ_b=0.0, + σ_b=0.5, + μ_e=0.0, + σ_e=0.5, + s_bar=1.0, + T=500, + M=1_000_000, + s_init=1.0, +): + """ + JAX-jit version of the generate_draws function. + Returns: + Array of M draws + """ + # Create M different random keys for parallel execution + keys = random.split(key, M) + + # Use vmap to parallelize over the M dimension + vectorized_single_draw = vmap( + generate_single_draw, + in_axes=(0, None, None, None, None, None, None, None, None, None), + ) + + draws = vectorized_single_draw(keys, μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar, T, s_init) return draws +``` +```{code-cell} ipython3 +# Generate the observations data = generate_draws() ``` @@ -716,7 +771,7 @@ Now we produce the rank-size plot: fig, ax = plt.subplots() rank_data, size_data = qe.rank_size(data, c=0.01) -ax.loglog(rank_data, size_data, 'o', markersize=3.0, alpha=0.5) +ax.loglog(rank_data, size_data, "o", markersize=3.0, alpha=0.5) ax.set_xlabel("log rank") ax.set_ylabel("log size") From a451d228470b2481eb1e1e96751c311d4c87b68d Mon Sep 17 00:00:00 2001 From: kp992 Date: Fri, 5 Sep 2025 12:17:37 -0700 Subject: [PATCH 02/10] update section titles --- lectures/kesten_processes.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/lectures/kesten_processes.md b/lectures/kesten_processes.md index 88b8fa486..b9acdaae3 100644 --- a/lectures/kesten_processes.md +++ b/lectures/kesten_processes.md @@ -70,7 +70,7 @@ register_matplotlib_converters() ``` Additional technical background related to this lecture can be found in the -monograph of {cite}`buraczewski2016stochastic`. +monograph by {cite}`buraczewski2016stochastic`. ## Kesten processes @@ -238,7 +238,7 @@ next period as it is this period. Since $y$ was chosen arbitrarily, the distribution is unchanged. -### Conditions for Stationarity +### Conditions for stationarity The Kesten process $X_{t+1} = a_{t+1} X_t + \eta_{t+1}$ does not always have a stationary distribution. @@ -267,7 +267,7 @@ As one application of this result, we see that the wealth process {eq}`wealth_dynam` will have a unique stationary distribution whenever labor income has finite mean and $\mathbb E \ln R_t + \ln s < 0$. -## Heavy Tails +## Heavy tails Under certain conditions, the stationary distribution of a Kesten process has a Pareto tail. @@ -276,7 +276,7 @@ a Pareto tail. This fact is significant for economics because of the prevalence of Pareto-tailed distributions. -### The Kesten--Goldie Theorem +### The Kesten--Goldie theorem To state the conditions under which the stationary distribution of a Kesten process has a Pareto tail, we first recall that a random variable is called **nonarithmetic** if its distribution is not concentrated on $\{\dots, -2t, -t, 0, t, 2t, \ldots \}$ for any $t \geq 0$. @@ -358,13 +358,13 @@ ax.set(xlabel="time", ylabel="$X_t$") plt.show() ``` -## Application: Firm Dynamics +## Application: firm dynamics As noted in our {doc}`intro:heavy_tails`, for common measures of firm size such as revenue or employment, the US firm size distribution exhibits a Pareto tail (see, e.g., {cite}`axtell2001zipf`, {cite}`gabaix2016power`). Let us try to explain this rather striking fact using the Kesten--Goldie Theorem. -### Gibrat's Law +### Gibrat's law It was postulated many years ago by Robert Gibrat {cite}`gibrat1931inegalites` that firm size evolves according to a simple rule whereby size next period is proportional to current size. @@ -411,7 +411,7 @@ In the exercises you are asked to show that {eq}`firm_dynam` is more consistent with the empirical findings presented above than Gibrat's law in {eq}`firm_dynam_gb`. -### Heavy Tails +### Heavy tails So what has this to do with Pareto tails? From 5c7573c37827f2cd155d0c5f28eeb2116f9d7a0f Mon Sep 17 00:00:00 2001 From: kp992 Date: Mon, 8 Sep 2025 21:54:04 -0700 Subject: [PATCH 03/10] fix code line length --- lectures/kesten_processes.md | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/lectures/kesten_processes.md b/lectures/kesten_processes.md index b9acdaae3..e2c1871b5 100644 --- a/lectures/kesten_processes.md +++ b/lectures/kesten_processes.md @@ -108,7 +108,9 @@ Composite Index for the period 1st January 2006 to 1st November 2019. (ndcode)= ```{code-cell} ipython3 -s = yf.download("^IXIC", "2006-1-1", "2019-11-1", auto_adjust=False)["Adj Close"] +s = yf.download("^IXIC", "2006-1-1", "2019-11-1", auto_adjust=False)[ + "Adj Close" +] r = s.pct_change() @@ -694,7 +696,8 @@ def generate_single_draw(key, μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar, T, s_in # Conditional logic using jnp.where # If s < s_bar: new_s = exp(μ_e + σ_e * randn()) - # Else: new_s = a * s + b where a = exp(μ_a + σ_a * randn()), b = exp(μ_b + σ_b * randn()) + # Else: new_s = a * s + b + # where a = exp(μ_a + σ_a * randn()), b = exp(μ_b + σ_b * randn()) # For the else branch, we need two random numbers subkey, key1, key2 = random.split(subkey, 3) @@ -755,7 +758,9 @@ def generate_draws( in_axes=(0, None, None, None, None, None, None, None, None, None), ) - draws = vectorized_single_draw(keys, μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar, T, s_init) + draws = vectorized_single_draw( + keys, μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar, T, s_init + ) return draws ``` From cd4d5213e578099e747fb2c53224344d03c5d093 Mon Sep 17 00:00:00 2001 From: kp992 Date: Thu, 11 Sep 2025 20:03:54 -0700 Subject: [PATCH 04/10] apply suggestions --- lectures/kesten_processes.md | 38 ++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/lectures/kesten_processes.md b/lectures/kesten_processes.md index e2c1871b5..2c66cea3b 100644 --- a/lectures/kesten_processes.md +++ b/lectures/kesten_processes.md @@ -36,6 +36,14 @@ In addition to what's in Anaconda, this lecture will need the following librarie !pip install --upgrade quantecon yfinance ``` +Later in the exercise, we will use JAX to optimize our code + +```{code-cell} ipython3 +:tags: [no-execute] + +!pip install --upgrade jax +``` + ## Overview Previously in {doc}`intro:ar1_processes`, we learned about linear scalar-valued stochastic processes (AR(1) models). @@ -60,15 +68,6 @@ import quantecon as qe import yfinance as yf ``` -The following two lines are only added to avoid a `FutureWarning` caused by -compatibility issues between pandas and matplotlib. - -```{code-cell} ipython3 -from pandas.plotting import register_matplotlib_converters - -register_matplotlib_converters() -``` - Additional technical background related to this lecture can be found in the monograph by {cite}`buraczewski2016stochastic`. @@ -344,7 +343,7 @@ def kesten_ts(ts_length=100): for t in range(ts_length - 1): a = np.exp(μ + σ * np.random.randn()) b = np.exp(np.random.randn()) - x[t + 1] = a * x[t] + b + x[t+1] = a * x[t] + b return x @@ -502,7 +501,7 @@ In what sense is this true (or false)? The empirical findings are that 1. small firms grow faster than large firms and -1. the growth rate of small firms is more volatile than that of large firms. +2. the growth rate of small firms is more volatile than that of large firms. Also, Gibrat's law is generally found to be a reasonable approximation for large firms than for small firms @@ -729,6 +728,17 @@ def generate_single_draw(key, μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar, T, s_in generate_single_draw = jax.jit(generate_single_draw, static_argnums=(8,)) ``` +```{code-cell} ipython3 +# Use vmap to vectorize over the first argument (key) +in_axes = [None] * 10 +in_axes[0] = 0 + +vectorized_single_draw = vmap( + generate_single_draw, + in_axes=in_axes, +) +``` + ```{code-cell} ipython3 @jit def generate_draws( @@ -752,12 +762,6 @@ def generate_draws( # Create M different random keys for parallel execution keys = random.split(key, M) - # Use vmap to parallelize over the M dimension - vectorized_single_draw = vmap( - generate_single_draw, - in_axes=(0, None, None, None, None, None, None, None, None, None), - ) - draws = vectorized_single_draw( keys, μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar, T, s_init ) From cba480dc53733e1dd71fccdf3a9d8f5c3e21b1fb Mon Sep 17 00:00:00 2001 From: Humphrey Yang Date: Wed, 17 Sep 2025 11:00:52 +1000 Subject: [PATCH 05/10] take random key out of the default value --- lectures/kesten_processes.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lectures/kesten_processes.md b/lectures/kesten_processes.md index 2c66cea3b..99a6a239e 100644 --- a/lectures/kesten_processes.md +++ b/lectures/kesten_processes.md @@ -742,7 +742,7 @@ vectorized_single_draw = vmap( ```{code-cell} ipython3 @jit def generate_draws( - key=random.PRNGKey(123), + seed=0, μ_a=-0.5, σ_a=0.1, μ_b=0.0, @@ -760,6 +760,7 @@ def generate_draws( Array of M draws """ # Create M different random keys for parallel execution + key = random.PRNGKey(seed) keys = random.split(key, M) draws = vectorized_single_draw( From 6e44c3c635e62da68356eb525568f3fa80be27ee Mon Sep 17 00:00:00 2001 From: Humphrey Yang Date: Wed, 17 Sep 2025 11:22:22 +1000 Subject: [PATCH 06/10] attempt 2 to fix the pip install issue --- lectures/kesten_processes.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lectures/kesten_processes.md b/lectures/kesten_processes.md index 99a6a239e..7b80da8f1 100644 --- a/lectures/kesten_processes.md +++ b/lectures/kesten_processes.md @@ -4,7 +4,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.7 + jupytext_version: 1.16.6 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -39,7 +39,7 @@ In addition to what's in Anaconda, this lecture will need the following librarie Later in the exercise, we will use JAX to optimize our code ```{code-cell} ipython3 -:tags: [no-execute] +:tags: [skip-execution] !pip install --upgrade jax ``` From 1adfc936005bc6644866570f78188f1ac9a67149 Mon Sep 17 00:00:00 2001 From: kp992 Date: Wed, 17 Sep 2025 17:01:20 -0700 Subject: [PATCH 07/10] merge jax install --- lectures/kesten_processes.md | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/lectures/kesten_processes.md b/lectures/kesten_processes.md index 7b80da8f1..b0d9f884f 100644 --- a/lectures/kesten_processes.md +++ b/lectures/kesten_processes.md @@ -34,14 +34,8 @@ In addition to what's in Anaconda, this lecture will need the following librarie :tags: [hide-output] !pip install --upgrade quantecon yfinance -``` - -Later in the exercise, we will use JAX to optimize our code - -```{code-cell} ipython3 -:tags: [skip-execution] - -!pip install --upgrade jax +# Later in the exercise, we will use JAX to optimize our code +!pip install jax ``` ## Overview From 9bead7e22cf3ae0d285a01460f094b4824230d95 Mon Sep 17 00:00:00 2001 From: mmcky Date: Thu, 18 Sep 2025 10:16:58 +1000 Subject: [PATCH 08/10] remove comment --- lectures/kesten_processes.md | 1 - 1 file changed, 1 deletion(-) diff --git a/lectures/kesten_processes.md b/lectures/kesten_processes.md index b0d9f884f..09d593f78 100644 --- a/lectures/kesten_processes.md +++ b/lectures/kesten_processes.md @@ -34,7 +34,6 @@ In addition to what's in Anaconda, this lecture will need the following librarie :tags: [hide-output] !pip install --upgrade quantecon yfinance -# Later in the exercise, we will use JAX to optimize our code !pip install jax ``` From 7d008329870f95ce9735ae203607a7bb248a501a Mon Sep 17 00:00:00 2001 From: Humphrey Yang Date: Wed, 24 Sep 2025 15:42:24 +1000 Subject: [PATCH 09/10] update to use JAX entirely --- lectures/kesten_processes.md | 295 ++++++++++++++++++++++------------- 1 file changed, 183 insertions(+), 112 deletions(-) diff --git a/lectures/kesten_processes.md b/lectures/kesten_processes.md index 09d593f78..810e3f855 100644 --- a/lectures/kesten_processes.md +++ b/lectures/kesten_processes.md @@ -4,7 +4,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.6 + jupytext_version: 1.17.2 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -56,14 +56,53 @@ Let's start with some imports: ```{code-cell} ipython3 import matplotlib.pyplot as plt -import numpy as np import quantecon as qe import yfinance as yf +import jax +import jax.numpy as jnp +from jax import random, vmap, jit +from functools import partial +from typing import NamedTuple ``` Additional technical background related to this lecture can be found in the monograph by {cite}`buraczewski2016stochastic`. +We will use the following general-purpose function for generating time series paths + +```{code-cell} ipython3 +:tags: [hide-input] + +@partial(jax.jit, static_argnames=['f', 'num_steps']) +def generate_path(f, initial_state, num_steps, model, key): + """ + Generate a time series by repeatedly applying an update rule. + Given a map f, initial state x_0, and model parameters θ, this + function computes and returns the sequence {x_t}_{t=0}^{T-1} when + x_{t+1} = f(x_t, t, θ) + Args: + f: Update function mapping (x_t, t, model, key) -> x_{t+1} + initial_state: Initial state x_0 + num_steps: Number of time steps T to simulate + model: Model parameters + key: Random key for reproducible randomness + Returns: + Array of shape (dim(x), T) containing the time series path + [x_0, x_1, x_2, ..., x_{T-1}] + """ + def update_wrapper(carry, t): + """Wrapper function that adapts f for use with JAX scan.""" + state, subkey = carry + subkey, new_subkey = random.split(subkey) + next_state = f(state, t, model, new_subkey) + return (next_state, subkey), state + + # Initial carry: (initial_state, key) + init_carry = (initial_state, key) + _, path = jax.lax.scan(update_wrapper, init_carry, jnp.arange(num_steps)) + return path.T +``` + ## Kesten processes ```{index} single: Kesten processes; heavy tails @@ -327,26 +366,49 @@ This leads to spikes in the time series, which fill out the extreme right hand t The spikes in the time series are visible in the following simulation, which generates of 10 paths when $a_t$ and $b_t$ are lognormal. ```{code-cell} ipython3 -μ = -0.5 -σ = 1.0 +class KestenModel(NamedTuple): + """Parameters for Kesten process X_{t+1} = a_{t+1} X_t + η_{t+1}""" + μ: float = -0.5 # location parameter for log(a_t) + σ: float = 1.0 # scale parameter for log(a_t) -def kesten_ts(ts_length=100): - x = np.zeros(ts_length) - for t in range(ts_length - 1): - a = np.exp(μ + σ * np.random.randn()) - b = np.exp(np.random.randn()) - x[t+1] = a * x[t] + b - return x +@jax.jit +def kesten_update(current_x, time_step, model, key): + """ + Update function for Kesten process: X_{t+1} = a_{t+1} X_t + η_{t+1} + """ + # Split key for random number generation + key_a, key_η = random.split(key, 2) + # Generate random shocks + shock_a = random.normal(key_a) + shock_η = random.normal(key_η) + + # Compute a_t and η_t + a = jnp.exp(model.μ + model.σ * shock_a) + η = jnp.exp(shock_η) + + # Kesten process update + next_x = a * current_x + η + + return next_x fig, ax = plt.subplots() num_paths = 10 -np.random.seed(12) +model = KestenModel() for i in range(num_paths): - ax.plot(kesten_ts()) + key = random.PRNGKey(i) + + path = generate_path( + kesten_update, + initial_state=0.0, + num_steps=100, + model=model, + key=key + ) + ax.plot(path) ax.set(xlabel="time", ylabel="$X_t$") plt.show() @@ -446,31 +508,55 @@ While the time path differs, you should see bursts of high volatility. Here is one solution: ```{code-cell} ipython3 -α_0 = 1e-5 -α_1 = 0.1 -β = 0.9 +class GARCHModel(NamedTuple): + """Parameters for GARCH(1,1) volatility model""" + α_0: float = 1e-5 # constant term + α_1: float = 0.1 # coefficient on lagged squared shock + β: float = 0.9 # coefficient on lagged volatility years = 15 days = years * 250 +@jax.jit +def garch_update(current_state, time_step, model, key): + """Update function for GARCH(1,1) volatility and returns""" + σ2_current, r_previous = current_state + + # Split key for random number generation + key_xi, key_zeta = random.split(key, 2) + + # Generate random shocks + ξ = random.normal(key_xi) + ζ = random.normal(key_zeta) + + # Update volatility + σ2_next = model.α_0 + σ2_current * (model.α_1 * ξ**2 + model.β) -def garch_ts(ts_length=days): - σ2 = 0 - r = np.zeros(ts_length) - for t in range(ts_length - 1): - ξ = np.random.randn() - σ2 = α_0 + σ2 * (α_1 * ξ**2 + β) - r[t] = np.sqrt(σ2) * np.random.randn() - return r + # Generate return + r_current = jnp.sqrt(σ2_current) * ζ + return jnp.array([σ2_next, r_current]) fig, ax = plt.subplots() -np.random.seed(12) +key = random.PRNGKey(0) +model = GARCHModel() -ax.plot(garch_ts(), alpha=0.7) +# Initial state +initial_state = jnp.array([0.0, 0.0]) -ax.set(xlabel="time", ylabel="$\\sigma_t^2$") +path = generate_path( + garch_update, + initial_state=initial_state, + num_steps=days, + model=model, + key=key +) + +# Extract and plot returns +ax.plot(path[1, :], alpha=0.7) + +ax.set(xlabel="time", ylabel="returns") plt.show() ``` @@ -667,108 +753,93 @@ s_init = 1.0 # initial condition for each firm :class: dropdown ``` -Here's one solution. -First we generate the observations: - -```{code-cell} ipython3 -import jax -import jax.numpy as jnp -from jax import random, vmap, jit - - -def generate_single_draw(key, μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar, T, s_init): - """Generate a single draw using JAX's scan for the time loop.""" - - def step_fn(carry, t): - s, subkey = carry - subkey, new_subkey = random.split(subkey) - - # Generate random normal samples - rand_normal = random.normal(new_subkey) - - # Conditional logic using jnp.where - # If s < s_bar: new_s = exp(μ_e + σ_e * randn()) - # Else: new_s = a * s + b - # where a = exp(μ_a + σ_a * randn()), b = exp(μ_b + σ_b * randn()) - - # For the else branch, we need two random numbers - subkey, key1, key2 = random.split(subkey, 3) - rand_a = random.normal(key1) - rand_b = random.normal(key2) +Here's one solution using the `generate_path` framework. - # Calculate both possible new values - new_s_under_bar = jnp.exp(μ_e + σ_e * rand_normal) +First, we define the firm productivity update function: - a = jnp.exp(μ_a + σ_a * rand_a) - b = jnp.exp(μ_b + σ_b * rand_b) - new_s_over_bar = a * s + b - - # Choose based on condition - new_s = jnp.where(s < s_bar, new_s_under_bar, new_s_over_bar) - - return (new_s, subkey), new_s - - # Initial state: (s_init, key) - init_carry = (s_init, key) - - # Run the scan - final_carry, _ = jax.lax.scan(step_fn, init_carry, jnp.arange(T)) - - # Return final s value - return final_carry[0] +```{code-cell} ipython3 +@jax.jit +def firm_product_update(current_product, time_step, model, key): + """ + Update firm productivity according to entry/exit dynamics. + If productivity is below threshold: firm exits and is replaced by new entrant + If productivity is above threshold: productivity evolves as Kesten process + """ + # Split key for random number generation + key_a, key_η, key_e = random.split(key, 3) + + # Generate random shocks + shock_a = random.normal(key_a) + shock_η = random.normal(key_η) + shock_e = random.normal(key_e) + + # Calculate potential new productivity values + # If firm exits (s_t < s_bar): replaced by new entrant + product_entrant = jnp.exp(model.μ_e + model.σ_e * shock_e) + + # If firm continues (s_t >= s_bar): Kesten process dynamics + a = jnp.exp(model.μ_a + model.σ_a * shock_a) + η = jnp.exp(model.μ_b + model.σ_b * shock_η) + product_incumbent = a * current_product + η + + # Apply entry/exit rule + new_product = jnp.where( + current_product < model.s_bar, + product_entrant, + product_incumbent + ) -generate_single_draw = jax.jit(generate_single_draw, static_argnums=(8,)) + return new_product ``` -```{code-cell} ipython3 -# Use vmap to vectorize over the first argument (key) -in_axes = [None] * 10 -in_axes[0] = 0 +Now we define a model container for parameters -vectorized_single_draw = vmap( - generate_single_draw, - in_axes=in_axes, -) +```{code-cell} ipython3 +class FirmDynamicsModel(NamedTuple): + """Parameters for firm dynamics with entry/exit""" + μ_a: float = -0.5 # location parameter for log(a_t) + σ_a: float = 0.1 # scale parameter for log(a_t) + μ_b: float = 0.0 # location parameter for log(η_t) + σ_b: float = 0.5 # scale parameter for log(η_t) + μ_e: float = 0.0 # location parameter for log(e_t) + σ_e: float = 0.5 # scale parameter for log(e_t) + s_bar: float = 1.0 # exit threshold ``` +Now we generate multiple firm trajectories in parallel + ```{code-cell} ipython3 -@jit -def generate_draws( - seed=0, - μ_a=-0.5, - σ_a=0.1, - μ_b=0.0, - σ_b=0.5, - μ_e=0.0, - σ_e=0.5, - s_bar=1.0, - T=500, - M=1_000_000, - s_init=1.0, -): - """ - JAX-jit version of the generate_draws function. - Returns: - Array of M draws - """ - # Create M different random keys for parallel execution +def generate_firm_distribution(model, + seed=0, M=1_000_000, T=500, s_init=1.0): + """Generate distribution of firm productivities after T periods.""" + + # Create random keys for each firm key = random.PRNGKey(seed) keys = random.split(key, M) - draws = vectorized_single_draw( - keys, μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar, T, s_init - ) + @jax.jit + def single_firm_path(firm_key): + # Generate path and return final productivity + path = generate_path( + firm_product_update, + initial_state=s_init, + num_steps=T, + model=model, + key=firm_key + ) + return path[-1] - return draws -``` + # Apply to all firms in parallel + product_dist = vmap(single_firm_path)(keys) -```{code-cell} ipython3 -# Generate the observations -data = generate_draws() + return product_dist + +# Generate the data +data = generate_firm_distribution(FirmDynamicsModel()) ``` -Now we produce the rank-size plot: +Let's produce the rank-size plot ```{code-cell} ipython3 fig, ax = plt.subplots() From 4f4c0f50476d708705399b8a3dbca30066baa7d7 Mon Sep 17 00:00:00 2001 From: Humphrey Yang Date: Mon, 29 Sep 2025 22:49:49 +1000 Subject: [PATCH 10/10] Revert "update to use JAX entirely" This reverts commit 7d008329870f95ce9735ae203607a7bb248a501a. --- lectures/kesten_processes.md | 295 +++++++++++++---------------------- 1 file changed, 112 insertions(+), 183 deletions(-) diff --git a/lectures/kesten_processes.md b/lectures/kesten_processes.md index 810e3f855..09d593f78 100644 --- a/lectures/kesten_processes.md +++ b/lectures/kesten_processes.md @@ -4,7 +4,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.17.2 + jupytext_version: 1.16.6 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -56,53 +56,14 @@ Let's start with some imports: ```{code-cell} ipython3 import matplotlib.pyplot as plt +import numpy as np import quantecon as qe import yfinance as yf -import jax -import jax.numpy as jnp -from jax import random, vmap, jit -from functools import partial -from typing import NamedTuple ``` Additional technical background related to this lecture can be found in the monograph by {cite}`buraczewski2016stochastic`. -We will use the following general-purpose function for generating time series paths - -```{code-cell} ipython3 -:tags: [hide-input] - -@partial(jax.jit, static_argnames=['f', 'num_steps']) -def generate_path(f, initial_state, num_steps, model, key): - """ - Generate a time series by repeatedly applying an update rule. - Given a map f, initial state x_0, and model parameters θ, this - function computes and returns the sequence {x_t}_{t=0}^{T-1} when - x_{t+1} = f(x_t, t, θ) - Args: - f: Update function mapping (x_t, t, model, key) -> x_{t+1} - initial_state: Initial state x_0 - num_steps: Number of time steps T to simulate - model: Model parameters - key: Random key for reproducible randomness - Returns: - Array of shape (dim(x), T) containing the time series path - [x_0, x_1, x_2, ..., x_{T-1}] - """ - def update_wrapper(carry, t): - """Wrapper function that adapts f for use with JAX scan.""" - state, subkey = carry - subkey, new_subkey = random.split(subkey) - next_state = f(state, t, model, new_subkey) - return (next_state, subkey), state - - # Initial carry: (initial_state, key) - init_carry = (initial_state, key) - _, path = jax.lax.scan(update_wrapper, init_carry, jnp.arange(num_steps)) - return path.T -``` - ## Kesten processes ```{index} single: Kesten processes; heavy tails @@ -366,49 +327,26 @@ This leads to spikes in the time series, which fill out the extreme right hand t The spikes in the time series are visible in the following simulation, which generates of 10 paths when $a_t$ and $b_t$ are lognormal. ```{code-cell} ipython3 -class KestenModel(NamedTuple): - """Parameters for Kesten process X_{t+1} = a_{t+1} X_t + η_{t+1}""" - μ: float = -0.5 # location parameter for log(a_t) - σ: float = 1.0 # scale parameter for log(a_t) +μ = -0.5 +σ = 1.0 -@jax.jit -def kesten_update(current_x, time_step, model, key): - """ - Update function for Kesten process: X_{t+1} = a_{t+1} X_t + η_{t+1} - """ - # Split key for random number generation - key_a, key_η = random.split(key, 2) +def kesten_ts(ts_length=100): + x = np.zeros(ts_length) + for t in range(ts_length - 1): + a = np.exp(μ + σ * np.random.randn()) + b = np.exp(np.random.randn()) + x[t+1] = a * x[t] + b + return x - # Generate random shocks - shock_a = random.normal(key_a) - shock_η = random.normal(key_η) - - # Compute a_t and η_t - a = jnp.exp(model.μ + model.σ * shock_a) - η = jnp.exp(shock_η) - - # Kesten process update - next_x = a * current_x + η - - return next_x fig, ax = plt.subplots() num_paths = 10 -model = KestenModel() +np.random.seed(12) for i in range(num_paths): - key = random.PRNGKey(i) - - path = generate_path( - kesten_update, - initial_state=0.0, - num_steps=100, - model=model, - key=key - ) - ax.plot(path) + ax.plot(kesten_ts()) ax.set(xlabel="time", ylabel="$X_t$") plt.show() @@ -508,55 +446,31 @@ While the time path differs, you should see bursts of high volatility. Here is one solution: ```{code-cell} ipython3 -class GARCHModel(NamedTuple): - """Parameters for GARCH(1,1) volatility model""" - α_0: float = 1e-5 # constant term - α_1: float = 0.1 # coefficient on lagged squared shock - β: float = 0.9 # coefficient on lagged volatility +α_0 = 1e-5 +α_1 = 0.1 +β = 0.9 years = 15 days = years * 250 -@jax.jit -def garch_update(current_state, time_step, model, key): - """Update function for GARCH(1,1) volatility and returns""" - σ2_current, r_previous = current_state - - # Split key for random number generation - key_xi, key_zeta = random.split(key, 2) - - # Generate random shocks - ξ = random.normal(key_xi) - ζ = random.normal(key_zeta) - - # Update volatility - σ2_next = model.α_0 + σ2_current * (model.α_1 * ξ**2 + model.β) - # Generate return - r_current = jnp.sqrt(σ2_current) * ζ +def garch_ts(ts_length=days): + σ2 = 0 + r = np.zeros(ts_length) + for t in range(ts_length - 1): + ξ = np.random.randn() + σ2 = α_0 + σ2 * (α_1 * ξ**2 + β) + r[t] = np.sqrt(σ2) * np.random.randn() + return r - return jnp.array([σ2_next, r_current]) fig, ax = plt.subplots() -key = random.PRNGKey(0) -model = GARCHModel() +np.random.seed(12) -# Initial state -initial_state = jnp.array([0.0, 0.0]) +ax.plot(garch_ts(), alpha=0.7) -path = generate_path( - garch_update, - initial_state=initial_state, - num_steps=days, - model=model, - key=key -) - -# Extract and plot returns -ax.plot(path[1, :], alpha=0.7) - -ax.set(xlabel="time", ylabel="returns") +ax.set(xlabel="time", ylabel="$\\sigma_t^2$") plt.show() ``` @@ -753,93 +667,108 @@ s_init = 1.0 # initial condition for each firm :class: dropdown ``` -Here's one solution using the `generate_path` framework. - -First, we define the firm productivity update function: +Here's one solution. +First we generate the observations: ```{code-cell} ipython3 -@jax.jit -def firm_product_update(current_product, time_step, model, key): - """ - Update firm productivity according to entry/exit dynamics. +import jax +import jax.numpy as jnp +from jax import random, vmap, jit - If productivity is below threshold: firm exits and is replaced by new entrant - If productivity is above threshold: productivity evolves as Kesten process - """ - # Split key for random number generation - key_a, key_η, key_e = random.split(key, 3) - - # Generate random shocks - shock_a = random.normal(key_a) - shock_η = random.normal(key_η) - shock_e = random.normal(key_e) - - # Calculate potential new productivity values - # If firm exits (s_t < s_bar): replaced by new entrant - product_entrant = jnp.exp(model.μ_e + model.σ_e * shock_e) - - # If firm continues (s_t >= s_bar): Kesten process dynamics - a = jnp.exp(model.μ_a + model.σ_a * shock_a) - η = jnp.exp(model.μ_b + model.σ_b * shock_η) - product_incumbent = a * current_product + η - - # Apply entry/exit rule - new_product = jnp.where( - current_product < model.s_bar, - product_entrant, - product_incumbent - ) - return new_product -``` +def generate_single_draw(key, μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar, T, s_init): + """Generate a single draw using JAX's scan for the time loop.""" -Now we define a model container for parameters + def step_fn(carry, t): + s, subkey = carry + subkey, new_subkey = random.split(subkey) -```{code-cell} ipython3 -class FirmDynamicsModel(NamedTuple): - """Parameters for firm dynamics with entry/exit""" - μ_a: float = -0.5 # location parameter for log(a_t) - σ_a: float = 0.1 # scale parameter for log(a_t) - μ_b: float = 0.0 # location parameter for log(η_t) - σ_b: float = 0.5 # scale parameter for log(η_t) - μ_e: float = 0.0 # location parameter for log(e_t) - σ_e: float = 0.5 # scale parameter for log(e_t) - s_bar: float = 1.0 # exit threshold -``` + # Generate random normal samples + rand_normal = random.normal(new_subkey) + + # Conditional logic using jnp.where + # If s < s_bar: new_s = exp(μ_e + σ_e * randn()) + # Else: new_s = a * s + b + # where a = exp(μ_a + σ_a * randn()), b = exp(μ_b + σ_b * randn()) + + # For the else branch, we need two random numbers + subkey, key1, key2 = random.split(subkey, 3) + rand_a = random.normal(key1) + rand_b = random.normal(key2) + + # Calculate both possible new values + new_s_under_bar = jnp.exp(μ_e + σ_e * rand_normal) + + a = jnp.exp(μ_a + σ_a * rand_a) + b = jnp.exp(μ_b + σ_b * rand_b) + new_s_over_bar = a * s + b + + # Choose based on condition + new_s = jnp.where(s < s_bar, new_s_under_bar, new_s_over_bar) + + return (new_s, subkey), new_s + + # Initial state: (s_init, key) + init_carry = (s_init, key) + + # Run the scan + final_carry, _ = jax.lax.scan(step_fn, init_carry, jnp.arange(T)) -Now we generate multiple firm trajectories in parallel + # Return final s value + return final_carry[0] + + +generate_single_draw = jax.jit(generate_single_draw, static_argnums=(8,)) +``` ```{code-cell} ipython3 -def generate_firm_distribution(model, - seed=0, M=1_000_000, T=500, s_init=1.0): - """Generate distribution of firm productivities after T periods.""" +# Use vmap to vectorize over the first argument (key) +in_axes = [None] * 10 +in_axes[0] = 0 - # Create random keys for each firm +vectorized_single_draw = vmap( + generate_single_draw, + in_axes=in_axes, +) +``` + +```{code-cell} ipython3 +@jit +def generate_draws( + seed=0, + μ_a=-0.5, + σ_a=0.1, + μ_b=0.0, + σ_b=0.5, + μ_e=0.0, + σ_e=0.5, + s_bar=1.0, + T=500, + M=1_000_000, + s_init=1.0, +): + """ + JAX-jit version of the generate_draws function. + Returns: + Array of M draws + """ + # Create M different random keys for parallel execution key = random.PRNGKey(seed) keys = random.split(key, M) - @jax.jit - def single_firm_path(firm_key): - # Generate path and return final productivity - path = generate_path( - firm_product_update, - initial_state=s_init, - num_steps=T, - model=model, - key=firm_key - ) - return path[-1] - - # Apply to all firms in parallel - product_dist = vmap(single_firm_path)(keys) + draws = vectorized_single_draw( + keys, μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar, T, s_init + ) - return product_dist + return draws +``` -# Generate the data -data = generate_firm_distribution(FirmDynamicsModel()) +```{code-cell} ipython3 +# Generate the observations +data = generate_draws() ``` -Let's produce the rank-size plot +Now we produce the rank-size plot: ```{code-cell} ipython3 fig, ax = plt.subplots()