@@ -19,7 +19,7 @@ kernelspec:
1919</div>
2020```
2121
22- # {index}` IFP III: The Endogenous Grid Method <single: IFP III: The Endogenous Grid Method>`
22+ # {index}` The Income Fluctuation Problem III: The Endogenous Grid Method <single: The Income Fluctuation Problem III: The Endogenous Grid Method>`
2323
2424``` {contents} Contents
2525:depth: 2
@@ -424,7 +424,9 @@ def K_numpy(
424424 for k in range(n_z):
425425 # Set up the function a -> σ(a, z_k)
426426 σ = lambda a: np.interp(a, ae_vals[:, k], c_vals[:, k])
427+ # Calculate σ(R s_i + y(z_k), z_k)
427428 next_c = σ(R * s[i] + y(z_grid[k]))
429+ # Add to the sum that forms the expectation
428430 expectation += u_prime(next_c, γ) * Π[j, k]
429431 # Calculate updated c_{ij} values
430432 new_c_vals[i, j] = u_prime_inv(β * R * expectation, γ)
@@ -548,22 +550,26 @@ def K(
548550 n_a = len(s)
549551 n_z = len(z_grid)
550552
551- # Function to compute consumption for one (i, j) pair where i >= 1
552553 def compute_c_ij(i, j):
554+ " Function to compute consumption for one (i, j) pair where i >= 1. "
553555
554- # For each k, compute u'(σ(R * s_i + y(z_k), z_k))
556+ # First set up a function that takes s_i as given and, for each k in the indices
557+ # of z_grid, computes the term u'(σ(R * s_i + y(z_k), z_k))
555558 def mu(k):
556559 next_a = R * s[i] + y(z_grid[k])
557- # Interpolate to get consumption at next_a in state k
560+ # Interpolate to get σ(R * s_i + y(z_k), z_k)
558561 next_c = jnp.interp(next_a, ae_vals[:, k], c_vals[:, k])
562+ # Return the final quantity u'(σ(R * s_i + y(z_k), z_k))
559563 return u_prime(next_c, γ)
560564
561565 # Compute u'(σ(R * s_i + y(z_k), z_k)) at all k via vmap
562566 mu_vectorized = jax.vmap(mu)
563567 marginal_utils = mu_vectorized(jnp.arange(n_z))
568+
564569 # Compute expectation: Σ_k u'(σ(...)) * Π[j, k]
565570 expectation = jnp.sum(marginal_utils * Π[j, :])
566- # Invert to get consumption
571+
572+ # Invert to get consumption c_{ij} at (s_i, z_j)
567573 return u_prime_inv(β * R * expectation, γ)
568574
569575 # Set up index grids for vmap computation of all c_{ij}
@@ -646,9 +652,11 @@ print(f"Maximum difference in consumption policy: {max_c_diff:.2e}")
646652print(f"Maximum difference in asset grid: {max_ae_diff:.2e}")
647653```
648654
649- The maximum differences are on the order of $10^{-15}$ or smaller, which is essentially machine precision for 64-bit floating point arithmetic.
655+ The maximum differences are on the order of $10^{-15}$ or smaller, which is
656+ essentially machine precision for 64-bit floating point arithmetic.
650657
651- This confirms that our JAX implementation produces identical results to the NumPy version, validating the correctness of our vectorized JAX code.
658+ This confirms that our JAX implementation produces identical results to the
659+ NumPy version, validating the correctness of our vectorized JAX code.
652660
653661Here's a plot of the optimal policy for each $z$ state
654662
@@ -663,7 +671,8 @@ plt.show()
663671
664672### Dynamics
665673
666- To begin to understand the long run asset levels held by households under the default parameters, let's look at the
674+ To begin to understand the long run asset levels held by households under the
675+ default parameters, let's look at the
66767645 degree diagram showing the law of motion for assets under the optimal consumption policy.
668677
669678``` {code-cell} ipython3
@@ -741,69 +750,70 @@ plt.show()
741750
742751This looks pretty good.
743752
753+ ## Simulation
744754
745- ## Exercises
746-
747- ``` {exercise}
748- :label: ifp_egm_ex1
749-
750- Let's consider how the interest rate affects consumption.
755+ Let's return to the default model and study the stationary distribution of assets.
751756
752- * Step `r` through `np.linspace(0, 0.016, 4)`.
753- * Other than `r`, hold all parameters at their default values.
754- * Plot consumption against assets for income shock fixed at the smallest value.
757+ Our plan is to run a large number of households forward for $T$ periods and then
758+ histogram the cross-sectional distribution of assets.
755759
756- Your figure should show that, for this model, higher interest rates
757- suppress consumption (because they encourage more savings).
760+ Set ` num_households=50_000, T=500 ` .
758761```
759762
760- ``` {solution-start} ifp_egm_ex1
763+ ```{solution-start} ifp_egm_ex2
761764:class: dropdown
762765```
763766
764- Here's one solution:
765-
766- ``` {code-cell} ipython3
767- # With β=0.96, we need R*β < 1, so r < 0.0416
768- r_vals = np.linspace(0, 0.04, 4)
769-
770- fig, ax = plt.subplots()
771- for r_val in r_vals:
772- ifp = create_ifp(r=r_val)
773- R, β, γ, Π, z_grid, s = ifp
774- c_vals_init = s[:, None] * jnp.ones(len(z_grid))
775- c_vals, ae_vals = solve_model(ifp, c_vals_init)
776- ax.plot(ae_vals[:, 0], c_vals[:, 0], label=f'$r = {r_val:.3f}$')
777-
778- ax.set(xlabel='asset level', ylabel='consumption (low income)')
779- ax.legend()
780- plt.show()
781- ```
767+ First we write a function to run a single household forward in time and record
768+ the final value of assets.
782769
783- ``` {solution-end}
784- `` `
770+ The function takes a solution pair ` c_vals ` and ` ae_vals ` , understanding them
771+ as representing an optimal policy associated with a given model ` ifp `
785772
773+ ``` {code-cell} ipython3
774+ @jax.jit
775+ def simulate_household(
776+ key, a_0, z_idx_0, c_vals, ae_vals, ifp, num_households, T
777+ ):
778+ """
779+ Simulates num_households households for T periods to approximate
780+ the stationary distribution of assets.
786781
787- ``` {exercise}
788- :label: ifp_egm_ex2
782+ - key is the state of the random number generator
783+ - ifp is an instance of IFP
784+ - c_vals, ae_vals are the optimal consumption policy, endogenous grid for ifp
789785
790- Let's approximate the stationary distribution by simulation.
786+ """
787+ R, β, γ, Π, z_grid, s = ifp
788+ n_z = len(z_grid)
791789
792- Run a large number of households forward for $T$ periods and then histogram the
793- cross-sectional distribution of assets.
790+ # Create interpolation function for consumption policy
791+ σ = lambda a, z_idx: jnp.interp(a, ae_vals[:, z_idx], c_vals[:, z_idx])
794792
795- Set `num_households=50_000, T=500`.
796- ```
793+ # Simulate forward T periods
794+ def update(state, t):
795+ a, z_idx = state
796+ c = σ(a, z_idx)
797+ # Draw next shock z' from Π[z, z']
798+ current_key = jax.random.fold_in(t, key)
799+ z_next_idx = jax.random.choice(current_key, n_z, p=Π[z_idx])
800+ z_next = z_grid[z_next_idx]
801+ # Update assets: a' = R * (a - c) + Y'
802+ a_next = R * (a - c) + y(z_next)
803+ # Return updated state
804+ return a_next, z_next_idx
797805
798- ``` {solution-start} ifp_egm_ex2
799- :class: dropdown
806+ initial_state = a_0, z_idx_0
807+ final_state = jax.lax.fori_loop(0, T, update, initial_state)
808+ a_final, _ = final_state
809+ return a_final
800810```
801811
802- First we write a function to simulate many households in parallel using JAX .
812+ Now we write a function to simulate many households in parallel.
803813
804814``` {code-cell} ipython3
805815def compute_asset_stationary(
806- ifp, c_vals, ae_vals, num_households=50_000, T=500, seed=1234
816+ c_vals, ae_vals, ifp , num_households=50_000, T=500, seed=1234
807817 ):
808818 """
809819 Simulates num_households households for T periods to approximate
@@ -815,6 +825,7 @@ def compute_asset_stationary(
815825 ifp is an instance of IFP
816826 c_vals, ae_vals are the consumption policy and endogenous grid from
817827 solve_model
828+
818829 """
819830 R, β, γ, Π, z_grid, s = ifp
820831 n_z = len(z_grid)
@@ -823,38 +834,19 @@ def compute_asset_stationary(
823834 # Interpolate on the endogenous grid
824835 σ = lambda a, z_idx: jnp.interp(a, ae_vals[:, z_idx], c_vals[:, z_idx])
825836
826- # Simulate one household forward
827- def simulate_one_household(key):
828-
829- # Random initial state (a, z)
830- key1, key2, key3 = jax.random.split(key, 3)
831- z_idx = jax.random.choice(key1, n_z)
832- # Start with random assets drawn from [0, savings_grid_max/2]
833- a = jax.random.uniform(key3, minval=0.0, maxval=s[-1]/2)
834-
835- # Simulate forward T periods
836- def step(state, key_t):
837- a, z_idx = state
838- # Consume based on current state
839- c = σ(a, z_idx)
840- # Draw next shock
841- z_next_idx = jax.random.choice(key_t, n_z, p=Π[z_idx])
842- # Update assets: a' = R*(a - c) + Y'
843- z_next = z_grid[z_next_idx]
844- a_next = R * (a - c) + y(z_next)
845- return (a_next, z_next_idx), None
846-
847- keys = jax.random.split(key2, T)
848- initial_state = a, z_idx
849- final_state, _ = jax.lax.scan(step, initial_state, keys)
850- a_final, _ = final_state
851- return a_final
837+ # Start with assets = savings_grid_max / 2
838+ a_0_vector = jnp.full(num_households, s[-1] / 2)
839+ # Initialize the exogenous state of each household
840+ z_idx_0_vector = jnp.zeros(num_households).astype(jnp.int32)
852841
853842 # Vectorize over many households
854843 key = jax.random.PRNGKey(seed)
855844 keys = jax.random.split(key, num_households)
856- sim_all_households = jax.vmap(simulate_one_household)
857- assets = sim_all_households(keys)
845+ # Vectorize simulate_household in (key, a_0, z_idx_0)
846+ sim_all_households = jax.vmap(
847+ simulate_household, axes=(0, 0, 0, None, None, None, None, None)
848+ )
849+ assets = sim_all_households(keys, a_0_vector, z_idx_0_vector)
858850
859851 return np.array(assets)
860852```
@@ -874,13 +866,55 @@ ax.set(xlabel='assets')
874866plt.show()
875867```
876868
877- The shape of the asset distribution is unrealistic.
869+ The shape of the asset distribution is completely unrealistic!
878870
879871Here it is left skewed when in reality it has a long right tail.
880872
881873In a {doc}` subsequent lecture <ifp_advanced> ` we will rectify this by adding
882874more realistic features to the model.
883875
876+
877+
878+
879+
880+ ## Exercises
881+
882+ ``` {exercise}
883+ :label: ifp_egm_ex1
884+
885+ Let's consider how the interest rate affects consumption.
886+
887+ * Step `r` through `np.linspace(0, 0.016, 4)`.
888+ * Other than `r`, hold all parameters at their default values.
889+ * Plot consumption against assets for income shock fixed at the smallest value.
890+
891+ Your figure should show that, for this model, higher interest rates
892+ suppress consumption (because they encourage more savings).
893+ ```
894+
895+ ``` {solution-start} ifp_egm_ex1
896+ :class: dropdown
897+ ```
898+
899+ Here's one solution:
900+
901+ ``` {code-cell} ipython3
902+ # With β=0.96, we need R*β < 1, so r < 0.0416
903+ r_vals = np.linspace(0, 0.04, 4)
904+
905+ fig, ax = plt.subplots()
906+ for r_val in r_vals:
907+ ifp = create_ifp(r=r_val)
908+ R, β, γ, Π, z_grid, s = ifp
909+ c_vals_init = s[:, None] * jnp.ones(len(z_grid))
910+ c_vals, ae_vals = solve_model(ifp, c_vals_init)
911+ ax.plot(ae_vals[:, 0], c_vals[:, 0], label=f'$r = {r_val:.3f}$')
912+
913+ ax.set(xlabel='asset level', ylabel='consumption (low income)')
914+ ax.legend()
915+ plt.show()
916+ ```
917+
884918``` {solution-end}
885919```
886920
@@ -890,7 +924,7 @@ more realistic features to the model.
890924:label: ifp_egm_ex3
891925```
892926
893- Following on from exercises 1 and 2 , let's look at how savings and aggregate
927+ Following on from Exercises 1 , let's look at how savings and aggregate
894928asset holdings vary with the interest rate
895929
896930``` {note}
@@ -905,12 +939,10 @@ shocks.
905939Your task is to investigate how this measure of aggregate capital varies with
906940the interest rate.
907941
908- Following tradition, put the price (i.e., interest rate) on the vertical axis.
909-
910- On the horizontal axis put aggregate capital, computed as the mean of the
911- stationary distribution given the interest rate.
942+ Intuition suggests that a higher interest rate should encourage capital
943+ formation --- test this.
912944
913- Use
945+ For the interest rate grid, use
914946
915947``` {code-cell} ipython3
916948M = 12
0 commit comments