From a225ed9b5bb818c46dacd832bb5df875e19d73f3 Mon Sep 17 00:00:00 2001 From: Humphrey Yang Date: Tue, 16 Sep 2025 14:27:32 +1000 Subject: [PATCH 1/6] update career to JAX --- lectures/career.md | 322 ++++++++++++++++++++++++--------------------- 1 file changed, 169 insertions(+), 153 deletions(-) diff --git a/lectures/career.md b/lectures/career.md index b826ae9dc..671f34c15 100644 --- a/lectures/career.md +++ b/lectures/career.md @@ -3,10 +3,12 @@ jupytext: text_representation: extension: .md format_name: myst + format_version: 0.13 + jupytext_version: 1.17.2 kernelspec: - display_name: Python 3 - language: python name: python3 + display_name: Python 3 (ipykernel) + language: python --- (career)= @@ -29,13 +31,24 @@ kernelspec: In addition to what's in Anaconda, this lecture will need the following libraries: -```{code-cell} ipython ---- -tags: [hide-output] ---- +```{code-cell} ipython3 +:tags: [hide-output] + !pip install quantecon ``` +```{admonition} GPU acceleration +:class: warning + +This lecture uses JAX for hardware acceleration and automatic differentiation. + +For faster execution, consider running this lecture on a GPU. + +You can access free GPUs on [Google Colab](https://colab.research.google.com/) by selecting "Runtime → Change runtime type → Hardware accelerator → GPU" from the menu. + +To install JAX with GPU support locally, please consult the [JAX installation guide](https://jax.readthedocs.io/en/latest/installation.html). +``` + ## Overview Next, we study a computational problem concerning career and job choices. @@ -46,11 +59,12 @@ This exposition draws on the presentation in {cite}`Ljungqvist2012`, section 6.5 We begin with some imports: -```{code-cell} ipython +```{code-cell} ipython3 import matplotlib.pyplot as plt -import numpy as np +import jax.numpy as jnp +import jax import quantecon as qe -from numba import jit, prange +from typing import NamedTuple from quantecon.distributions import BetaBinomial from scipy.special import binom, beta from mpl_toolkits.mplot3d.axes3d import Axes3D @@ -163,11 +177,11 @@ Nice properties: Here's a figure showing the effect on the pmf of different shape parameters when $n=50$. -```{code-cell} python3 +```{code-cell} ipython3 def gen_probs(n, a, b): - probs = np.zeros(n+1) - for k in range(n+1): - probs[k] = binom(n, k) * beta(k + a, n - k + b) / beta(a, b) + probs = jnp.zeros(n+1) + k_vals = jnp.arange(n+1) + probs = jnp.array([binom(n, k) * beta(k + a, n - k + b) / beta(a, b) for k in range(n+1)]) return probs n = 50 @@ -183,161 +197,163 @@ plt.show() ## Implementation -We will first create a class `CareerWorkerProblem` which will hold the -default parameterizations of the model and an initial guess for the value function. - -```{code-cell} python3 -class CareerWorkerProblem: - - def __init__(self, - B=5.0, # Upper bound - β=0.95, # Discount factor - grid_size=50, # Grid size - F_a=1, - F_b=1, - G_a=1, - G_b=1): - - self.β, self.grid_size, self.B = β, grid_size, B - - self.θ = np.linspace(0, B, grid_size) # Set of θ values - self.ϵ = np.linspace(0, B, grid_size) # Set of ϵ values - - self.F_probs = BetaBinomial(grid_size - 1, F_a, F_b).pdf() - self.G_probs = BetaBinomial(grid_size - 1, G_a, G_b).pdf() - self.F_mean = self.θ @ self.F_probs - self.G_mean = self.ϵ @ self.G_probs - - # Store these parameters for str and repr methods - self._F_a, self._F_b = F_a, F_b - self._G_a, self._G_b = G_a, G_b +We will first create a JAX-compatible model structure using `NamedTuple` to store +the model parameters and computed distributions. + +```{code-cell} ipython3 +class CareerWorkerProblem(NamedTuple): + β: float # Discount factor + B: float # Upper bound + grid_size: int # Grid size + θ: jnp.ndarray # Set of θ values + ε: jnp.ndarray # Set of ε values + F_probs: jnp.ndarray # Probabilities for F distribution + G_probs: jnp.ndarray # Probabilities for G distribution + F_mean: float # Mean of F distribution + G_mean: float # Mean of G distribution + +def create_career_worker_problem(B=5.0, β=0.95, grid_size=50, + F_a=1, F_b=1, G_a=1, G_b=1): + """ + Factory function to create a CareerWorkerProblem instance. + """ + θ = jnp.linspace(0, B, grid_size) # Set of θ values + ε = jnp.linspace(0, B, grid_size) # Set of ε values + + F_probs = jnp.array(BetaBinomial(grid_size - 1, F_a, F_b).pdf()) + G_probs = jnp.array(BetaBinomial(grid_size - 1, G_a, G_b).pdf()) + F_mean = θ @ F_probs + G_mean = ε @ G_probs + + return CareerWorkerProblem( + β=β, B=B, grid_size=grid_size, + θ=θ, ε=ε, + F_probs=F_probs, G_probs=G_probs, + F_mean=F_mean, G_mean=G_mean + ) ``` -The following function takes an instance of `CareerWorkerProblem` and returns -the corresponding Bellman operator $T$ and the greedy policy function. +The following functions implement the Bellman operator $T$ and the greedy policy function +using JAX. In this model, $T$ is defined by $Tv(\theta, \epsilon) = \max\{I, II, III\}$, where $I$, $II$ and $III$ are as given in {eq}`eyes`. -```{code-cell} python3 -def operator_factory(cw, parallel_flag=True): - +```{code-cell} ipython3 +@jax.jit +def bellman_operator(model, v): """ - Returns jitted versions of the Bellman operator and the - greedy policy function - - cw is an instance of ``CareerWorkerProblem`` + The Bellman operator for the career choice model. """ + θ, ε, β = model.θ, model.ε, model.β + F_probs, G_probs = model.F_probs, model.G_probs + F_mean, G_mean = model.F_mean, model.G_mean - θ, ϵ, β = cw.θ, cw.ϵ, cw.β - F_probs, G_probs = cw.F_probs, cw.G_probs - F_mean, G_mean = cw.F_mean, cw.G_mean + # Vectorized computation + # Broadcasting θ and ε to create all combinations + θ_grid, ε_grid = jnp.meshgrid(θ, ε, indexing='ij') - @jit(parallel=parallel_flag) - def T(v): - "The Bellman operator" + # Option 1: Stay put + v1 = θ_grid + ε_grid + β * v - v_new = np.empty_like(v) + # Option 2: New job (keep θ, new ε) + # For each θ[i], compute expected value over new ε + ev_new_job = jnp.dot(v, G_probs) # Expected value for each θ + v2 = θ_grid + G_mean + β * ev_new_job[:, jnp.newaxis] - for i in prange(len(v)): - for j in prange(len(v)): - v1 = θ[i] + ϵ[j] + β * v[i, j] # Stay put - v2 = θ[i] + G_mean + β * v[i, :] @ G_probs # New job - v3 = G_mean + F_mean + β * F_probs @ v @ G_probs # New life - v_new[i, j] = max(v1, v2, v3) + # Option 3: New life (new θ and new ε) + # Expected value over both θ and ε + ev_new_life = jnp.dot(F_probs, jnp.dot(v, G_probs)) + v3 = jnp.full_like(v, G_mean + F_mean + β * ev_new_life) - return v_new + return jnp.maximum(jnp.maximum(v1, v2), v3) - @jit - def get_greedy(v): - "Computes the v-greedy policy" +@jax.jit +def get_greedy_policy(model, v): + """ + Computes the greedy policy given the value function. + * Policy function where 1=stay put, 2=new job, 3=new life + """ + θ, ε, β = model.θ, model.ε, model.β + F_probs, G_probs = model.F_probs, model.G_probs + F_mean, G_mean = model.F_mean, model.G_mean - σ = np.empty(v.shape) + # Vectorized computation + # Broadcasting θ and ε to create all combinations + θ_grid, ε_grid = jnp.meshgrid(θ, ε, indexing='ij') - for i in range(len(v)): - for j in range(len(v)): - v1 = θ[i] + ϵ[j] + β * v[i, j] - v2 = θ[i] + G_mean + β * v[i, :] @ G_probs - v3 = G_mean + F_mean + β * F_probs @ v @ G_probs - if v1 > max(v2, v3): - action = 1 - elif v2 > max(v1, v3): - action = 2 - else: - action = 3 - σ[i, j] = action + # Option 1: Stay put + v1 = θ_grid + ε_grid + β * v - return σ + # Option 2: New job (keep θ, new ε) + ev_new_job = jnp.dot(v, G_probs) # Expected value for each θ + v2 = θ_grid + G_mean + β * ev_new_job[:, jnp.newaxis] - return T, get_greedy -``` + # Option 3: New life (new θ and new ε) + ev_new_life = jnp.dot(F_probs, jnp.dot(v, G_probs)) + v3 = jnp.full_like(v, G_mean + F_mean + β * ev_new_life) -Lastly, `solve_model` will take an instance of `CareerWorkerProblem` and -iterate using the Bellman operator to find the fixed point of the Bellman equation. + # Stack the value arrays and find argmax along first axis + values = jnp.stack([v1, v2, v3], axis=0) -```{code-cell} python3 -def solve_model(cw, - use_parallel=True, - tol=1e-4, - max_iter=1000, - verbose=True, - print_skip=25): + # +1 because actions are 1, 2, 3 not 0, 1, 2 + policy = jnp.argmax(values, axis=0) + 1 - T, _ = operator_factory(cw, parallel_flag=use_parallel) + return policy +``` - # Set up loop - v = np.full((cw.grid_size, cw.grid_size), 100.) # Initial guess - i = 0 +Lastly, `solve_model` will take an instance of `CareerWorkerProblem` and +iterate using the Bellman operator to find the fixed point of the Bellman equation. + +```{code-cell} ipython3 +def solve_model(model, tol=1e-4, max_iter=1000): + """ + Solve the career choice model using JAX. + """ + # Initial guess + v = jnp.full((model.grid_size, model.grid_size), 100.0) error = tol + 1 + i = 0 while i < max_iter and error > tol: - v_new = T(v) - error = np.max(np.abs(v - v_new)) - i += 1 - if verbose and i % print_skip == 0: - print(f"Error at iteration {i} is {error}.") + v_new = bellman_operator(model, v) + error = jnp.max(jnp.abs(v_new - v)) v = v_new + i += 1 - if error > tol: - print("Failed to converge!") - - elif verbose: - print(f"\nConverged in {i} iterations.") - - return v_new + return v ``` Here's the solution to the model -- an approximate value function -```{code-cell} python3 -cw = CareerWorkerProblem() -T, get_greedy = operator_factory(cw) -v_star = solve_model(cw, verbose=False) -greedy_star = get_greedy(v_star) +```{code-cell} ipython3 +model = create_career_worker_problem() +v_star = solve_model(model) +greedy_star = get_greedy_policy(model, v_star) fig = plt.figure(figsize=(8, 6)) ax = fig.add_subplot(111, projection='3d') -tg, eg = np.meshgrid(cw.θ, cw.ϵ) +tg, eg = jnp.meshgrid(model.θ, model.ε) ax.plot_surface(tg, eg, v_star.T, cmap=cm.jet, alpha=0.5, linewidth=0.25) -ax.set(xlabel='θ', ylabel='ϵ', zlim=(150, 200)) +ax.set(xlabel='θ', ylabel='ε', zlim=(150, 200)) ax.view_init(ax.elev, 225) plt.show() ``` And here is the optimal policy -```{code-cell} python3 +```{code-cell} ipython3 fig, ax = plt.subplots(figsize=(6, 6)) -tg, eg = np.meshgrid(cw.θ, cw.ϵ) +tg, eg = jnp.meshgrid(model.θ, model.ε) lvls = (0.5, 1.5, 2.5, 3.5) ax.contourf(tg, eg, greedy_star.T, levels=lvls, cmap=cm.winter, alpha=0.5) ax.contour(tg, eg, greedy_star.T, colors='k', levels=lvls, linewidths=2) -ax.set(xlabel='θ', ylabel='ϵ') +ax.set(xlabel='θ', ylabel='ε') ax.text(1.8, 2.5, 'new life', fontsize=14) ax.text(4.5, 2.5, 'new job', fontsize=14, rotation='vertical') ax.text(4.0, 4.5, 'stay put', fontsize=14) @@ -392,35 +408,39 @@ In reading the code, recall that `optimal_policy[i, j]` = policy at $(\theta_i, \epsilon_j)$ = either 1, 2 or 3; meaning 'stay put', 'new job' and 'new life'. -```{code-cell} python3 -F = np.cumsum(cw.F_probs) -G = np.cumsum(cw.G_probs) -v_star = solve_model(cw, verbose=False) -T, get_greedy = operator_factory(cw) -greedy_star = get_greedy(v_star) +```{code-cell} ipython3 +model = create_career_worker_problem() +F = jnp.cumsum(model.F_probs) +G = jnp.cumsum(model.G_probs) +v_star = solve_model(model) +greedy_star = get_greedy_policy(model, v_star) -def gen_path(optimal_policy, F, G, t=20): +def gen_path(optimal_policy, F, G, model, t=20): i = j = 0 θ_index = [] - ϵ_index = [] + ε_index = [] for t in range(t): if optimal_policy[i, j] == 1: # Stay put pass - elif greedy_star[i, j] == 2: # New job + elif optimal_policy[i, j] == 2: # New job j = qe.random.draw(G) else: # New life i, j = qe.random.draw(F), qe.random.draw(G) θ_index.append(i) - ϵ_index.append(j) - return cw.θ[θ_index], cw.ϵ[ϵ_index] + ε_index.append(j) + + # Convert lists to JAX arrays for indexing + θ_indices = jnp.array(θ_index) + ε_indices = jnp.array(ε_index) + return model.θ[θ_indices], model.ε[ε_indices] fig, axes = plt.subplots(2, 1, figsize=(10, 8)) for ax in axes: - θ_path, ϵ_path = gen_path(greedy_star, F, G) - ax.plot(ϵ_path, label='ϵ') + θ_path, ε_path = gen_path(greedy_star, F, G, model) + ax.plot(ε_path, label='ε') ax.plot(θ_path, label='θ') ax.set_ylim(0, 6) @@ -464,15 +484,13 @@ Repeat the exercise with $\beta=0.99$ and interpret the change. The median for the original parameterization can be computed as follows -```{code-cell} python3 -cw = CareerWorkerProblem() -F = np.cumsum(cw.F_probs) -G = np.cumsum(cw.G_probs) -T, get_greedy = operator_factory(cw) -v_star = solve_model(cw, verbose=False) -greedy_star = get_greedy(v_star) +```{code-cell} ipython3 +model = create_career_worker_problem() +F = jnp.cumsum(model.F_probs) +G = jnp.cumsum(model.G_probs) +v_star = solve_model(model) +greedy_star = get_greedy_policy(model, v_star) -@jit def passage_time(optimal_policy, F, G): t = 0 i = j = 0 @@ -485,19 +503,18 @@ def passage_time(optimal_policy, F, G): i, j = qe.random.draw(F), qe.random.draw(G) t += 1 -@jit(parallel=True) def median_time(optimal_policy, F, G, M=25000): - samples = np.empty(M) - for i in prange(M): - samples[i] = passage_time(optimal_policy, F, G) - return np.median(samples) + samples = [] + for i in range(M): + samples.append(passage_time(optimal_policy, F, G)) + return jnp.median(jnp.array(samples)) median_time(greedy_star, F, G) ``` To compute the median with $\beta=0.99$ instead of the default -value $\beta=0.95$, replace `cw = CareerWorkerProblem()` with -`cw = CareerWorkerProblem(β=0.99)`. +value $\beta=0.95$, replace `model = create_career_worker_problem()` with +`model = create_career_worker_problem(β=0.99)`. The medians are subject to randomness but should be about 7 and 14 respectively. @@ -520,18 +537,17 @@ figure -- interpret. Here is one solution -```{code-cell} python3 -cw = CareerWorkerProblem(G_a=100, G_b=100) -T, get_greedy = operator_factory(cw) -v_star = solve_model(cw, verbose=False) -greedy_star = get_greedy(v_star) +```{code-cell} ipython3 +model = create_career_worker_problem(G_a=100, G_b=100) +v_star = solve_model(model) +greedy_star = get_greedy_policy(model, v_star) fig, ax = plt.subplots(figsize=(6, 6)) -tg, eg = np.meshgrid(cw.θ, cw.ϵ) +tg, eg = jnp.meshgrid(model.θ, model.ε) lvls = (0.5, 1.5, 2.5, 3.5) ax.contourf(tg, eg, greedy_star.T, levels=lvls, cmap=cm.winter, alpha=0.5) ax.contour(tg, eg, greedy_star.T, colors='k', levels=lvls, linewidths=2) -ax.set(xlabel='θ', ylabel='ϵ') +ax.set(xlabel='θ', ylabel='ε') ax.text(1.8, 2.5, 'new life', fontsize=14) ax.text(4.5, 1.5, 'new job', fontsize=14, rotation='vertical') ax.text(4.0, 4.5, 'stay put', fontsize=14) From 62daf27f6705ec7c3e3d105329e81c6c68e255ca Mon Sep 17 00:00:00 2001 From: Humphrey Yang Date: Tue, 16 Sep 2025 16:11:42 +1000 Subject: [PATCH 2/6] update --- lectures/career.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lectures/career.md b/lectures/career.md index 671f34c15..6c7f9d651 100644 --- a/lectures/career.md +++ b/lectures/career.md @@ -410,8 +410,8 @@ $(\theta_i, \epsilon_j)$ = either 1, 2 or 3; meaning 'stay put', ```{code-cell} ipython3 model = create_career_worker_problem() -F = jnp.cumsum(model.F_probs) -G = jnp.cumsum(model.G_probs) +F = np.array(jnp.cumsum(model.F_probs)) +G = np.array(jnp.cumsum(model.G_probs)) v_star = solve_model(model) greedy_star = get_greedy_policy(model, v_star) @@ -486,8 +486,8 @@ The median for the original parameterization can be computed as follows ```{code-cell} ipython3 model = create_career_worker_problem() -F = jnp.cumsum(model.F_probs) -G = jnp.cumsum(model.G_probs) +F = np.array(jnp.cumsum(model.F_probs)) +G = np.array(jnp.cumsum(model.G_probs)) v_star = solve_model(model) greedy_star = get_greedy_policy(model, v_star) From 6232e1877ae90d3e154ea5b66cf58123a81802c0 Mon Sep 17 00:00:00 2001 From: Humphrey Yang Date: Tue, 16 Sep 2025 20:54:04 +1000 Subject: [PATCH 3/6] update --- lectures/career.md | 127 +++++++++++++++++++++++---------------------- 1 file changed, 65 insertions(+), 62 deletions(-) diff --git a/lectures/career.md b/lectures/career.md index 6c7f9d651..f4432989e 100644 --- a/lectures/career.md +++ b/lectures/career.md @@ -4,7 +4,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.17.2 + jupytext_version: 1.17.1 kernelspec: name: python3 display_name: Python 3 (ipykernel) @@ -37,18 +37,6 @@ In addition to what's in Anaconda, this lecture will need the following librarie !pip install quantecon ``` -```{admonition} GPU acceleration -:class: warning - -This lecture uses JAX for hardware acceleration and automatic differentiation. - -For faster execution, consider running this lecture on a GPU. - -You can access free GPUs on [Google Colab](https://colab.research.google.com/) by selecting "Runtime → Change runtime type → Hardware accelerator → GPU" from the menu. - -To install JAX with GPU support locally, please consult the [JAX installation guide](https://jax.readthedocs.io/en/latest/installation.html). -``` - ## Overview Next, we study a computational problem concerning career and job choices. @@ -63,7 +51,7 @@ We begin with some imports: import matplotlib.pyplot as plt import jax.numpy as jnp import jax -import quantecon as qe +import jax.random as jr from typing import NamedTuple from quantecon.distributions import BetaBinomial from scipy.special import binom, beta @@ -391,7 +379,7 @@ In particular, modulo randomness, reproduce the following figure (where the hori ```{hint} :class: dropdown -To generate the draws from the distributions $F$ and $G$, use `quantecon.random.draw()`. +To generate the draws from the distributions $F$ and $G$, use `quantecon.jr.draw()`. ``` ```{exercise-end} @@ -410,41 +398,51 @@ $(\theta_i, \epsilon_j)$ = either 1, 2 or 3; meaning 'stay put', ```{code-cell} ipython3 model = create_career_worker_problem() -F = np.array(jnp.cumsum(model.F_probs)) -G = np.array(jnp.cumsum(model.G_probs)) +F = jnp.cumsum(jnp.asarray(model.F_probs)) +G = jnp.cumsum(jnp.asarray(model.G_probs)) v_star = solve_model(model) -greedy_star = get_greedy_policy(model, v_star) - -def gen_path(optimal_policy, F, G, model, t=20): - i = j = 0 - θ_index = [] - ε_index = [] - for t in range(t): - if optimal_policy[i, j] == 1: # Stay put - pass - - elif optimal_policy[i, j] == 2: # New job - j = qe.random.draw(G) - - else: # New life - i, j = qe.random.draw(F), qe.random.draw(G) - θ_index.append(i) - ε_index.append(j) - - # Convert lists to JAX arrays for indexing - θ_indices = jnp.array(θ_index) - ε_indices = jnp.array(ε_index) - return model.θ[θ_indices], model.ε[ε_indices] +greedy_star = jnp.asarray(get_greedy_policy(model, v_star)) +def draw_from_cdf(key, cdf): + u = jr.uniform(key) + return jnp.searchsorted(cdf, u, side="left") +def gen_path(optimal_policy, F, G, model, t=20, key=None): + if key is None: + key = jr.PRNGKey(0) + i = 0 + j = 0 + theta_idx = [] + eps_idx = [] + for _ in range(t): + a = optimal_policy[i, j] + key, k1, k2 = jr.split(key, 3) + if a == 1: # Stay put + pass + elif a == 2: # New job + j = draw_from_cdf(k1, G) + else: # New life + i = draw_from_cdf(k1, F) + j = draw_from_cdf(k2, G) + theta_idx.append(i) + eps_idx.append(j) + + theta_idx = jnp.array(theta_idx, dtype=jnp.int32) + eps_idx = jnp.array(eps_idx, dtype=jnp.int32) + return model.θ[theta_idx], model.ε[eps_idx], key + +key = jr.PRNGKey(42) fig, axes = plt.subplots(2, 1, figsize=(10, 8)) + for ax in axes: - θ_path, ε_path = gen_path(greedy_star, F, G, model) + key, subkey = jr.split(key) + θ_path, ε_path, _ = gen_path(greedy_star, F, G, model, key=subkey) ax.plot(ε_path, label='ε') ax.plot(θ_path, label='θ') ax.set_ylim(0, 6) + ax.legend(loc='upper right') -plt.legend() +plt.tight_layout() plt.show() ``` @@ -486,28 +484,33 @@ The median for the original parameterization can be computed as follows ```{code-cell} ipython3 model = create_career_worker_problem() -F = np.array(jnp.cumsum(model.F_probs)) -G = np.array(jnp.cumsum(model.G_probs)) +F = jnp.cumsum(jnp.asarray(model.F_probs)) +G = jnp.cumsum(jnp.asarray(model.G_probs)) v_star = solve_model(model) -greedy_star = get_greedy_policy(model, v_star) - -def passage_time(optimal_policy, F, G): - t = 0 - i = j = 0 - while True: - if optimal_policy[i, j] == 1: # Stay put - return t - elif optimal_policy[i, j] == 2: # New job - j = qe.random.draw(G) - else: # New life - i, j = qe.random.draw(F), qe.random.draw(G) - t += 1 - -def median_time(optimal_policy, F, G, M=25000): - samples = [] - for i in range(M): - samples.append(passage_time(optimal_policy, F, G)) - return jnp.median(jnp.array(samples)) +greedy_star = jnp.asarray(get_greedy_policy(model, v_star)) + +def passage_time(optimal_policy, F, G, key): + def cond(state): + i, j, t, key = state + return optimal_policy[i, j] != 1 + + def body(state): + i, j, t, key = state + a = optimal_policy[i, j] + key, k1, k2 = jr.split(key, 3) + new_j = draw_from_cdf(k1, G) + new_i = draw_from_cdf(k2, F) + i = jnp.where(a == 3, new_i, i) + j = jnp.where((a == 2) | (a == 3), new_j, j) + return i, j, t + 1, key + + i, j, t, _ = jax.lax.while_loop(cond, body, (0, 0, 0, key)) + return t + +def median_time(optimal_policy, F, G, M=25000, seed=0): + keys = jr.split(jr.PRNGKey(seed), M) + times = jax.vmap(lambda k: passage_time(optimal_policy, F, G, k))(keys) + return jnp.median(times) median_time(greedy_star, F, G) ``` From 1ededf4e9b00b8278dcda1188e5178b3ec22b1b7 Mon Sep 17 00:00:00 2001 From: Humphrey Yang Date: Wed, 17 Sep 2025 13:04:48 +1000 Subject: [PATCH 4/6] updates --- lectures/career.md | 73 +++++++++++++++++++++++----------------------- 1 file changed, 37 insertions(+), 36 deletions(-) diff --git a/lectures/career.md b/lectures/career.md index f4432989e..b9016574f 100644 --- a/lectures/career.md +++ b/lectures/career.md @@ -4,11 +4,11 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.17.1 + jupytext_version: 1.16.6 kernelspec: - name: python3 display_name: Python 3 (ipykernel) language: python + name: python3 --- (career)= @@ -34,7 +34,15 @@ In addition to what's in Anaconda, this lecture will need the following librarie ```{code-cell} ipython3 :tags: [hide-output] -!pip install quantecon +!pip install --upgrade quantecon +``` + +We also need to install JAX to run this lecture + +```{code-cell} ipython3 +:tags: [skip-execution] + +!pip install -U jax ``` ## Overview @@ -51,7 +59,7 @@ We begin with some imports: import matplotlib.pyplot as plt import jax.numpy as jnp import jax -import jax.random as jr +import jax.random as jr from typing import NamedTuple from quantecon.distributions import BetaBinomial from scipy.special import binom, beta @@ -59,7 +67,7 @@ from mpl_toolkits.mplot3d.axes3d import Axes3D from matplotlib import cm ``` -### Model Features +### Model features * Career and job within career both chosen to maximize expected discounted wage flow. * Infinite horizon dynamic programming with two state variables. @@ -71,7 +79,7 @@ In what follows we distinguish between a career and a job, where * a *career* is understood to be a general field encompassing many possible jobs, and * a *job* is understood to be a position with a particular firm -For workers, wages can be decomposed into the contribution of job and career +For workers, wages can be decomposed into the contributions of job and career * $w_t = \theta_t + \epsilon_t$, where * $\theta_t$ is the contribution of career at time $t$ @@ -134,14 +142,14 @@ Evidently $I$, $II$ and $III$ correspond to "stay put", "new job" and "new life" As in {cite}`Ljungqvist2012`, section 6.5, we will focus on a discrete version of the model, parameterized as follows: * both $\theta$ and $\epsilon$ take values in the set - `np.linspace(0, B, grid_size)` --- an even grid of points between + `jnp.linspace(0, B, grid_size)` --- an even grid of points between $0$ and $B$ inclusive * `grid_size = 50` * `B = 5` * `β = 0.95` The distributions $F$ and $G$ are discrete distributions -generating draws from the grid points `np.linspace(0, B, grid_size)`. +generating draws from the grid points `jnp.linspace(0, B, grid_size)`. A very useful family of discrete distributions is the Beta-binomial family, with probability mass function @@ -229,31 +237,34 @@ $I$, $II$ and $III$ are as given in {eq}`eyes`. ```{code-cell} ipython3 @jax.jit -def bellman_operator(model, v): - """ - The Bellman operator for the career choice model. - """ - θ, ε, β = model.θ, model.ε, model.β - F_probs, G_probs = model.F_probs, model.G_probs - F_mean, G_mean = model.F_mean, model.G_mean - - # Vectorized computation - # Broadcasting θ and ε to create all combinations - θ_grid, ε_grid = jnp.meshgrid(θ, ε, indexing='ij') - +def Q(θ_grid, ε_grid, β, v, F_probs, G_probs, F_mean, G_mean): # Option 1: Stay put v1 = θ_grid + ε_grid + β * v # Option 2: New job (keep θ, new ε) - # For each θ[i], compute expected value over new ε ev_new_job = jnp.dot(v, G_probs) # Expected value for each θ v2 = θ_grid + G_mean + β * ev_new_job[:, jnp.newaxis] # Option 3: New life (new θ and new ε) - # Expected value over both θ and ε ev_new_life = jnp.dot(F_probs, jnp.dot(v, G_probs)) v3 = jnp.full_like(v, G_mean + F_mean + β * ev_new_life) + return v1, v2, v3 + +@jax.jit +def bellman_operator(model, v): + """ + The Bellman operator for the career choice model. + """ + θ, ε, β = model.θ, model.ε, model.β + F_probs, G_probs = model.F_probs, model.G_probs + F_mean, G_mean = model.F_mean, model.G_mean + + v1, v2, v3 = Q( + *jnp.meshgrid(θ, ε, indexing='ij'), + β, v, F_probs, G_probs, F_mean, G_mean + ) + return jnp.maximum(jnp.maximum(v1, v2), v3) @jax.jit @@ -266,20 +277,10 @@ def get_greedy_policy(model, v): F_probs, G_probs = model.F_probs, model.G_probs F_mean, G_mean = model.F_mean, model.G_mean - # Vectorized computation - # Broadcasting θ and ε to create all combinations - θ_grid, ε_grid = jnp.meshgrid(θ, ε, indexing='ij') - - # Option 1: Stay put - v1 = θ_grid + ε_grid + β * v - - # Option 2: New job (keep θ, new ε) - ev_new_job = jnp.dot(v, G_probs) # Expected value for each θ - v2 = θ_grid + G_mean + β * ev_new_job[:, jnp.newaxis] - - # Option 3: New life (new θ and new ε) - ev_new_life = jnp.dot(F_probs, jnp.dot(v, G_probs)) - v3 = jnp.full_like(v, G_mean + F_mean + β * ev_new_life) + v1, v2, v3 = Q( + *jnp.meshgrid(θ, ε, indexing='ij'), + β, v, F_probs, G_probs, F_mean, G_mean + ) # Stack the value arrays and find argmax along first axis values = jnp.stack([v1, v2, v3], axis=0) From 5842d1e2bd81ee712f7a9ab0a20ac512f443b26d Mon Sep 17 00:00:00 2001 From: Humphrey Yang Date: Wed, 17 Sep 2025 16:00:07 +1000 Subject: [PATCH 5/6] use cpu --- lectures/career.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lectures/career.md b/lectures/career.md index b9016574f..b57f39b10 100644 --- a/lectures/career.md +++ b/lectures/career.md @@ -65,6 +65,9 @@ from quantecon.distributions import BetaBinomial from scipy.special import binom, beta from mpl_toolkits.mplot3d.axes3d import Axes3D from matplotlib import cm + +# Set JAX to use CPU +jax.config.update('jax_platform_name', 'cpu') ``` ### Model features From 8707aab2a78ee1e153da00dbdc62f12202a6932b Mon Sep 17 00:00:00 2001 From: Humphrey Yang Date: Thu, 18 Sep 2025 10:02:55 +1000 Subject: [PATCH 6/6] update jax install --- lectures/career.md | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/lectures/career.md b/lectures/career.md index b57f39b10..605c9a6a8 100644 --- a/lectures/career.md +++ b/lectures/career.md @@ -34,15 +34,7 @@ In addition to what's in Anaconda, this lecture will need the following librarie ```{code-cell} ipython3 :tags: [hide-output] -!pip install --upgrade quantecon -``` - -We also need to install JAX to run this lecture - -```{code-cell} ipython3 -:tags: [skip-execution] - -!pip install -U jax +!pip install quantecon jax ``` ## Overview