Skip to content

Commit e5e369e

Browse files
jstacclaude
andcommitted
Fix simulate_household to use embedded y function
- Add local y function inside simulate_household to replace removed y_jax - Maintains consistency with refactoring pattern used in K_numpy and K - All tests pass successfully 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 21f3573 commit e5e369e

File tree

1 file changed

+19
-33
lines changed

1 file changed

+19
-33
lines changed

lectures/ifp_egm.md

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -406,11 +406,6 @@ def create_ifp(r=0.01,
406406
η_draws = np.random.randn(shock_draw_size)
407407
assert R * β < 1, "Stability condition violated."
408408
return IFPNumPy(R, β, γ, Π, z_grid, s, a_y, b_y, η_draws)
409-
410-
# Set y(z, η) = exp(a_y * η + z * b_y)
411-
@numba.jit
412-
def y(z, η, a_y, b_y):
413-
return np.exp(a_y * η + z * b_y)
414409
```
415410

416411
### Solver
@@ -488,8 +483,11 @@ def K_numpy(
488483
next_c = np.interp(next_a, ae_vals[:, k], c_vals[:, k])
489484
# Add to the inner sum
490485
inner_sum += u_prime(next_c)
491-
# Average over η draws and weight by transition probability
492-
expectation += (inner_sum / len(η_draws)) * Π[j, k]
486+
# Average over η draws to approximate the integral
487+
# ∫ u'(σ(R s_i + y(z', η'), z')) φ(η') dη' when z' = z_grid[k]
488+
inner_mean_k = (inner_sum / len(η_draws))
489+
# Weight by transition probability and add to the expectation
490+
expectation += inner_mean_k * Π[j, k]
493491
# Calculate updated c_{ij} values
494492
new_c_vals[i, j] = u_prime_inv(β * R * expectation)
495493
@@ -597,17 +595,6 @@ def create_ifp(r=0.01,
597595
η_draws = jax.random.normal(key, (shock_draw_size,))
598596
assert R * β < 1, "Stability condition violated."
599597
return IFP(R, β, γ, Π, z_grid, s, a_y, b_y, η_draws)
600-
601-
# Set y(z, η) = exp(a_y * η + z * b_y)
602-
def y_jax(z, η, a_y, b_y):
603-
return jnp.exp(a_y * η + z * b_y)
604-
605-
# Utility functions for JAX (can't use numba-jitted versions)
606-
def u_prime_jax(c, γ):
607-
return c**(-γ)
608-
609-
def u_prime_inv_jax(c, γ):
610-
return c**(-1/γ)
611598
```
612599

613600

@@ -651,6 +638,7 @@ def K(
651638
# For each k (future z state), compute the integral over η
652639
def compute_expectation_k(k):
653640
z_prime = z_grid[k]
641+
654642
# For each η draw, compute u'(σ(R * s_i + y(z', η), z'))
655643
def compute_for_eta(η):
656644
next_a = R * s[i] + y(z_prime, η)
@@ -659,18 +647,13 @@ def K(
659647
# Return u'(σ(R * s_i + y(z', η), z'))
660648
return u_prime(next_c)
661649
662-
# Compute average over all η draws using vmap
663-
compute_all_eta = jax.vmap(compute_for_eta)
664-
marginal_utils = compute_all_eta(η_draws)
665-
# Return the average (Monte Carlo approximation of the integral)
666-
return jnp.mean(marginal_utils)
667-
668-
# Compute ∫ u'(σ(...)) φ(η) dη for all k via vmap
669-
exp_over_eta = jax.vmap(compute_expectation_k)
670-
expectations_k = exp_over_eta(jnp.arange(n_z))
650+
# Average over η draws to approximate the integral
651+
# ∫ u'(σ(R s_i + y(z', η'), z')) φ(η') dη' when z' = z_grid[k]
652+
return jnp.mean(jax.vmap(compute_for_eta)(η_draws))
671653
672654
# Compute expectation: Σ_k [∫ u'(σ(...)) φ(η) dη] * Π[j, k]
673-
expectation = jnp.sum(expectations_k * Π[j, :])
655+
expectations = jax.vmap(compute_expectation_k)(jnp.arange(n_z))
656+
expectation = jnp.sum(expectations * Π[j, :])
674657
675658
# Invert to get consumption c_{ij} at (s_i, z_j)
676659
return u_prime_inv(β * R * expectation)
@@ -918,6 +901,9 @@ def simulate_household(
918901
R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp
919902
n_z = len(z_grid)
920903
904+
def y(z, η):
905+
return jnp.exp(a_y * η + z * b_y)
906+
921907
# Create interpolation function for consumption policy
922908
σ = lambda a, z_idx: jnp.interp(a, ae_vals[:, z_idx], c_vals[:, z_idx])
923909
@@ -932,7 +918,7 @@ def simulate_household(
932918
η_key = jax.random.fold_in(key, 2*t + 1)
933919
η = jax.random.normal(η_key)
934920
# Update assets: a' = R * (a - c) + Y'
935-
a_next = R * (a - σ(a, z_idx)) + y_jax(z_next, η, a_y, b_y)
921+
a_next = R * (a - σ(a, z_idx)) + y(z_next, η)
936922
# Return updated state
937923
return a_next, z_next_idx
938924
@@ -1109,14 +1095,14 @@ for r in r_vals:
11091095
ae_vals_init = s[:, None] * jnp.ones(len(z_grid))
11101096
c_vals_init = ae_vals_init
11111097
c_vals, ae_vals = solve_model(ifp, c_vals_init, ae_vals_init)
1112-
assets = compute_asset_stationary(c_vals, ae_vals, ifp,
1113-
num_households=50_000, T=500)
1098+
assets = compute_asset_stationary(
1099+
c_vals, ae_vals, ifp, num_households=50_000, T=500
1100+
)
11141101
gini = gini_coefficient(assets)
11151102
top1 = top_share(assets, p=0.01)
11161103
gini_vals.append(gini)
11171104
top1_vals.append(top1)
1118-
print(f' Gini: {gini:.4f}, Top 1%: {top1:.4f}')
1119-
# Start next round with last solution
1105+
# Use last solution as initial conditions for the policy solver
11201106
c_vals_init = c_vals
11211107
ae_vals_init = ae_vals
11221108
```

0 commit comments

Comments
 (0)