@@ -406,11 +406,6 @@ def create_ifp(r=0.01,
406406 η_draws = np.random.randn(shock_draw_size)
407407 assert R * β < 1, "Stability condition violated."
408408 return IFPNumPy(R, β, γ, Π, z_grid, s, a_y, b_y, η_draws)
409-
410- # Set y(z, η) = exp(a_y * η + z * b_y)
411- @numba.jit
412- def y(z, η, a_y, b_y):
413- return np.exp(a_y * η + z * b_y)
414409```
415410
416411### Solver
@@ -488,8 +483,11 @@ def K_numpy(
488483 next_c = np.interp(next_a, ae_vals[:, k], c_vals[:, k])
489484 # Add to the inner sum
490485 inner_sum += u_prime(next_c)
491- # Average over η draws and weight by transition probability
492- expectation += (inner_sum / len(η_draws)) * Π[j, k]
486+ # Average over η draws to approximate the integral
487+ # ∫ u'(σ(R s_i + y(z', η'), z')) φ(η') dη' when z' = z_grid[k]
488+ inner_mean_k = (inner_sum / len(η_draws))
489+ # Weight by transition probability and add to the expectation
490+ expectation += inner_mean_k * Π[j, k]
493491 # Calculate updated c_{ij} values
494492 new_c_vals[i, j] = u_prime_inv(β * R * expectation)
495493
@@ -597,17 +595,6 @@ def create_ifp(r=0.01,
597595 η_draws = jax.random.normal(key, (shock_draw_size,))
598596 assert R * β < 1, "Stability condition violated."
599597 return IFP(R, β, γ, Π, z_grid, s, a_y, b_y, η_draws)
600-
601- # Set y(z, η) = exp(a_y * η + z * b_y)
602- def y_jax(z, η, a_y, b_y):
603- return jnp.exp(a_y * η + z * b_y)
604-
605- # Utility functions for JAX (can't use numba-jitted versions)
606- def u_prime_jax(c, γ):
607- return c**(-γ)
608-
609- def u_prime_inv_jax(c, γ):
610- return c**(-1/γ)
611598```
612599
613600
@@ -651,6 +638,7 @@ def K(
651638 # For each k (future z state), compute the integral over η
652639 def compute_expectation_k(k):
653640 z_prime = z_grid[k]
641+
654642 # For each η draw, compute u'(σ(R * s_i + y(z', η), z'))
655643 def compute_for_eta(η):
656644 next_a = R * s[i] + y(z_prime, η)
@@ -659,18 +647,13 @@ def K(
659647 # Return u'(σ(R * s_i + y(z', η), z'))
660648 return u_prime(next_c)
661649
662- # Compute average over all η draws using vmap
663- compute_all_eta = jax.vmap(compute_for_eta)
664- marginal_utils = compute_all_eta(η_draws)
665- # Return the average (Monte Carlo approximation of the integral)
666- return jnp.mean(marginal_utils)
667-
668- # Compute ∫ u'(σ(...)) φ(η) dη for all k via vmap
669- exp_over_eta = jax.vmap(compute_expectation_k)
670- expectations_k = exp_over_eta(jnp.arange(n_z))
650+ # Average over η draws to approximate the integral
651+ # ∫ u'(σ(R s_i + y(z', η'), z')) φ(η') dη' when z' = z_grid[k]
652+ return jnp.mean(jax.vmap(compute_for_eta)(η_draws))
671653
672654 # Compute expectation: Σ_k [∫ u'(σ(...)) φ(η) dη] * Π[j, k]
673- expectation = jnp.sum(expectations_k * Π[j, :])
655+ expectations = jax.vmap(compute_expectation_k)(jnp.arange(n_z))
656+ expectation = jnp.sum(expectations * Π[j, :])
674657
675658 # Invert to get consumption c_{ij} at (s_i, z_j)
676659 return u_prime_inv(β * R * expectation)
@@ -918,6 +901,9 @@ def simulate_household(
918901 R, β, γ, Π, z_grid, s, a_y, b_y, η_draws = ifp
919902 n_z = len(z_grid)
920903
904+ def y(z, η):
905+ return jnp.exp(a_y * η + z * b_y)
906+
921907 # Create interpolation function for consumption policy
922908 σ = lambda a, z_idx: jnp.interp(a, ae_vals[:, z_idx], c_vals[:, z_idx])
923909
@@ -932,7 +918,7 @@ def simulate_household(
932918 η_key = jax.random.fold_in(key, 2*t + 1)
933919 η = jax.random.normal(η_key)
934920 # Update assets: a' = R * (a - c) + Y'
935- a_next = R * (a - σ(a, z_idx)) + y_jax (z_next, η, a_y, b_y )
921+ a_next = R * (a - σ(a, z_idx)) + y (z_next, η)
936922 # Return updated state
937923 return a_next, z_next_idx
938924
@@ -1109,14 +1095,14 @@ for r in r_vals:
11091095 ae_vals_init = s[:, None] * jnp.ones(len(z_grid))
11101096 c_vals_init = ae_vals_init
11111097 c_vals, ae_vals = solve_model(ifp, c_vals_init, ae_vals_init)
1112- assets = compute_asset_stationary(c_vals, ae_vals, ifp,
1113- num_households=50_000, T=500)
1098+ assets = compute_asset_stationary(
1099+ c_vals, ae_vals, ifp, num_households=50_000, T=500
1100+ )
11141101 gini = gini_coefficient(assets)
11151102 top1 = top_share(assets, p=0.01)
11161103 gini_vals.append(gini)
11171104 top1_vals.append(top1)
1118- print(f' Gini: {gini:.4f}, Top 1%: {top1:.4f}')
1119- # Start next round with last solution
1105+ # Use last solution as initial conditions for the policy solver
11201106 c_vals_init = c_vals
11211107 ae_vals_init = ae_vals
11221108```
0 commit comments