From 689fd06baef105b8f36a5273bff21fd085fafd32 Mon Sep 17 00:00:00 2001 From: kp992 Date: Fri, 5 Sep 2025 11:41:36 -0700 Subject: [PATCH 1/4] update the code to latest style --- lectures/kesten_processes.md | 97 +++++++++++++++--------------------- 1 file changed, 39 insertions(+), 58 deletions(-) diff --git a/lectures/kesten_processes.md b/lectures/kesten_processes.md index 5fe95ad5..ad5edbb5 100644 --- a/lectures/kesten_processes.md +++ b/lectures/kesten_processes.md @@ -4,7 +4,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.17.2 + jupytext_version: 1.16.7 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -32,7 +32,7 @@ In addition to JAX and Anaconda, this lecture will need the following libraries: ```{code-cell} ipython3 :tags: [hide-output] -!pip install quantecon +!pip install --upgrade quantecon ``` ## Overview @@ -40,8 +40,7 @@ In addition to JAX and Anaconda, this lecture will need the following libraries: This lecture describes Kesten processes, which are an important class of stochastic processes, and an application of firm dynamics. -The lecture draws on [an earlier QuantEcon lecture](https://python.quantecon.org/kesten_processes.html), -which uses Numba to accelerate the computations. +The lecture draws on {doc}`intermediate:kesten_processes`. In that earlier lecture you can find a more detailed discussion of the concepts involved. @@ -55,8 +54,6 @@ import quantecon as qe import jax import jax.numpy as jnp from jax import random -from jax import lax -from quantecon import tic, toc from typing import NamedTuple from functools import partial ``` @@ -136,7 +133,7 @@ We now study the implications of this specification. #### Heavy tails -If the conditions of the [Kesten--Goldie Theorem](https://python.quantecon.org/kesten_processes.html#the-kestengoldie-theorem) +If the conditions of the {doc}`intermediate:kesten_processes#the-kestengoldie-theorem` are satisfied, then {eq}`firm_dynam` implies that the firm size distribution will have Pareto tails. This matches empirical findings across many data sets. @@ -154,8 +151,7 @@ In this setting, firm dynamics can be expressed as (a_{t+1} s_t + b_{t+1}) \mathbb{1}\{s_t \geq \bar s\} ``` -The motivation behind and interpretation of [](firm_dynam_ee) can be found in -[our earlier Kesten process lecture](https://python.quantecon.org/kesten_processes.html). +The motivation behind and interpretation of [](firm_dynam_ee) can be found in {doc}`intermediate:kesten_processes`. What can we say about dynamics? @@ -180,12 +176,12 @@ Here's a class to store parameters: ```{code-cell} ipython3 class Firm(NamedTuple): - μ_a: float = -0.5 - σ_a: float = 0.1 - μ_b: float = 0.0 - σ_b: float = 0.5 - μ_e: float = 0.0 - σ_e: float = 0.5 + μ_a: float = -0.5 + σ_a: float = 0.1 + μ_b: float = 0.0 + σ_b: float = 0.5 + μ_e: float = 0.0 + σ_e: float = 0.5 s_bar: float = 1.0 ``` @@ -207,15 +203,13 @@ For sufficiently large `T`, the cross-section it returns (the cross-section at time `T`) corresponds to firm size distribution in (approximate) equilibrium. ```{code-cell} ipython3 -def generate_cross_section( - firm, M=500_000, T=500, s_init=1.0, seed=123 - ): +def generate_cross_section(firm, M=500_000, T=500, s_init=1.0, seed=123): μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar = firm key = random.PRNGKey(seed) # Initialize the cross-section to a common value - s = jnp.full((M, ), s_init) + s = jnp.full((M,), s_init) # Perform updates on s for time t for t in range(T): @@ -235,17 +229,15 @@ Let's try running the code and generating a cross-section. ```{code-cell} ipython3 firm = Firm() -tic() -data = generate_cross_section(firm).block_until_ready() -toc() +with qe.Timer(): + data = generate_cross_section(firm).block_until_ready() ``` We run the function again so we can see the speed without compile time. ```{code-cell} ipython3 -tic() -data = generate_cross_section(firm).block_until_ready() -toc() +with qe.Timer(): + data = generate_cross_section(firm).block_until_ready() ``` Let's produce the rank-size plot and check the distribution: @@ -254,7 +246,7 @@ Let's produce the rank-size plot and check the distribution: fig, ax = plt.subplots() rank_data, size_data = qe.rank_size(data, c=0.01) -ax.loglog(rank_data, size_data, 'o', markersize=3.0, alpha=0.5) +ax.loglog(rank_data, size_data, "o", markersize=3.0, alpha=0.5) ax.set_xlabel("log rank") ax.set_ylabel("log size") @@ -279,20 +271,18 @@ Here a the `lax.fori_loop` version: ```{code-cell} ipython3 @jax.jit -def generate_cross_section_lax( - firm, T=500, M=500_000, s_init=1.0, seed=123 - ): +def generate_cross_section_lax(firm, T=500, M=500_000, s_init=1.0, seed=123): μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar = firm key = random.PRNGKey(seed) - + # Initial cross section - s = jnp.full((M, ), s_init) + s = jnp.full((M,), s_init) def update_cross_section(t, state): s, key = state key, *subkeys = jax.random.split(key, 4) - # Generate current random draws + # Generate current random draws a = μ_a + σ_a * random.normal(subkeys[0], (M,)) b = μ_b + σ_b * random.normal(subkeys[1], (M,)) e = μ_e + σ_e * random.normal(subkeys[2], (M,)) @@ -303,26 +293,22 @@ def generate_cross_section_lax( new_state = s, key return new_state - # Use fori_loop + # Use fori_loop initial_state = s, key - final_s, final_key = lax.fori_loop( - 0, T, update_cross_section, initial_state - ) + final_s, final_key = jax.lax.fori_loop(0, T, update_cross_section, initial_state) return final_s ``` Let's see if we get any speed gain ```{code-cell} ipython3 -tic() -data = generate_cross_section_lax(firm).block_until_ready() -toc() +with qe.Timer(): + data = generate_cross_section_lax(firm).block_until_ready() ``` ```{code-cell} ipython3 -tic() -data = generate_cross_section_lax(firm).block_until_ready() -toc() +with qe.Timer(): + data = generate_cross_section_lax(firm).block_until_ready() ``` Here we produce the same rank-size plot: @@ -331,12 +317,11 @@ Here we produce the same rank-size plot: fig, ax = plt.subplots() rank_data, size_data = qe.rank_size(data, c=0.01) -ax.loglog(rank_data, size_data, 'o', markersize=3.0, alpha=0.5) +ax.loglog(rank_data, size_data, "o", markersize=3.0, alpha=0.5) ax.set_xlabel("log rank") ax.set_ylabel("log size") plt.show() - ``` ## Exercises @@ -362,22 +347,20 @@ What are the pros and cons of this approach? ```{code-cell} ipython3 @jax.jit -def generate_cross_section_lax( - firm, T=500, M=500_000, s_init=1.0, seed=123 - ): +def generate_cross_section_lax(firm, T=500, M=500_000, s_init=1.0, seed=123): μ_a, σ_a, μ_b, σ_b, μ_e, σ_e, s_bar = firm key = random.PRNGKey(seed) subkey_1, subkey_2, subkey_3 = random.split(key, 3) - - # Generate entire sequence of random draws + + # Generate entire sequence of random draws a = μ_a + σ_a * random.normal(subkey_1, (T, M)) b = μ_b + σ_b * random.normal(subkey_2, (T, M)) e = μ_e + σ_e * random.normal(subkey_3, (T, M)) # Exponentiate them a, b, e = jax.tree.map(jnp.exp, (a, b, e)) # Initial cross section - s = jnp.full((M, ), s_init) + s = jnp.full((M,), s_init) def update_cross_section(t, s): # Pull out the t-th cross-section of shocks @@ -385,23 +368,21 @@ def generate_cross_section_lax( s = jnp.where(s < s_bar, e_t, a_t * s + b_t) return s - # Use lax.scan to perform the calculations on all states - s_final = lax.fori_loop(0, T, update_cross_section, s) + # Use lax.fori_loop to perform the calculations on all states + s_final = jax.lax.fori_loop(0, T, update_cross_section, s) return s_final ``` Here are the run times. ```{code-cell} ipython3 -tic() -data = generate_cross_section_lax(firm).block_until_ready() -toc() +with qe.Timer(): + data = generate_cross_section_lax(firm).block_until_ready() ``` ```{code-cell} ipython3 -tic() -data = generate_cross_section_lax(firm).block_until_ready() -toc() +with qe.Timer(): + data = generate_cross_section_lax(firm).block_until_ready() ``` This method might or might not be faster. From b1590503cb665e3fa00d978a8974f9d7f6e3573b Mon Sep 17 00:00:00 2001 From: kp992 Date: Fri, 5 Sep 2025 11:42:40 -0700 Subject: [PATCH 2/4] fix jax.lax --- lectures/kesten_processes.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lectures/kesten_processes.md b/lectures/kesten_processes.md index ad5edbb5..ba13d8f3 100644 --- a/lectures/kesten_processes.md +++ b/lectures/kesten_processes.md @@ -256,7 +256,7 @@ plt.show() The plot produces a straight line, consistent with a Pareto tail. -#### Alternative implementation with `lax.fori_loop` +#### Alternative implementation with `jax.lax.fori_loop` Although we JIT-compiled some of the code above, we did not JIT-compile the `for` loop. @@ -264,10 +264,10 @@ we did not JIT-compile the `for` loop. Let's try squeezing out a bit more speed by -* replacing the `for` loop with [`lax.fori_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html) and +* replacing the `for` loop with [`jax.lax.fori_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html) and * JIT-compiling the whole function. -Here a the `lax.fori_loop` version: +Here a the `jax.lax.fori_loop` version: ```{code-cell} ipython3 @jax.jit From 05e2a35d054620b9c83364938a322dbf6bd88e38 Mon Sep 17 00:00:00 2001 From: kp992 Date: Fri, 5 Sep 2025 11:47:27 -0700 Subject: [PATCH 3/4] fix typos --- lectures/kesten_processes.md | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/lectures/kesten_processes.md b/lectures/kesten_processes.md index ba13d8f3..8b9c8b6f 100644 --- a/lectures/kesten_processes.md +++ b/lectures/kesten_processes.md @@ -38,11 +38,11 @@ In addition to JAX and Anaconda, this lecture will need the following libraries: ## Overview This lecture describes Kesten processes, which are an important class of -stochastic processes, and an application of firm dynamics. +stochastic processes, and their application to firm dynamics. The lecture draws on {doc}`intermediate:kesten_processes`. -In that earlier lecture you can find a more detailed discussion of the concepts involved. +In that earlier lecture, you can find a more detailed discussion of the concepts involved. This lecture focuses on implementing the same computations in JAX. @@ -58,7 +58,7 @@ from typing import NamedTuple from functools import partial ``` -Let's check the GPU we are running +Let's check the GPU we are running on ```{code-cell} ipython3 !nvidia-smi @@ -82,19 +82,17 @@ sequences. We are interested in the dynamics of $\{X_t\}_{t \geq 0}$ when $X_0$ is given. -We will focus on the nonnegative scalar case, where $X_t$ takes values in $\mathbb R_+$. +We will focus on the nonnegative scalar case, where $X_t$ takes values in $\mathbb{R}_+$. In particular, we will assume that * the initial condition $X_0$ is nonnegative, -* $\{a_t\}_{t \geq 1}$ is a nonnegative IID stochastic process and +* $\{a_t\}_{t \geq 1}$ is a nonnegative IID stochastic process, and * $\{\eta_t\}_{t \geq 1}$ is another nonnegative IID stochastic process, independent of the first. - ### Application: firm dynamics -In this section we apply Kesten process theory to the study of firm dynamics. - +In this section, we apply Kesten process theory to the study of firm dynamics. #### Gibrat's law @@ -118,7 +116,7 @@ for some positive IID sequence $\{a_t\}$. Subsequent empirical research has shown that this specification is not accurate, particularly for small firms. -However, we can get close to the data by modifying {eq}`firm_dynam_gb` to +However, we can get closer to the data by modifying {eq}`firm_dynam_gb` to ```{math} :label: firm_dynam @@ -153,7 +151,7 @@ In this setting, firm dynamics can be expressed as The motivation behind and interpretation of [](firm_dynam_ee) can be found in {doc}`intermediate:kesten_processes`. -What can we say about dynamics? +What can we say about the dynamics? Although {eq}`firm_dynam_ee` is not a Kesten process, it does update in the same way as a Kesten process when $s_t$ is large. @@ -164,8 +162,8 @@ We can investigate this question via simulation and rank-size plots. The approach will be to -1. generate $M$ draws of $s_T$ when $M$ and $T$ are large and -1. plot the largest 1,000 of the resulting draws in a rank-size plot. +1. generate $M$ draws of $s_T$ when $M$ and $T$ are large, and +2. plot the largest 1,000 of the resulting draws in a rank-size plot. (The distribution of $s_T$ will be close to the stationary distribution when $T$ is large.) @@ -200,7 +198,7 @@ Now we write a for loop that repeatedly calls this function, to push a cross-section of firms forward in time. For sufficiently large `T`, the cross-section it returns (the cross-section at -time `T`) corresponds to firm size distribution in (approximate) equilibrium. +time `T`) corresponds to the firm size distribution in (approximate) equilibrium. ```{code-cell} ipython3 def generate_cross_section(firm, M=500_000, T=500, s_init=1.0, seed=123): @@ -255,7 +253,6 @@ plt.show() The plot produces a straight line, consistent with a Pareto tail. - #### Alternative implementation with `jax.lax.fori_loop` Although we JIT-compiled some of the code above, @@ -264,10 +261,10 @@ we did not JIT-compile the `for` loop. Let's try squeezing out a bit more speed by -* replacing the `for` loop with [`jax.lax.fori_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html) and +* replacing the `for` loop with [`jax.lax.fori_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html), and * JIT-compiling the whole function. -Here a the `jax.lax.fori_loop` version: +Here is the `jax.lax.fori_loop` version: ```{code-cell} ipython3 @jax.jit From 054c7c96905b5a31a29873ac36632b0090d1181a Mon Sep 17 00:00:00 2001 From: kp992 Date: Fri, 5 Sep 2025 11:48:12 -0700 Subject: [PATCH 4/4] add sphinx mapping --- lectures/_config.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lectures/_config.yml b/lectures/_config.yml index 4f5b4445..fb6f506f 100644 --- a/lectures/_config.yml +++ b/lectures/_config.yml @@ -93,6 +93,10 @@ sphinx: macros: "argmax" : "arg\\,max" "argmin" : "arg\\,min" + intersphinx_mapping: + intermediate: + - "https://python.quantecon.org/" + - null mathjax_path: https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js rediraffe_redirects: index_toc.md: intro.md