From c227f9c60304b445bd332eac3abed27ec55856e7 Mon Sep 17 00:00:00 2001 From: Kenko LI Date: Mon, 20 Oct 2025 17:27:54 +0900 Subject: [PATCH 1/7] update markov_asset --- lectures/markov_asset.md | 166 +++++++++++++++++++++------------------ 1 file changed, 91 insertions(+), 75 deletions(-) diff --git a/lectures/markov_asset.md b/lectures/markov_asset.md index 1739dde0f..0c4264df0 100644 --- a/lectures/markov_asset.md +++ b/lectures/markov_asset.md @@ -75,7 +75,10 @@ Let's start with some imports: import matplotlib.pyplot as plt import numpy as np import quantecon as qe -from numpy.linalg import eigvals, solve +import jax +import jax.numpy as jnp +from jax.numpy.linalg import eigvals, solve +from typing import NamedTuple ``` ## {index}`Pricing Models ` @@ -91,7 +94,7 @@ Let $\{d_t\}_{t \geq 0}$ be a stream of dividends Let's look at some equations that we expect to hold for prices of assets under ex-dividend contracts (we will consider cum-dividend pricing in the exercises). -### Risk-Neutral Pricing +### Risk-neutral pricing ```{index} single: Pricing Models; Risk-Neutral ``` @@ -116,7 +119,7 @@ Here ${\mathbb E}_t [y]$ denotes the best forecast of $y$, conditioned on inform More precisely, ${\mathbb E}_t [y]$ is the mathematical expectation of $y$ conditional on information available at time $t$. -### Pricing with Random Discount Factor +### Pricing with random discount factor ```{index} single: Pricing Models; Risk Aversion ``` @@ -145,7 +148,7 @@ This is because such assets pay well when funds are more urgently wanted. We give examples of how the stochastic discount factor has been modeled below. -### Asset Pricing and Covariances +### Asset pricing and covariances Recall that, from the definition of a conditional covariance ${\rm cov}_t (x_{t+1}, y_{t+1})$, we have @@ -174,7 +177,7 @@ Equation {eq}`lteeqs102` asserts that the covariance of the stochastic discount We give examples of some models of stochastic discount factors that have been proposed later in this lecture and also in a [later lecture](https://python-advanced.quantecon.org/lucas_model.html). -### The Price-Dividend Ratio +### The price-dividend ratio Aside from prices, another quantity of interest is the **price-dividend ratio** $v_t := p_t / d_t$. @@ -190,7 +193,7 @@ v_t = {\mathbb E}_t \left[ m_{t+1} \frac{d_{t+1}}{d_t} (1 + v_{t+1}) \right] Below we'll discuss the implication of this equation. -## Prices in the Risk-Neutral Case +## Prices in the risk-neutral case What can we say about price dynamics on the basis of the models described above? @@ -203,7 +206,7 @@ For now we'll study the risk-neutral case in which the stochastic discount fac We'll focus on how an asset price depends on a dividend process. -### Example 1: Constant Dividends +### Example 1: constant dividends The simplest case is risk-neutral price of a constant, non-random dividend stream $d_t = d > 0$. @@ -234,7 +237,7 @@ This is the equilibrium price in the constant dividend case. Indeed, simple algebra shows that setting $p_t = \bar p$ for all $t$ satisfies the difference equation $p_t = \beta (d + p_{t+1})$. -### Example 2: Dividends with Deterministic Growth Paths +### Example 2: dividends with deterministic growth paths Consider a growing, non-random dividend process $d_{t+1} = g d_t$ where $0 < g \beta < 1$. @@ -267,7 +270,7 @@ $$ This is called the *Gordon formula*. (mass_mg)= -### Example 3: Markov Growth, Risk-Neutral Pricing +### Example 3: Markov growth, risk-neutral pricing Next, we consider a dividend process @@ -310,14 +313,14 @@ The next figure shows a simulation, where * $\{X_t\}$ evolves as a discretized AR1 process produced using {ref}`Tauchen's method `. * $g_t = \exp(X_t)$, so that $\ln g_t = X_t$ is the growth rate. -```{code-cell} ipython +```{code-cell} ipython3 n = 7 mc = qe.tauchen(n, 0.96, 0.25) sim_length = 80 -x_series = mc.simulate(sim_length, init=np.median(mc.state_values)) -g_series = np.exp(x_series) -d_series = np.cumprod(g_series) # Assumes d_0 = 1 +x_series = mc.simulate(sim_length, init=jnp.median(mc.state_values)) +g_series = jnp.exp(x_series) +d_series = jnp.cumprod(g_series) # Assumes d_0 = 1 series = [x_series, g_series, d_series, np.log(d_series)] labels = ['$X_t$', '$g_t$', '$d_t$', r'$\log \, d_t$'] @@ -330,7 +333,7 @@ plt.tight_layout() plt.show() ``` -#### Pricing Formula +#### Pricing formula To obtain asset prices in this setting, let's adapt our analysis from the case of deterministic growth. @@ -400,18 +403,18 @@ As before, we'll generate $\{X_t\}$ as a {ref}`discretized AR1 process Here's the code, including a test of the spectral radius condition -```{code-cell} python3 +```{code-cell} ipython3 n = 25 # Size of state space β = 0.9 mc = qe.tauchen(n, 0.96, 0.02) -K = mc.P * np.exp(mc.state_values) +K = mc.P * jnp.exp(mc.state_values) warning_message = "Spectral radius condition fails" -assert np.max(np.abs(eigvals(K))) < 1 / β, warning_message +assert jnp.max(jnp.abs(eigvals(K))) < 1 / β, warning_message -I = np.identity(n) -v = solve(I - β * K, β * K @ np.ones(n)) +I = jnp.identity(n) +v = solve(I - β * K, β * K @ jnp.ones(n)) fig, ax = plt.subplots() ax.plot(mc.state_values, v, 'g-o', lw=2, alpha=0.7, label='$v$') @@ -440,7 +443,7 @@ We'll price several distinct assets, including * A consol (a type of bond issued by the UK government in the 19th century) * Call options on a consol -### Pricing a Lucas Tree +### Pricing a Lucas tree ```{index} single: Finite Markov Asset Pricing; Lucas Tree ``` @@ -539,46 +542,51 @@ v = (I - \beta J)^{-1} \beta J {\mathbb 1} We will define a function tree_price to compute $v$ given parameters stored in the class AssetPriceModel -```{code-cell} python3 -class AssetPriceModel: +```{code-cell} ipython3 +class AssetPriceModel(NamedTuple): """ A class that stores the primitives of the asset pricing model. Parameters ---------- - β : scalar, float - Discount factor mc : MarkovChain - Contains the transition matrix and set of state values for the state - process - γ : scalar(float) - Coefficient of risk aversion + Contains the transition matrix and set of state values g : callable The function mapping states to growth rates - + β : float + Discount factor + γ : float + Coefficient of risk aversion + n: int + The number of states + """ + mc: qe.MarkovChain + g: callable + β: float + γ: float + n: int + + +def create_ap_model(mc=None, g=jnp.exp, β=0.96, γ=2.0): + """Create an AssetPriceModel class""" + if mc is None: + n, ρ, σ = 25, 0.9, 0.02 + mc = qe.tauchen(n, ρ, σ) + else: + mc = mc + n = mc.P.shape[0] + + return AssetPriceModel(mc=mc, g=g, β=β, γ=γ, n=n) + + +def test_stability(Q, β): """ - def __init__(self, β=0.96, mc=None, γ=2.0, g=np.exp): - self.β, self.γ = β, γ - self.g = g - - # A default process for the Markov chain - if mc is None: - self.ρ = 0.9 - self.σ = 0.02 - self.mc = qe.tauchen(n, self.ρ, self.σ) - else: - self.mc = mc - - self.n = self.mc.P.shape[0] - - def test_stability(self, Q): - """ - Stability test for a given matrix Q. - """ - sr = np.max(np.abs(eigvals(Q))) - if not sr < 1 / self.β: - msg = f"Spectral radius condition failed with radius = {sr}" - raise ValueError(msg) + Stability test for a given matrix Q. + """ + sr = np.max(np.abs(eigvals(Q))) + if not sr < 1 / β: + msg = f"Spectral radius condition failed with radius = {sr}" + raise ValueError(msg) def tree_price(ap): @@ -601,11 +609,11 @@ def tree_price(ap): J = P * ap.g(y)**(1 - γ) # Make sure that a unique solution exists - ap.test_stability(J) + test_stability(J, β) # Compute v - I = np.identity(ap.n) - Ones = np.ones(ap.n) + I = jnp.identity(ap.n) + Ones = jnp.ones(ap.n) v = solve(I - β * J, β * J @ Ones) return v @@ -614,16 +622,16 @@ def tree_price(ap): Here's a plot of $v$ as a function of the state for several values of $\gamma$, with a positively correlated Markov process and $g(x) = \exp(x)$ -```{code-cell} python3 +```{code-cell} ipython3 γs = [1.2, 1.4, 1.6, 1.8, 2.0] -ap = AssetPriceModel() +ap = create_ap_model() states = ap.mc.state_values fig, ax = plt.subplots() for γ in γs: - ap.γ = γ - v = tree_price(ap) + tem_ap = create_ap_model(mc=ap.mc, g=ap.g, β=ap.β, γ=γ) + v = tree_price(tem_ap) ax.plot(states, v, lw=2, alpha=0.6, label=rf"$\gamma = {γ}$") ax.set_title('Price-dividend ratio as a function of the state') @@ -706,7 +714,7 @@ p = (I - \beta M)^{-1} \beta M \zeta {\mathbb 1} The above is implemented in the function consol_price. -```{code-cell} python3 +```{code-cell} ipython3 def consol_price(ap, ζ): """ Computes price of a consol bond with payoff ζ @@ -715,7 +723,7 @@ def consol_price(ap, ζ): ---------- ap: AssetPriceModel An instance of AssetPriceModel containing primitives - + ζ : scalar(float) Coupon of the console @@ -723,18 +731,17 @@ def consol_price(ap, ζ): ------- p : array_like(float) Console bond prices - """ # Simplify names, set up matrices β, γ, P, y = ap.β, ap.γ, ap.mc.P, ap.mc.state_values M = P * ap.g(y)**(- γ) # Make sure that a unique solution exists - ap.test_stability(M) + test_stability(M, β) # Compute price - I = np.identity(ap.n) - Ones = np.ones(ap.n) + I = jnp.identity(ap.n) + Ones = jnp.ones(ap.n) p = solve(I - β * M, β * ζ * M @ Ones) return p @@ -812,7 +819,7 @@ Start at some initial $w$ and iterate with $T$ to convergence . We can find the solution with the following function call_option -```{code-cell} python3 +```{code-cell} ipython3 def call_option(ap, ζ, p_s, ϵ=1e-7): """ Computes price of a call option on a consol bond. @@ -828,7 +835,7 @@ def call_option(ap, ζ, p_s, ϵ=1e-7): p_s : scalar(float) Strike price - ϵ : scalar(float), optional(default=1e-8) + ϵ : scalar(float), optional(default=1e-7) Tolerance for infinite horizon problem Returns @@ -842,26 +849,35 @@ def call_option(ap, ζ, p_s, ϵ=1e-7): M = P * ap.g(y)**(- γ) # Make sure that a unique consol price exists - ap.test_stability(M) + test_stability(M, β) # Compute option price p = consol_price(ap, ζ) - w = np.zeros(ap.n) + w = jnp.zeros(ap.n) error = ϵ + 1 - while error > ϵ: + + def step(state): + w, error = state # Maximize across columns - w_new = np.maximum(β * M @ w, p - p_s) + w_new = jnp.maximum(β * M @ w, p - p_s) # Find maximal difference of each component and update - error = np.amax(np.abs(w - w_new)) - w = w_new + error_new = jnp.amax(jnp.abs(w - w_new)) + return (w_new, error_new) - return w + # Check whether converged + def cond(state): + _, error = state + return error > ϵ + + final_w, _ = jax.lax.while_loop(cond, step, (w, error)) + + return final_w ``` Here's a plot of $w$ compared to the consol price when $P_S = 40$ -```{code-cell} python3 -ap = AssetPriceModel(β=0.9) +```{code-cell} ipython3 +ap = create_ap_model(β=0.9) ζ = 1.0 strike_price = 40 From 0317e4a9f51861b3102dec1a15759f3f46ec04f3 Mon Sep 17 00:00:00 2001 From: Kenko LI Date: Wed, 22 Oct 2025 15:36:53 +0900 Subject: [PATCH 2/7] modified: lectures/markov_asset.md --- lectures/markov_asset.md | 53 ++++++++++++++++++++++++++-------------- 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/lectures/markov_asset.md b/lectures/markov_asset.md index 0c4264df0..298d43433 100644 --- a/lectures/markov_asset.md +++ b/lectures/markov_asset.md @@ -35,6 +35,16 @@ kernelspec: "Asset pricing is all about covariances" -- Lars Peter Hansen ``` +```{admonition} GPU +:class: warning + +This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and JAX for GPU programming. + +Free GPUs are available on Google Colab. To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU. + +Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax) for installing JAX with GPU support. If you would like to install JAX running on the `cpu` only you can use `pip install jax[cpu]` +``` + In addition to what's in Anaconda, this lecture will need the following libraries: ```{code-cell} ipython @@ -976,10 +986,12 @@ $$ Consider the following primitives -```{code-cell} python3 +```{code-cell} ipython3 n = 5 # Size of State Space -P = np.full((n, n), 0.0125) -P[range(n), range(n)] += 1 - P.sum(1) +P = jnp.full((n, n), 0.0125) +P = P.at[jnp.arange(n), jnp.arange(n)].set( + P[jnp.arange(n), jnp.arange(n)] + 1 - P.sum(1) + ) # State values of the Markov chain s = np.array([0.95, 0.975, 1.0, 1.025, 1.05]) γ = 2.0 @@ -1004,11 +1016,13 @@ Do the same for First, let's enter the parameters: -```{code-cell} python3 +```{code-cell} ipython3 n = 5 -P = np.full((n, n), 0.0125) -P[range(n), range(n)] += 1 - P.sum(1) -s = np.array([0.95, 0.975, 1.0, 1.025, 1.05]) # State values +P = jnp.full((n, n), 0.0125) +P = P.at[jnp.arange(n), jnp.arange(n)].set( + P[jnp.arange(n), jnp.arange(n)] + 1 - P.sum(1) + ) +s = jnp.array([0.95, 0.975, 1.0, 1.025, 1.05]) # State values mc = qe.MarkovChain(P, state_values=s) γ = 2.0 @@ -1020,27 +1034,27 @@ p_s = 150.0 Next, we'll create an instance of `AssetPriceModel` to feed into the functions -```{code-cell} python3 -apm = AssetPriceModel(β=β, mc=mc, γ=γ, g=lambda x: x) +```{code-cell} ipython3 +apm = create_ap_model(mc=mc, g=lambda x: x, β=β, γ=γ) ``` Now we just need to call the relevant functions on the data: -```{code-cell} python3 +```{code-cell} ipython3 tree_price(apm) ``` -```{code-cell} python3 +```{code-cell} ipython3 consol_price(apm, ζ) ``` -```{code-cell} python3 +```{code-cell} ipython3 call_option(apm, ζ, p_s) ``` Let's show the last two functions as a plot -```{code-cell} python3 +```{code-cell} ipython3 fig, ax = plt.subplots() ax.plot(s, consol_price(apm, ζ), label='consol') ax.plot(s, call_option(apm, ζ, p_s), label='call option') @@ -1101,7 +1115,7 @@ Is one higher than the other? Can you give intuition? Here's a suitable function: -```{code-cell} python3 +```{code-cell} ipython3 def finite_horizon_call_option(ap, ζ, p_s, k): """ Computes k period option value. @@ -1111,15 +1125,16 @@ def finite_horizon_call_option(ap, ζ, p_s, k): M = P * ap.g(y)**(- γ) # Make sure that a unique solution exists - ap.test_stability(M) - + test_stability(M, β) # Compute option price p = consol_price(ap, ζ) - w = np.zeros(ap.n) - for i in range(k): + def step(i, w): # Maximize across columns - w = np.maximum(β * M @ w, p - p_s) + w = jnp.maximum(β * M @ w, p - p_s) + return w + + w = jax.lax.fori_loop(0, k, step, jnp.zeros(ap.n)) return w ``` From 17396978b84468c0e83f50365659e6762ebf90ab Mon Sep 17 00:00:00 2001 From: Kenko LI Date: Wed, 22 Oct 2025 17:09:33 +0900 Subject: [PATCH 3/7] add metadata cell for figures --- lectures/markov_asset.md | 45 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/lectures/markov_asset.md b/lectures/markov_asset.md index 298d43433..90e20c315 100644 --- a/lectures/markov_asset.md +++ b/lectures/markov_asset.md @@ -324,6 +324,13 @@ The next figure shows a simulation, where * $g_t = \exp(X_t)$, so that $\ln g_t = X_t$ is the growth rate. ```{code-cell} ipython3 +--- +mystnb: + figure: + caption: | + State, growth, and dividend simulation + name: fig_markov_sim +--- n = 7 mc = qe.tauchen(n, 0.96, 0.25) sim_length = 80 @@ -414,6 +421,13 @@ As before, we'll generate $\{X_t\}$ as a {ref}`discretized AR1 process Here's the code, including a test of the spectral radius condition ```{code-cell} ipython3 +--- +mystnb: + figure: + caption: | + Price-dividend ratio risk-neutral case + name: fig_pdv_neutral +--- n = 25 # Size of state space β = 0.9 mc = qe.tauchen(n, 0.96, 0.02) @@ -633,6 +647,13 @@ Here's a plot of $v$ as a function of the state for several values of $\gamma$, with a positively correlated Markov process and $g(x) = \exp(x)$ ```{code-cell} ipython3 +--- +mystnb: + figure: + caption: | + Lucas tree prices for varying risk aversion + name: fig_lucas_gamma +--- γs = [1.2, 1.4, 1.6, 1.8, 2.0] ap = create_ap_model() states = ap.mc.state_values @@ -644,7 +665,6 @@ for γ in γs: v = tree_price(tem_ap) ax.plot(states, v, lw=2, alpha=0.6, label=rf"$\gamma = {γ}$") -ax.set_title('Price-dividend ratio as a function of the state') ax.set_ylabel("price-dividend ratio") ax.set_xlabel("state") ax.legend(loc='upper right') @@ -887,6 +907,13 @@ def call_option(ap, ζ, p_s, ϵ=1e-7): Here's a plot of $w$ compared to the consol price when $P_S = 40$ ```{code-cell} ipython3 +--- +mystnb: + figure: + caption: | + Consol price and call option value + name: fig_consol_call +--- ap = create_ap_model(β=0.9) ζ = 1.0 strike_price = 40 @@ -1055,6 +1082,13 @@ call_option(apm, ζ, p_s) Let's show the last two functions as a plot ```{code-cell} ipython3 +--- +mystnb: + figure: + caption: | + Consol and call option exercise two comparison + name: fig_ex2_prices +--- fig, ax = plt.subplots() ax.plot(s, consol_price(apm, ζ), label='consol') ax.plot(s, call_option(apm, ζ, p_s), label='call option') @@ -1141,7 +1175,14 @@ def finite_horizon_call_option(ap, ζ, p_s, k): Now let's compute the option values at `k=5` and `k=25` -```{code-cell} python3 +```{code-cell} ipython3 +--- +mystnb: + figure: + caption: | + Finite horizon call option values + name: fig_ex3_finite +--- fig, ax = plt.subplots() for k in [5, 25]: w = finite_horizon_call_option(apm, ζ, p_s, k) From 76c64b2a050b0036945d58d4e20ad5b441f75fba Mon Sep 17 00:00:00 2001 From: Kenko LI Date: Wed, 22 Oct 2025 17:27:57 +0900 Subject: [PATCH 4/7] delete metadata cell in solutions --- lectures/markov_asset.md | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/lectures/markov_asset.md b/lectures/markov_asset.md index 90e20c315..a7e04a010 100644 --- a/lectures/markov_asset.md +++ b/lectures/markov_asset.md @@ -1082,13 +1082,6 @@ call_option(apm, ζ, p_s) Let's show the last two functions as a plot ```{code-cell} ipython3 ---- -mystnb: - figure: - caption: | - Consol and call option exercise two comparison - name: fig_ex2_prices ---- fig, ax = plt.subplots() ax.plot(s, consol_price(apm, ζ), label='consol') ax.plot(s, call_option(apm, ζ, p_s), label='call option') @@ -1176,13 +1169,6 @@ def finite_horizon_call_option(ap, ζ, p_s, k): Now let's compute the option values at `k=5` and `k=25` ```{code-cell} ipython3 ---- -mystnb: - figure: - caption: | - Finite horizon call option values - name: fig_ex3_finite ---- fig, ax = plt.subplots() for k in [5, 25]: w = finite_horizon_call_option(apm, ζ, p_s, k) From 02dfe6e50889b1fc45843b2e3c6fa91dc1e6ebac Mon Sep 17 00:00:00 2001 From: Kenko LI Date: Mon, 27 Oct 2025 13:31:16 +0900 Subject: [PATCH 5/7] modify document link --- lectures/markov_asset.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lectures/markov_asset.md b/lectures/markov_asset.md index a7e04a010..ef6c8d3c2 100644 --- a/lectures/markov_asset.md +++ b/lectures/markov_asset.md @@ -185,7 +185,7 @@ It is useful to regard equation {eq}`lteeqs102` as a generalization of equatio Equation {eq}`lteeqs102` asserts that the covariance of the stochastic discount factor with the one period payout $d_{t+1} + p_{t+1}$ is an important determinant of the price $p_t$. -We give examples of some models of stochastic discount factors that have been proposed later in this lecture and also in a [later lecture](https://python-advanced.quantecon.org/lucas_model.html). +We give examples of some models of stochastic discount factors that have been proposed later in this lecture and also in a {doc}`later lecture`. ### The price-dividend ratio @@ -496,7 +496,7 @@ m_{t+1} = \beta \frac{u'(c_{t+1})}{u'(c_t)} where $u$ is a concave utility function and $c_t$ is time $t$ consumption of a representative consumer. -(A derivation of this expression is given in a [later lecture](https://python-advanced.quantecon.org/lucas_model.html)) +(A derivation of this expression is given in a {doc}`later lecture`) Assume the existence of an endowment that follows growth process {eq}`mass_fmce`. From 211ec24cdcea153cab421f4c96090ed8d0caeed1 Mon Sep 17 00:00:00 2001 From: Kenko LI Date: Mon, 27 Oct 2025 16:56:50 +0900 Subject: [PATCH 6/7] Made functions more JAX-compatible --- lectures/markov_asset.md | 73 +++++++++++++++++++++++++++------------- 1 file changed, 50 insertions(+), 23 deletions(-) diff --git a/lectures/markov_asset.md b/lectures/markov_asset.md index ef6c8d3c2..93fb6d5d2 100644 --- a/lectures/markov_asset.md +++ b/lectures/markov_asset.md @@ -88,6 +88,7 @@ import quantecon as qe import jax import jax.numpy as jnp from jax.numpy.linalg import eigvals, solve +from jax.experimental import checkify from typing import NamedTuple ``` @@ -567,6 +568,20 @@ We will define a function tree_price to compute $v$ given parameters stored in the class AssetPriceModel ```{code-cell} ipython3 +class MarkovChain(NamedTuple): + """ + A class that stores the primitives of a Markov chain. + Parameters + ---------- + P : jnp.ndarray + Transition matrix + state_values : jnp.ndarray + The values associated with each state + """ + P: jnp.ndarray + state_values: jnp.ndarray + + class AssetPriceModel(NamedTuple): """ A class that stores the primitives of the asset pricing model. @@ -584,34 +599,42 @@ class AssetPriceModel(NamedTuple): n: int The number of states """ - mc: qe.MarkovChain + mc: MarkovChain g: callable β: float γ: float n: int -def create_ap_model(mc=None, g=jnp.exp, β=0.96, γ=2.0): - """Create an AssetPriceModel class""" - if mc is None: - n, ρ, σ = 25, 0.9, 0.02 - mc = qe.tauchen(n, ρ, σ) - else: - mc = mc - n = mc.P.shape[0] +def create_ap_model(g=jnp.exp, β=0.96, γ=2.0): + """Create an AssetPriceModel class using standard Markov chain.""" + n, ρ, σ = 25, 0.9, 0.02 + qe_mc = qe.tauchen(n, ρ, σ) + P = jnp.array(qe_mc.P) + state_values = jnp.array(qe_mc.state_values) + mc = MarkovChain(P=P, state_values=state_values) return AssetPriceModel(mc=mc, g=g, β=β, γ=γ, n=n) +def create_customized_ap_model(mc: MarkovChain, g=jnp.exp, β=0.96, γ=2.0): + """Create an AssetPriceModel class using a customized Markov chain.""" + n = mc.P.shape[0] + return AssetPriceModel(mc=mc, g=g, β=β, γ=γ, n=n) + + def test_stability(Q, β): - """ - Stability test for a given matrix Q. - """ - sr = np.max(np.abs(eigvals(Q))) - if not sr < 1 / β: - msg = f"Spectral radius condition failed with radius = {sr}" - raise ValueError(msg) + """Stability test for a given matrix Q.""" + sr = jnp.max(jnp.abs(eigvals(Q))) + checkify.check( + sr < 1 / β, + "Spectral radius condition failed with radius = {sr}", sr=sr + ) + return sr + +# Wrap the check function to be JIT-safe +test_stability = checkify.checkify(test_stability, errors=checkify.user_checks) def tree_price(ap): """ @@ -633,7 +656,8 @@ def tree_price(ap): J = P * ap.g(y)**(1 - γ) # Make sure that a unique solution exists - test_stability(J, β) + err, out = test_stability(J, β) + err.throw() # Compute v I = jnp.identity(ap.n) @@ -661,7 +685,7 @@ states = ap.mc.state_values fig, ax = plt.subplots() for γ in γs: - tem_ap = create_ap_model(mc=ap.mc, g=ap.g, β=ap.β, γ=γ) + tem_ap = create_customized_ap_model(mc=ap.mc, β=ap.β, γ=γ) v = tree_price(tem_ap) ax.plot(states, v, lw=2, alpha=0.6, label=rf"$\gamma = {γ}$") @@ -767,7 +791,8 @@ def consol_price(ap, ζ): M = P * ap.g(y)**(- γ) # Make sure that a unique solution exists - test_stability(M, β) + err, _ = test_stability(M, β) + err.throw() # Compute price I = jnp.identity(ap.n) @@ -879,7 +904,8 @@ def call_option(ap, ζ, p_s, ϵ=1e-7): M = P * ap.g(y)**(- γ) # Make sure that a unique consol price exists - test_stability(M, β) + err, _ = test_stability(M, β) + err.throw() # Compute option price p = consol_price(ap, ζ) @@ -887,7 +913,7 @@ def call_option(ap, ζ, p_s, ϵ=1e-7): error = ϵ + 1 def step(state): - w, error = state + w, _ = state # Maximize across columns w_new = jnp.maximum(β * M @ w, p - p_s) # Find maximal difference of each component and update @@ -1062,7 +1088,7 @@ Next, we'll create an instance of `AssetPriceModel` to feed into the functions ```{code-cell} ipython3 -apm = create_ap_model(mc=mc, g=lambda x: x, β=β, γ=γ) +apm = create_customized_ap_model(mc=mc, g=lambda x: x, β=β, γ=γ) ``` Now we just need to call the relevant functions on the data: @@ -1152,7 +1178,8 @@ def finite_horizon_call_option(ap, ζ, p_s, k): M = P * ap.g(y)**(- γ) # Make sure that a unique solution exists - test_stability(M, β) + err, _ = test_stability(M, β) + err.throw() # Compute option price p = consol_price(ap, ζ) From 2cfda1acaa1e9db487b20dc606789632ab1c9cad Mon Sep 17 00:00:00 2001 From: Kenko LI Date: Wed, 29 Oct 2025 17:57:23 +0900 Subject: [PATCH 7/7] jit computation functions --- lectures/markov_asset.md | 149 ++++++++++++++++++++++----------------- 1 file changed, 84 insertions(+), 65 deletions(-) diff --git a/lectures/markov_asset.md b/lectures/markov_asset.md index 93fb6d5d2..b8b533d98 100644 --- a/lectures/markov_asset.md +++ b/lectures/markov_asset.md @@ -63,7 +63,7 @@ An asset is a claim on one or more future payoffs. The spot price of an asset depends primarily on -* the anticipated income stream +* the anticipated income stream * attitudes about risk * rates of time preference @@ -75,7 +75,7 @@ We also look at creating and pricing *derivative* assets that repackage income s Key tools for the lecture are -* Markov processses +* Markov processes * formulas for predicting future values of functions of a Markov state * a formula for predicting the discounted sum of future values of a Markov state @@ -83,7 +83,6 @@ Let's start with some imports: ```{code-cell} ipython import matplotlib.pyplot as plt -import numpy as np import quantecon as qe import jax import jax.numpy as jnp @@ -151,7 +150,7 @@ for some **stochastic discount factor** $m_{t+1}$. Here the fixed discount factor $\beta$ in {eq}`rnapex` has been replaced by the random variable $m_{t+1}$. -How anticipated future payoffs are evaluated now depends on statistical properties of $m_{t+1}$. +How anticipated future payoffs are evaluated now depends on statistical properties of $m_{t+1}$. The stochastic discount factor can be specified to capture the idea that assets that tend to have good payoffs in bad states of the world are valued more highly than other assets whose payoffs don't behave that way. @@ -177,12 +176,12 @@ If we apply this definition to the asset pricing equation {eq}`lteeqs0` we obtai p_t = {\mathbb E}_t m_{t+1} {\mathbb E}_t (d_{t+1} + p_{t+1}) + {\rm cov}_t (m_{t+1}, d_{t+1}+ p_{t+1}) ``` -It is useful to regard equation {eq}`lteeqs102` as a generalization of equation {eq}`rnapex` +It is useful to regard equation {eq}`lteeqs102` as a generalization of equation {eq}`rnapex` -* In equation {eq}`rnapex`, the stochastic discount factor $m_{t+1} = \beta$, a constant. +* In equation {eq}`rnapex`, the stochastic discount factor $m_{t+1} = \beta$, a constant. * In equation {eq}`rnapex`, the covariance term ${\rm cov}_t (m_{t+1}, d_{t+1}+ p_{t+1})$ is zero because $m_{t+1} = \beta$. * In equation {eq}`rnapex`, ${\mathbb E}_t m_{t+1}$ can be interpreted as the reciprocal of the one-period risk-free gross interest rate. -* When $m_{t+1}$ covaries more negatively with the payout $p_{t+1} + d_{t+1}$, the price of the asset is lower. +* When $m_{t+1}$ covaries more negatively with the payout $p_{t+1} + d_{t+1}$, the price of the asset is lower. Equation {eq}`lteeqs102` asserts that the covariance of the stochastic discount factor with the one period payout $d_{t+1} + p_{t+1}$ is an important determinant of the price $p_t$. @@ -213,9 +212,9 @@ The answer to this question depends on 1. the process we specify for dividends 1. the stochastic discount factor and how it correlates with dividends -For now we'll study the risk-neutral case in which the stochastic discount factor is constant. +For now we'll study the risk-neutral case in which the stochastic discount factor is constant. -We'll focus on how an asset price depends on a dividend process. +We'll focus on how an asset price depends on a dividend process. ### Example 1: constant dividends @@ -340,7 +339,7 @@ x_series = mc.simulate(sim_length, init=jnp.median(mc.state_values)) g_series = jnp.exp(x_series) d_series = jnp.cumprod(g_series) # Assumes d_0 = 1 -series = [x_series, g_series, d_series, np.log(d_series)] +series = [x_series, g_series, d_series, jnp.log(d_series)] labels = ['$X_t$', '$g_t$', '$d_t$', r'$\log \, d_t$'] fig, axes = plt.subplots(2, 2) @@ -564,8 +563,8 @@ Assuming that the spectral radius of $J$ is strictly less than $\beta^{-1}$, thi v = (I - \beta J)^{-1} \beta J {\mathbb 1} ``` -We will define a function tree_price to compute $v$ given parameters stored in -the class AssetPriceModel +We will define a function `tree_price` to compute $v$ given parameters stored in +the class `AssetPriceModel` ```{code-cell} ipython3 class MarkovChain(NamedTuple): @@ -578,8 +577,8 @@ class MarkovChain(NamedTuple): state_values : jnp.ndarray The values associated with each state """ - P: jnp.ndarray - state_values: jnp.ndarray + P: jax.Array + state_values: jax.Array class AssetPriceModel(NamedTuple): @@ -590,20 +589,17 @@ class AssetPriceModel(NamedTuple): ---------- mc : MarkovChain Contains the transition matrix and set of state values - g : callable - The function mapping states to growth rates + G : jax.Array + The vector form of the function mapping states to growth rates β : float Discount factor γ : float Coefficient of risk aversion - n: int - The number of states """ mc: MarkovChain - g: callable + G: jax.Array β: float γ: float - n: int def create_ap_model(g=jnp.exp, β=0.96, γ=2.0): @@ -612,15 +608,16 @@ def create_ap_model(g=jnp.exp, β=0.96, γ=2.0): qe_mc = qe.tauchen(n, ρ, σ) P = jnp.array(qe_mc.P) state_values = jnp.array(qe_mc.state_values) + G = g(state_values) mc = MarkovChain(P=P, state_values=state_values) - return AssetPriceModel(mc=mc, g=g, β=β, γ=γ, n=n) + return AssetPriceModel(mc=mc, G=G, β=β, γ=γ) def create_customized_ap_model(mc: MarkovChain, g=jnp.exp, β=0.96, γ=2.0): """Create an AssetPriceModel class using a customized Markov chain.""" - n = mc.P.shape[0] - return AssetPriceModel(mc=mc, g=g, β=β, γ=γ, n=n) + G = g(mc.state_values) + return AssetPriceModel(mc=mc, G=G, β=β, γ=γ) def test_stability(Q, β): @@ -633,9 +630,6 @@ def test_stability(Q, β): return sr -# Wrap the check function to be JIT-safe -test_stability = checkify.checkify(test_stability, errors=checkify.user_checks) - def tree_price(ap): """ Computes the price-dividend ratio of the Lucas tree. @@ -649,22 +643,24 @@ def tree_price(ap): ------- v : array_like(float) Lucas tree price-dividend ratio - """ # Simplify names, set up matrices - β, γ, P, y = ap.β, ap.γ, ap.mc.P, ap.mc.state_values - J = P * ap.g(y)**(1 - γ) + β, γ, P, G = ap.β, ap.γ, ap.mc.P, ap.G + J = P * G**(1 - γ) # Make sure that a unique solution exists - err, out = test_stability(J, β) - err.throw() + test_stability(J, β) # Compute v - I = jnp.identity(ap.n) - Ones = jnp.ones(ap.n) + n = J.shape[0] + I = jnp.identity(n) + Ones = jnp.ones(n) v = solve(I - β * J, β * J @ Ones) return v + +# Wrap the function to be safely jitted +tree_price_jit = jax.jit(checkify.checkify(tree_price)) ``` Here's a plot of $v$ as a function of the state for several values of $\gamma$, @@ -685,8 +681,12 @@ states = ap.mc.state_values fig, ax = plt.subplots() for γ in γs: - tem_ap = create_customized_ap_model(mc=ap.mc, β=ap.β, γ=γ) - v = tree_price(tem_ap) + tem_ap = create_customized_ap_model(ap.mc, γ=γ) + # checkify returns a tuple + # err indicates whether errors happened + err, v = tree_price_jit(tem_ap) + # Stop if errors raised + err.throw() ax.plot(states, v, lw=2, alpha=0.6, label=rf"$\gamma = {γ}$") ax.set_ylabel("price-dividend ratio") @@ -766,7 +766,7 @@ yields the solution p = (I - \beta M)^{-1} \beta M \zeta {\mathbb 1} ``` -The above is implemented in the function consol_price. +The above is implemented in the function `consol_price`. ```{code-cell} ipython3 def consol_price(ap, ζ): @@ -787,19 +787,22 @@ def consol_price(ap, ζ): Console bond prices """ # Simplify names, set up matrices - β, γ, P, y = ap.β, ap.γ, ap.mc.P, ap.mc.state_values - M = P * ap.g(y)**(- γ) + β, γ, P, G = ap.β, ap.γ, ap.mc.P, ap.G + M = P * G**(- γ) # Make sure that a unique solution exists - err, _ = test_stability(M, β) - err.throw() + test_stability(M, β) # Compute price - I = jnp.identity(ap.n) - Ones = jnp.ones(ap.n) + n = M.shape[0] + I = jnp.identity(n) + Ones = jnp.ones(n) p = solve(I - β * M, β * ζ * M @ Ones) return p + +# Wrap the function to be safely jitted +consol_price_jit = jax.jit(checkify.checkify(consol_price)) ``` ### Pricing an Option to Purchase the Consol @@ -870,9 +873,9 @@ T w = \max \{ \beta M w,\; p - p_S {\mathbb 1} \} $$ -Start at some initial $w$ and iterate with $T$ to convergence . +Start at some initial $w$ and iterate with $T$ to convergence. -We can find the solution with the following function call_option +We can find the solution with the following function `call_option` ```{code-cell} ipython3 def call_option(ap, ζ, p_s, ϵ=1e-7): @@ -900,16 +903,17 @@ def call_option(ap, ζ, p_s, ϵ=1e-7): """ # Simplify names, set up matrices - β, γ, P, y = ap.β, ap.γ, ap.mc.P, ap.mc.state_values - M = P * ap.g(y)**(- γ) + β, γ, P, G = ap.β, ap.γ, ap.mc.P, ap.G + M = P * G**(- γ) # Make sure that a unique consol price exists - err, _ = test_stability(M, β) - err.throw() + test_stability(M, β) # Compute option price p = consol_price(ap, ζ) - w = jnp.zeros(ap.n) + err.throw() + n = M.shape[0] + w = jnp.zeros(n) error = ϵ + 1 def step(state): @@ -928,6 +932,8 @@ def call_option(ap, ζ, p_s, ϵ=1e-7): final_w, _ = jax.lax.while_loop(cond, step, (w, error)) return final_w + +call_option_jit = jax.jit(checkify.checkify(call_option)) ``` Here's a plot of $w$ compared to the consol price when $P_S = 40$ @@ -945,8 +951,10 @@ ap = create_ap_model(β=0.9) strike_price = 40 x = ap.mc.state_values -p = consol_price(ap, ζ) -w = call_option(ap, ζ, strike_price) +err, p = consol_price_jit(ap, ζ) +err.throw() +err, w = call_option_jit(ap, ζ, strike_price) +err.throw() fig, ax = plt.subplots() ax.plot(x, p, 'b-', lw=2, label='consol price') @@ -1046,7 +1054,7 @@ P = P.at[jnp.arange(n), jnp.arange(n)].set( P[jnp.arange(n), jnp.arange(n)] + 1 - P.sum(1) ) # State values of the Markov chain -s = np.array([0.95, 0.975, 1.0, 1.025, 1.05]) +s = jnp.array([0.95, 0.975, 1.0, 1.025, 1.05]) γ = 2.0 β = 0.94 ``` @@ -1076,7 +1084,7 @@ P = P.at[jnp.arange(n), jnp.arange(n)].set( P[jnp.arange(n), jnp.arange(n)] + 1 - P.sum(1) ) s = jnp.array([0.95, 0.975, 1.0, 1.025, 1.05]) # State values -mc = qe.MarkovChain(P, state_values=s) +mc = MarkovChain(P=P, state_values=s) γ = 2.0 β = 0.94 @@ -1094,23 +1102,29 @@ apm = create_customized_ap_model(mc=mc, g=lambda x: x, β=β, γ=γ) Now we just need to call the relevant functions on the data: ```{code-cell} ipython3 -tree_price(apm) +err, v = tree_price_jit(apm) +err.throw() +print(v) ``` ```{code-cell} ipython3 -consol_price(apm, ζ) +err, p = consol_price_jit(apm, ζ) +err.throw() +print(p) ``` ```{code-cell} ipython3 -call_option(apm, ζ, p_s) +err, w = call_option_jit(apm, ζ, p_s) +err.throw() +print(w) ``` Let's show the last two functions as a plot ```{code-cell} ipython3 fig, ax = plt.subplots() -ax.plot(s, consol_price(apm, ζ), label='consol') -ax.plot(s, call_option(apm, ζ, p_s), label='call option') +ax.plot(s, p, label='consol') +ax.plot(s, w, label='call option') ax.legend() plt.show() ``` @@ -1169,28 +1183,32 @@ Is one higher than the other? Can you give intuition? Here's a suitable function: ```{code-cell} ipython3 -def finite_horizon_call_option(ap, ζ, p_s, k): +def finite_call_option(ap, ζ, p_s, k): """ Computes k period option value. """ # Simplify names, set up matrices - β, γ, P, y = ap.β, ap.γ, ap.mc.P, ap.mc.state_values - M = P * ap.g(y)**(- γ) + β, γ, P, G = ap.β, ap.γ, ap.mc.P, ap.G + M = P * G**(- γ) # Make sure that a unique solution exists - err, _ = test_stability(M, β) - err.throw() + test_stability(M, β) # Compute option price p = consol_price(ap, ζ) + n = M.shape[0] def step(i, w): # Maximize across columns w = jnp.maximum(β * M @ w, p - p_s) return w - w = jax.lax.fori_loop(0, k, step, jnp.zeros(ap.n)) + w = jax.lax.fori_loop(0, k, step, jnp.zeros(n)) return w + +finite_call_option_jit = jax.jit( + checkify.checkify(finite_call_option) + ) ``` Now let's compute the option values at `k=5` and `k=25` @@ -1198,7 +1216,8 @@ Now let's compute the option values at `k=5` and `k=25` ```{code-cell} ipython3 fig, ax = plt.subplots() for k in [5, 25]: - w = finite_horizon_call_option(apm, ζ, p_s, k) + err, w = finite_call_option_jit(apm, ζ, p_s, k) + err.throw() ax.plot(s, w, label=rf'$k = {k}$') ax.legend() plt.show()