diff --git a/lectures/mccall_model.md b/lectures/mccall_model.md index 202b9d591..35ca8460d 100644 --- a/lectures/mccall_model.md +++ b/lectures/mccall_model.md @@ -3,8 +3,10 @@ jupytext: text_representation: extension: .md format_name: myst + format_version: 0.13 + jupytext_version: 1.17.2 kernelspec: - display_name: Python 3 + display_name: Python 3 (ipykernel) language: python name: python3 --- @@ -34,11 +36,10 @@ and the pros and cons as they themselves see them." -- Robert E. Lucas, Jr. In addition to what's in Anaconda, this lecture will need the following libraries: -```{code-cell} ipython ---- -tags: [hide-output] ---- -!pip install quantecon +```{code-cell} ipython3 +:tags: [hide-output] + +!pip install quantecon jax ``` ## Overview @@ -59,11 +60,14 @@ As we'll see, McCall's model is not only interesting in its own right but also a Let's start with some imports: -```{code-cell} ipython +```{code-cell} ipython3 import matplotlib.pyplot as plt import numpy as np -from numba import jit, float64 -from numba.experimental import jitclass +import numba +import jax +import jax.numpy as jnp +from typing import NamedTuple +from functools import partial import quantecon as qe from quantecon.distributions import BetaBinomial ``` @@ -91,9 +95,11 @@ At time $t$, our agent has two choices: The agent is infinitely lived and aims to maximize the expected discounted sum of earnings -$$ -\mathbb{E} \sum_{t=0}^{\infty} \beta^t y_t -$$ +```{math} +:label: obj_model + +{\mathbb E} \sum_{t=0}^\infty \beta^t u(y_t) +``` The constant $\beta$ lies in $(0, 1)$ and is called a **discount factor**. @@ -112,7 +118,7 @@ The worker faces a trade-off: * Waiting too long for a good offer is costly, since the future is discounted. * Accepting too early is costly, since better offers might arrive in the future. -To decide optimally in the face of this trade-off, we use dynamic programming. +To decide optimally in the face of this trade-off, we use [dynamic programming](https://dp.quantecon.org/). Dynamic programming can be thought of as a two-step procedure that @@ -135,10 +141,10 @@ To this end, let $v^*(w)$ be the total lifetime *value* accruing to an unemployed worker who enters the current period unemployed when the wage is $w \in \mathbb{W}$. -In particular, the agent has wage offer $w$ in hand. +(In particular, the agent has wage offer $w$ in hand and can accept or reject it.) More precisely, $v^*(w)$ denotes the value of the objective function -{eq}`objective` when an agent in this situation makes *optimal* decisions now +{eq}`obj_model` when an agent in this situation makes *optimal* decisions now and at all future points in time. Of course $v^*(w)$ is not trivial to calculate because we don't yet know @@ -163,7 +169,7 @@ v^*(w) for every possible $w$ in $\mathbb{W}$. -This important equation is a version of the **Bellman equation**, which is +This is a version of the **Bellman equation**, which is ubiquitous in economic dynamics and other fields involving planning over time. The intuition behind it is as follows: @@ -174,9 +180,12 @@ $$ \frac{w}{1 - \beta} = w + \beta w + \beta^2 w + \cdots $$ -* the second term inside the max operation is the **continuation value**, which is the lifetime payoff from rejecting the current offer and then behaving optimally in all subsequent periods +* the second term inside the max operation is the continuation value, which is + the lifetime payoff from rejecting the current offer and then behaving + optimally in all subsequent periods -If we optimize and pick the best of these two options, we obtain maximal lifetime value from today, given current offer $w$. +If we optimize and pick the best of these two options, we obtain maximal +lifetime value from today, given current offer $w$. But this is precisely $v^*(w)$, which is the left-hand side of {eq}`odu_pv`. @@ -193,7 +202,7 @@ All we have to do is select the maximal choice on the right-hand side of {eq}`od The optimal action is best thought of as a **policy**, which is, in general, a map from states to actions. -Given *any* $w$, we can read off the corresponding best choice (accept or +Given any $w$, we can read off the corresponding best choice (accept or reject) by picking the max on the right-hand side of {eq}`odu_pv`. Thus, we have a map from $\mathbb W$ to $\{0, 1\}$, with 1 meaning accept and 0 meaning reject. @@ -224,7 +233,7 @@ where \bar w := (1 - \beta) \left\{ c + \beta \sum_{w'} v^*(w') q (w') \right\} ``` -Here $\bar w$ (called the *reservation wage*) is a constant depending on +Here $\bar w$ (called the **reservation wage**) is a constant depending on $\beta, c$ and the wage distribution. The agent should accept if and only if the current wage offer exceeds the reservation wage. @@ -234,8 +243,7 @@ In view of {eq}`reswage`, we can compute this reservation wage if we can compute ## Computing the Optimal Policy: Take 1 -To put the above ideas into action, we need to compute the value function at -each possible state $w \in \mathbb W$. +To put the above ideas into action, we need to compute the value function at each $w \in \mathbb W$. To simplify notation, let's set @@ -245,8 +253,7 @@ $$ v^*(i) := v^*(w_i) $$ -The value function is then represented by the vector -$v^* = (v^*(i))_{i=1}^n$. +The value function is then represented by the vector $v^* = (v^*(i))_{i=1}^n$. In view of {eq}`odu_pv`, this vector satisfies the nonlinear system of equations @@ -298,8 +305,7 @@ The theory below elaborates on this point. What's the mathematics behind these ideas? -First, one defines a mapping $T$ from $\mathbb R^n$ to -itself via +First, one defines a mapping $T$ from $\mathbb R^n$ to itself via ```{math} :label: odu_pv3 @@ -316,11 +322,9 @@ itself via (A new vector $Tv$ is obtained from given vector $v$ by evaluating the r.h.s. at each $i$.) -The element $v_k$ in the sequence $\{v_k\}$ of successive -approximations corresponds to $T^k v$. +The element $v_k$ in the sequence $\{v_k\}$ of successive approximations corresponds to $T^k v$. -* This is $T$ applied $k$ times, starting at the initial guess - $v$ +* This is $T$ applied $k$ times, starting at the initial guess $v$ One can show that the conditions of the [Banach fixed point theorem](https://en.wikipedia.org/wiki/Banach_fixed-point_theorem) are satisfied by $T$ on $\mathbb R^n$. @@ -329,33 +333,32 @@ One implication is that $T$ has a unique fixed point in $\mathbb R^n$. * That is, a unique vector $\bar v$ such that $T \bar v = \bar v$. -Moreover, it's immediate from the definition of $T$ that this fixed -point is $v^*$. +Moreover, it's immediate from the definition of $T$ that this fixed point is $v^*$. A second implication of the Banach contraction mapping theorem is that -$\{ T^k v \}$ converges to the fixed point $v^*$ regardless of -$v$. +$\{ T^k v \}$ converges to the fixed point $v^*$ regardless of $v$. + ### Implementation Our default for $q$, the distribution of the state process, will be [Beta-binomial](https://en.wikipedia.org/wiki/Beta-binomial_distribution). -```{code-cell} python3 +```{code-cell} ipython3 n, a, b = 50, 200, 100 # default parameters -q_default = BetaBinomial(n, a, b).pdf() # default choice of q +q_default = jnp.array(BetaBinomial(n, a, b).pdf()) ``` Our default set of values for wages will be -```{code-cell} python3 +```{code-cell} ipython3 w_min, w_max = 10, 60 -w_default = np.linspace(w_min, w_max, n+1) +w_default = jnp.linspace(w_min, w_max, n+1) ``` Here's a plot of the probabilities of different wage outcomes: -```{code-cell} python3 +```{code-cell} ipython3 fig, ax = plt.subplots() ax.plot(w_default, q_default, '-o', label='$q(w(i))$') ax.set_xlabel('wages') @@ -364,60 +367,28 @@ ax.set_ylabel('probabilities') plt.show() ``` -We are going to use Numba to accelerate our code. +We will use [JAX](https://python-programming.quantecon.org/jax_intro.html) to write our code. -* See, in particular, the discussion of `@jitclass` in [our lecture on Numba](https://python-programming.quantecon.org/numba.html). +We'll use `NamedTuple` for our model class to maintain immutability, which works well with JAX's functional programming paradigm. -The following helps Numba by providing some type specifications. +Here's a class that stores the model parameters with default values. -```{code-cell} python3 -mccall_data = [ - ('c', float64), # unemployment compensation - ('β', float64), # discount factor - ('w', float64[::1]), # array of wage values, w[i] = wage at state i - ('q', float64[::1]) # array of probabilities -] +```{code-cell} ipython3 +class McCallModel(NamedTuple): + c: float = 25 # unemployment compensation + β: float = 0.99 # discount factor + w: jnp.ndarray = w_default # array of wage values, w[i] = wage at state i + q: jnp.ndarray = q_default # array of probabilities ``` -```{note} -Note the use of `[::1]` in the array type declarations above. - -This notation specifies that the arrays should be C-contiguous. - -This is important for performance, especially when using the `@` operator for matrix multiplication (e.g., `v @ q`). - -Without this specification, Numba might need to handle non-contiguous arrays, which can significantly slow down these operations. +We implement the Bellman operator $T$ from {eq}`odu_pv3` as follows -Try to replace `[::1]` with `[:]` and see what happens. -``` - -Here's a class that stores the data and computes the values of state-action pairs, -i.e. the value in the maximum bracket on the right hand side of the Bellman equation {eq}`odu_pv2p`, -given the current state and an arbitrary feasible action. - -Default parameter values are embedded in the class. - -```{code-cell} python3 -@jitclass(mccall_data) -class McCallModel: - - def __init__(self, c=25, β=0.99, w=w_default, q=q_default): - - self.c, self.β = c, β - self.w, self.q = w_default, q_default - - def state_action_values(self, i, v): - """ - The values of state-action pairs. - """ - # Simplify names - c, β, w, q = self.c, self.β, self.w, self.q - # Evaluate value for each state-action pair - # Consider action = accept or reject the current offer - accept = w[i] / (1 - β) - reject = c + β * (v @ q) - - return np.array([accept, reject]) +```{code-cell} ipython3 +def T(model: McCallModel, v: jnp.ndarray): + c, β, w, q = model + accept = w / (1 - β) + reject = c + β * v @ q + return jnp.maximum(accept, reject) ``` Based on these defaults, let's try plotting the first few approximate value functions @@ -425,87 +396,62 @@ in the sequence $\{ T^k v \}$. We will start from guess $v$ given by $v(i) = w(i) / (1 - β)$, which is the value of accepting at every given wage. -Here's a function to implement this: - -```{code-cell} python3 -def plot_value_function_seq(mcm, ax, num_plots=6): - """ - Plot a sequence of value functions. - - * mcm is an instance of McCallModel - * ax is an axes object that implements a plot method. - - """ - - n = len(mcm.w) - v = mcm.w / (1 - mcm.β) - v_next = np.empty_like(v) - for i in range(num_plots): - ax.plot(mcm.w, v, '-', alpha=0.4, label=f"iterate {i}") - # Update guess - for j in range(n): - v_next[j] = np.max(mcm.state_action_values(j, v)) - v[:] = v_next # copy contents into v - - ax.legend(loc='lower right') -``` - -Now let's create an instance of `McCallModel` and watch iterations $T^k v$ converge from below: +```{code-cell} ipython3 +model = McCallModel() +c, β, w, q = model +v = w / (1 - β) # Initial condition +fig, ax = plt.subplots() -```{code-cell} python3 -mcm = McCallModel() +num_plots = 6 +for i in range(num_plots): + ax.plot(w, v, '-', alpha=0.6, lw=2, label=f"iterate {i}") + v = T(model, v) -fig, ax = plt.subplots() +ax.legend(loc='lower right') ax.set_xlabel('wage') ax.set_ylabel('value') -plot_value_function_seq(mcm, ax) plt.show() ``` You can see that convergence is occurring: successive iterates are getting closer together. -Here's a more serious iteration effort to compute the limit, which continues until measured deviation between successive iterates is below tol. +Here's a more serious iteration effort to compute the limit, which continues +until measured deviation between successive iterates is below tol. Once we obtain a good approximation to the limit, we will use it to calculate the reservation wage. -We'll be using JIT compilation via Numba to turbocharge our loops. - -```{code-cell} python3 -@jit -def compute_reservation_wage(mcm, - max_iter=500, - tol=1e-6): - - # Simplify names - c, β, w, q = mcm.c, mcm.β, mcm.w, mcm.q - - # == First compute the value function == # - - n = len(w) - v = w / (1 - β) # initial guess - v_next = np.empty_like(v) - j = 0 - error = tol + 1 - while j < max_iter and error > tol: - - for j in range(n): - v_next[j] = np.max(mcm.state_action_values(j, v)) - - error = np.max(np.abs(v_next - v)) - j += 1 - - v[:] = v_next # copy contents into v - - # == Now compute the reservation wage == # - - return (1 - β) * (c + β * (v @ q)) +```{code-cell} ipython3 +def compute_reservation_wage( + model: McCallModel, # instance containing default parameters + v_init: jnp.ndarray, # initial condition for iteration + tol: float=1e-6, # error tolerance + max_iter: int=500, # maximum number of iterations for loop + ): + "Computes the reservation wage in the McCall job search model." + c, β, w, q = model + i = 0 + error = tol + 1 + v = v_init + + while i < max_iter and error > tol: + v_next = T(model, v) + error = jnp.max(jnp.abs(v_next - v)) + v = v_next + i += 1 + + res_wage = (1 - β) * (c + β * v @ q) + return v, res_wage ``` -The next line computes the reservation wage at default parameters +The cell computes the reservation wage at the default parameters -```{code-cell} python3 -compute_reservation_wage(mcm) +```{code-cell} ipython3 +model = McCallModel() +c, β, w, q = model +v_init = w / (1 - β) # initial guess +v, res_wage = compute_reservation_wage(model, v_init) +print(res_wage) ``` ### Comparative Statics @@ -516,48 +462,80 @@ parameters. In particular, let's look at what happens when we change $\beta$ and $c$. -```{code-cell} python3 +As a first step, given that we'll use it many times, let's create a more +efficient, jit-complied version of the function that computes the reservation +wage: + +```{code-cell} ipython3 +@jax.jit +def compute_res_wage_jitted( + model: McCallModel, # instance containing default parameters + v_init: jnp.ndarray, # initial condition for iteration + tol: float=1e-6, # error tolerance + max_iter: int=500, # maximum number of iterations for loop + ): + c, β, w, q = model + i = 0 + error = tol + 1 + initial_state = v_init, i, error + + def cond(loop_state): + v, i, error = loop_state + return jnp.logical_and(i < max_iter, error > tol) + + def update(loop_state): + v, i, error = loop_state + v_next = T(model, v) + error = jnp.max(jnp.abs(v_next - v)) + i += 1 + new_loop_state = v_next, i, error + return new_loop_state + + final_state = jax.lax.while_loop(cond, update, initial_state) + v, i, error = final_state + + res_wage = (1 - β) * (c + β * v @ q) + return v, res_wage +``` + +Now we compute the reservation wage at each $c, \beta$ pair. + +```{code-cell} ipython3 grid_size = 25 -R = np.empty((grid_size, grid_size)) +c_vals = jnp.linspace(10.0, 30.0, grid_size) +β_vals = jnp.linspace(0.9, 0.99, grid_size) -c_vals = np.linspace(10.0, 30.0, grid_size) -β_vals = np.linspace(0.9, 0.99, grid_size) +res_wage_matrix = np.empty((grid_size, grid_size)) +model = McCallModel() +v_init = model.w / (1 - model.β) for i, c in enumerate(c_vals): for j, β in enumerate(β_vals): - mcm = McCallModel(c=c, β=β) - R[i, j] = compute_reservation_wage(mcm) -``` + model = McCallModel(c=c, β=β) + v, res_wage = compute_res_wage_jitted(model, v_init) + v_init = v + res_wage_matrix[i, j] = res_wage -```{code-cell} python3 fig, ax = plt.subplots() - -cs1 = ax.contourf(c_vals, β_vals, R.T, alpha=0.75) -ctr1 = ax.contour(c_vals, β_vals, R.T) - +cs1 = ax.contourf(c_vals, β_vals, res_wage_matrix.T, alpha=0.75) +ctr1 = ax.contour(c_vals, β_vals, res_wage_matrix.T) plt.clabel(ctr1, inline=1, fontsize=13) plt.colorbar(cs1, ax=ax) - - ax.set_title("reservation wage") ax.set_xlabel("$c$", fontsize=16) ax.set_ylabel("$β$", fontsize=16) - ax.ticklabel_format(useOffset=False) - plt.show() ``` -As expected, the reservation wage increases both with patience and with -unemployment compensation. +As expected, the reservation wage increases with both patience and unemployment compensation. (mm_op2)= ## Computing an Optimal Policy: Take 2 -The approach to dynamic programming just described is standard and -broadly applicable. +The approach to dynamic programming just described is standard and broadly applicable. -But for our McCall search model there's also an easier way that circumvents the +But for our McCall search model there's also an easier way that circumvents the need to compute the value function. Let $h$ denote the continuation value: @@ -565,17 +543,14 @@ Let $h$ denote the continuation value: ```{math} :label: j1 -h -= c + \beta - \sum_{s'} v^*(s') q (s') -\quad + h = c + \beta \sum_{w'} v^*(w') q (w') ``` The Bellman equation can now be written as $$ -v^*(s') -= \max \left\{ \frac{w(s')}{1 - \beta}, \, h \right\} + v^*(w') + = \max \left\{ \frac{w'}{1 - \beta}, \, h \right\} $$ Substituting this last equation into {eq}`j1` gives @@ -583,13 +558,11 @@ Substituting this last equation into {eq}`j1` gives ```{math} :label: j2 -h -= c + \beta - \sum_{s' \in \mathbb S} - \max \left\{ - \frac{w(s')}{1 - \beta}, h - \right\} q (s') -\quad + h = c + \beta + \sum_{w' \in \mathbb W} + \max \left\{ + \frac{w'}{1 - \beta}, h + \right\} q (w') ``` This is a nonlinear equation that we can solve for $h$. @@ -605,10 +578,10 @@ Step 2: compute the update $h'$ via h' = c + \beta - \sum_{s' \in \mathbb S} + \sum_{w' \in \mathbb W} \max \left\{ - \frac{w(s')}{1 - \beta}, h - \right\} q (s') + \frac{w'}{1 - \beta}, h + \right\} q (w') \quad ``` @@ -622,32 +595,36 @@ The big difference here, however, is that we're iterating on a scalar $h$, rathe Here's an implementation: -```{code-cell} python3 -@jit -def compute_reservation_wage_two(mcm, - max_iter=500, - tol=1e-5): - - # Simplify names - c, β, w, q = mcm.c, mcm.β, mcm.w, mcm.q - - # == First compute h == # - - h = (w @ q) / (1 - β) +```{code-cell} ipython3 +@jax.jit +def compute_reservation_wage_two( + model: McCallModel, # instance containing default parameters + tol: float=1e-5, # error tolerance + max_iter: int=500, # maximum number of iterations for loop + ): + c, β, w, q = model + h = (w @ q) / (1 - β) # initial condition i = 0 error = tol + 1 - while i < max_iter and error > tol: - - s = np.maximum(w / (1 - β), h) - h_next = c + β * (s @ q) + initial_loop_state = i, h, error - error = np.abs(h_next - h) - i += 1 + def cond(loop_state): + i, h, error = loop_state + return jnp.logical_and(i < max_iter, error > tol) - h = h_next + def update(loop_state): + i, h, error = loop_state + s = jnp.maximum(w / (1 - β), h) + h_next = c + β * (s @ q) + error = jnp.abs(h_next - h) + i_next = i + 1 + new_loop_state = i_next, h_next, error + return new_loop_state - # == Now compute the reservation wage == # + final_state = jax.lax.while_loop(cond, update, initial_loop_state) + i, h, error = final_state + # Compute and return the reservation wage return (1 - β) * h ``` @@ -675,19 +652,25 @@ Plot mean unemployment duration as a function of $c$ in `c_vals`. :class: dropdown ``` -Here's one solution +Here's a solution using Numba. -```{code-cell} python3 -cdf = np.cumsum(q_default) +```{code-cell} ipython3 +# Convert JAX arrays to NumPy arrays for use with Numba +q_default_np = np.array(q_default) +w_default_np = np.array(w_default) +cdf = np.cumsum(q_default_np) -@jit +@numba.jit def compute_stopping_time(w_bar, seed=1234): - + """ + Compute stopping time by drawing wages until one exceeds w_bar. + """ np.random.seed(seed) t = 1 while True: # Generate a wage draw - w = w_default[qe.random.draw(cdf)] + w = w_default_np[qe.random.draw(cdf)] + # Stop when the draw is above the reservation wage if w >= w_bar: stopping_time = t @@ -696,10 +679,14 @@ def compute_stopping_time(w_bar, seed=1234): t += 1 return stopping_time -@jit +@numba.jit(parallel=True) def compute_mean_stopping_time(w_bar, num_reps=100000): + """ + Generate a mean stopping time over `num_reps` repetitions by + drawing from `compute_stopping_time`. + """ obs = np.empty(num_reps) - for i in range(num_reps): + for i in numba.prange(num_reps): obs[i] = compute_stopping_time(w_bar, seed=i) return obs.mean() @@ -708,7 +695,72 @@ stop_times = np.empty_like(c_vals) for i, c in enumerate(c_vals): mcm = McCallModel(c=c) w_bar = compute_reservation_wage_two(mcm) - stop_times[i] = compute_mean_stopping_time(w_bar) + stop_times[i] = compute_mean_stopping_time(float(w_bar)) + +fig, ax = plt.subplots() + +ax.plot(c_vals, stop_times, label="mean unemployment duration") +ax.set(xlabel="unemployment compensation", ylabel="months") +ax.legend() + +plt.show() +``` + +And here's a solution using JAX. + +```{code-cell} ipython3 +cdf = jnp.cumsum(q_default) + +@jax.jit +def compute_stopping_time(w_bar, key): + """ + Compute stopping time by drawing wages until one exceeds `w_bar`. + """ + def update(loop_state): + t, key, accept = loop_state + key, subkey = jax.random.split(key) + u = jax.random.uniform(subkey) + w = w_default[jnp.searchsorted(cdf, u)] + accept = w >= w_bar + t = t + 1 + return t, key, accept + + def cond(loop_state): + _, _, accept = loop_state + return jnp.logical_not(accept) + + initial_loop_state = (0, key, False) + t_final, _, _ = jax.lax.while_loop(cond, update, initial_loop_state) + return t_final + + +@partial(jax.jit, static_argnames=('num_reps',)) +def compute_mean_stopping_time(w_bar, num_reps=100000, seed=1234): + """ + Generate a mean stopping time over `num_reps` repetitions by + drawing from `compute_stopping_time`. + """ + # Generate a key for each MC replication + key = jax.random.PRNGKey(seed) + keys = jax.random.split(key, num_reps) + + # Vectorize compute_stopping_time and evaluate across keys + compute_fn = jax.vmap(compute_stopping_time, in_axes=(None, 0)) + obs = compute_fn(w_bar, keys) + + # Return mean stopping time + return jnp.mean(obs) + +c_vals = jnp.linspace(10, 40, 25) + +def compute_stop_time_for_c(c): + """Compute mean stopping time for a given compensation value c.""" + model = McCallModel(c=c) + w_bar = compute_reservation_wage_two(model) + return compute_mean_stopping_time(w_bar) + +# Vectorize across all c values +stop_times = jax.vmap(compute_stop_time_for_c)(c_vals) fig, ax = plt.subplots() @@ -719,6 +771,8 @@ ax.legend() plt.show() ``` +At least for our hardware, Numba is faster on the CPU while JAX is faster on the GPU. + ```{solution-end} ``` @@ -788,49 +842,46 @@ Once your code is working, investigate how the reservation wage changes with $c$ Here is one solution: -```{code-cell} python3 -mccall_data_continuous = [ - ('c', float64), # unemployment compensation - ('β', float64), # discount factor - ('σ', float64), # scale parameter in lognormal distribution - ('μ', float64), # location parameter in lognormal distribution - ('w_draws', float64[:]) # draws of wages for Monte Carlo -] - -@jitclass(mccall_data_continuous) -class McCallModelContinuous: - - def __init__(self, c=25, β=0.99, σ=0.5, μ=2.5, mc_size=1000): - - self.c, self.β, self.σ, self.μ = c, β, σ, μ - - # Draw and store shocks - np.random.seed(1234) - s = np.random.randn(mc_size) - self.w_draws = np.exp(μ+ σ * s) - - -@jit -def compute_reservation_wage_continuous(mcmc, max_iter=500, tol=1e-5): - - c, β, σ, μ, w_draws = mcmc.c, mcmc.β, mcmc.σ, mcmc.μ, mcmc.w_draws - - h = np.mean(w_draws) / (1 - β) # initial guess - i = 0 - error = tol + 1 - while i < max_iter and error > tol: - - integral = np.mean(np.maximum(w_draws / (1 - β), h)) +```{code-cell} ipython3 +class McCallModelContinuous(NamedTuple): + c: float # unemployment compensation + β: float # discount factor + σ: float # scale parameter in lognormal distribution + μ: float # location parameter in lognormal distribution + w_draws: jnp.ndarray # draws of wages for Monte Carlo + + +def create_mccall_continuous( + c=25, β=0.99, σ=0.5, μ=2.5, mc_size=1000, seed=1234 + ): + key = jax.random.PRNGKey(seed) + s = jax.random.normal(key, (mc_size,)) + w_draws = jnp.exp(μ + σ * s) + return McCallModelContinuous(c=c, β=β, σ=σ, μ=μ, w_draws=w_draws) + + +@jax.jit +def compute_reservation_wage_continuous(model, max_iter=500, tol=1e-5): + c, β, σ, μ, w_draws = model.c, model.β, model.σ, model.μ, model.w_draws + + h = jnp.mean(w_draws) / (1 - β) # initial guess + + def update(state): + h, i, error = state + integral = jnp.mean(jnp.maximum(w_draws / (1 - β), h)) h_next = c + β * integral - - error = np.abs(h_next - h) - i += 1 - - h = h_next - - # == Now compute the reservation wage == # - - return (1 - β) * h + error = jnp.abs(h_next - h) + return h_next, i + 1, error + + def cond(state): + h, i, error = state + return jnp.logical_and(i < max_iter, error > tol) + + initial_state = (h, 0, tol + 1) + h_final, _, _ = jax.lax.while_loop(cond, update, initial_state) + + # Now compute the reservation wage + return (1 - β) * h_final ``` Now we investigate how the reservation wage changes with $c$ and @@ -838,20 +889,25 @@ $\beta$. We will do this using a contour plot. -```{code-cell} python3 +```{code-cell} ipython3 grid_size = 25 -R = np.empty((grid_size, grid_size)) +c_vals = jnp.linspace(10.0, 30.0, grid_size) +β_vals = jnp.linspace(0.9, 0.99, grid_size) -c_vals = np.linspace(10.0, 30.0, grid_size) -β_vals = np.linspace(0.9, 0.99, grid_size) +def compute_R_element(c, β): + model = create_mccall_continuous(c=c, β=β) + return compute_reservation_wage_continuous(model) -for i, c in enumerate(c_vals): - for j, β in enumerate(β_vals): - mcmc = McCallModelContinuous(c=c, β=β) - R[i, j] = compute_reservation_wage_continuous(mcmc) +# Create meshgrid and vectorize computation +c_grid, β_grid = jnp.meshgrid(c_vals, β_vals, indexing='ij') +compute_R_vectorized = jax.vmap( + jax.vmap(compute_R_element, + in_axes=(None, 0)), + in_axes=(0, None)) +R = compute_R_vectorized(c_vals, β_vals) ``` -```{code-cell} python3 +```{code-cell} ipython3 fig, ax = plt.subplots() cs1 = ax.contourf(c_vals, β_vals, R.T, alpha=0.75)