Skip to content

Commit 063367f

Browse files
committed
Add for-loop implementation
1 parent ff9d4c6 commit 063367f

File tree

1 file changed

+69
-2
lines changed

1 file changed

+69
-2
lines changed

lectures/wealth_dynamics.md

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,14 +332,81 @@ def update_states_jax(arrays, wdy, size, rand_key):
332332
update_states_jax = jax.jit(update_states_jax, static_argnums=(2,))
333333
```
334334

335-
Here’s function to simulate the time series of wealth for individual households.
335+
Here’s function to simulate the time series of wealth for individual households using `for` loop and JAX.
336+
337+
```{code-cell} ipython3
338+
# Using JAX and for loop
339+
def wealth_time_series_for_loop_jax(w_0, n, wdy, size, rand_seed=1):
340+
"""
341+
Generate a single time series of length n for wealth given
342+
initial value w_0.
343+
344+
* This implementation uses for loop.
345+
346+
The initial persistent state z_0 for each household is drawn from
347+
the stationary distribution of the AR(1) process.
348+
349+
* wdy: NamedTuple Model
350+
* w_0: scalar/vector
351+
* n: int
352+
* size: size/shape of the w_0
353+
* rand_seed: int (Used to generate PRNG key)
354+
"""
355+
rand_key = jax.random.PRNGKey(rand_seed)
356+
rand_key, *subkey = jax.random.split(rand_key, n)
357+
358+
w_0 = jax.device_put(w_0).reshape(size)
359+
360+
z = wdy.z_mean + jnp.sqrt(wdy.z_var) * jax.random.normal(rand_key, shape=size)
361+
w = [w_0]
362+
for t in range(n-1):
363+
w_, z = update_states_jax((w[t], z), wdy, size, subkey[t])
364+
w.append(w_)
365+
return jnp.array(w)
366+
367+
# Create the jit function
368+
wealth_time_series_for_loop_jax = jax.jit(wealth_time_series_for_loop_jax, static_argnums=(1,3,))
369+
```
370+
371+
Let's try simulating the model at different parameter values and investigate the implications for the wealth distribution using the above function.
372+
373+
```{code-cell} ipython3
374+
wdy = create_wealth_model() # default model
375+
ts_length = 200
376+
size = (1,)
377+
```
378+
379+
```{code-cell} ipython3
380+
%%time
381+
382+
w_jax_result = wealth_time_series_for_loop_jax(wdy.y_mean, ts_length, wdy, size)
383+
```
384+
385+
Running the above function again will be even faster because of JAX's JIT.
386+
387+
```{code-cell} ipython3
388+
%%time
389+
390+
# 2nd time is expected to be very fast because of JIT
391+
w_jax_result = wealth_time_series_for_loop_jax(wdy.y_mean, ts_length, wdy, size)
392+
```
393+
394+
```{code-cell} ipython3
395+
fig, ax = plt.subplots()
396+
ax.plot(w_jax_result)
397+
plt.show()
398+
```
399+
400+
We can further try to optimize and speed up the compile time of the above function by replacing `for` loop with [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html).
336401

337402
```{code-cell} ipython3
338403
def wealth_time_series_jax(w_0, n, wdy, size, rand_seed=1):
339404
"""
340405
Generate a single time series of length n for wealth given
341406
initial value w_0.
342407
408+
* This implementation uses for jax.lax.scan
409+
343410
The initial persistent state z_0 for each household is drawn from
344411
the stationary distribution of the AR(1) process.
345412
@@ -371,7 +438,7 @@ def wealth_time_series_jax(w_0, n, wdy, size, rand_seed=1):
371438
wealth_time_series_jax = jax.jit(wealth_time_series_jax, static_argnums=(1,3,))
372439
```
373440

374-
Let's try simulating the model at different parameter values and investigate the implications for the wealth distribution.
441+
Let's try simulating the model at different parameter values and investigate the implications for the wealth distribution and also observe the difference in time between `wealth_time_series_jax` and `wealth_time_series_for_loop_jax`.
375442

376443
```{code-cell} ipython3
377444
wdy = create_wealth_model() # default model

0 commit comments

Comments
 (0)