@@ -264,7 +264,7 @@ acknowledging that low wealth households tend to save very little.
264264
265265## Implementation using JAX
266266
267- Let's define a Model to represent the wealth dynamics.
267+ Let's define a model to represent the wealth dynamics.
268268
269269``` {code-cell} ipython3
270270# NamedTuple Model
@@ -332,7 +332,7 @@ def update_states_jax(arrays, wdy, size, rand_key):
332332update_states_jax = jax.jit(update_states_jax, static_argnums=(2,))
333333```
334334
335- Here’s function to simulate the time series of wealth for individual households using ` for ` loop and JAX.
335+ Here’s function to simulate the time series of wealth for individual households using a ` for ` loop and JAX.
336336
337337``` {code-cell} ipython3
338338# Using JAX and for loop
@@ -341,7 +341,7 @@ def wealth_time_series_for_loop_jax(w_0, n, wdy, size, rand_seed=1):
341341 Generate a single time series of length n for wealth given
342342 initial value w_0.
343343
344- * This implementation uses for loop.
344+ * This implementation uses a for loop.
345345
346346 The initial persistent state z_0 for each household is drawn from
347347 the stationary distribution of the AR(1) process.
@@ -379,7 +379,7 @@ size = (1,)
379379``` {code-cell} ipython3
380380%%time
381381
382- w_jax_result = wealth_time_series_for_loop_jax(wdy.y_mean, ts_length, wdy, size)
382+ w_jax_result = wealth_time_series_for_loop_jax(wdy.y_mean, ts_length, wdy, size).block_until_ready()
383383```
384384
385385Running the above function again will be even faster because of JAX's JIT.
@@ -388,7 +388,7 @@ Running the above function again will be even faster because of JAX's JIT.
388388%%time
389389
390390# 2nd time is expected to be very fast because of JIT
391- w_jax_result = wealth_time_series_for_loop_jax(wdy.y_mean, ts_length, wdy, size)
391+ w_jax_result = wealth_time_series_for_loop_jax(wdy.y_mean, ts_length, wdy, size).block_until_ready()
392392```
393393
394394``` {code-cell} ipython3
@@ -449,7 +449,7 @@ size = (1,)
449449``` {code-cell} ipython3
450450%%time
451451
452- w_jax_result = wealth_time_series_jax(wdy.y_mean, ts_length, wdy, size)
452+ w_jax_result = wealth_time_series_jax(wdy.y_mean, ts_length, wdy, size).block_until_ready()
453453```
454454
455455Running the above function again will be even faster because of JAX's JIT.
@@ -458,7 +458,7 @@ Running the above function again will be even faster because of JAX's JIT.
458458%%time
459459
460460# 2nd time is expected to be very fast because of JIT
461- w_jax_result = wealth_time_series_jax(wdy.y_mean, ts_length, wdy, size)
461+ w_jax_result = wealth_time_series_jax(wdy.y_mean, ts_length, wdy, size).block_until_ready()
462462```
463463
464464``` {code-cell} ipython3
@@ -591,6 +591,9 @@ class WealthDynamics:
591591Here's function to simulate the time series of wealth for in individual households.
592592
593593``` {code-cell} ipython3
594+ ---
595+ tags: [hide-input]
596+ ---
594597@njit
595598def wealth_time_series(wdy, w_0, n):
596599 """
@@ -660,21 +663,30 @@ the implications for the wealth distribution.
660663
661664### Time Series
662665
663- Let's look at the wealth dynamics of an individual household.
666+ Let's look at the wealth dynamics of an individual household using numba .
664667
665668``` {code-cell} ipython3
669+ ---
670+ tags: [hide-input]
671+ ---
666672wdy = WealthDynamics()
667673
668674ts_length = 200
669675```
670676
671677``` {code-cell} ipython3
678+ ---
679+ tags: [hide-input]
680+ ---
672681%%time
673682
674683w = wealth_time_series(wdy, wdy.y_mean, ts_length)
675684```
676685
677686``` {code-cell} ipython3
687+ ---
688+ tags: [hide-input]
689+ ---
678690%%time
679691
680692# Check the time for 2nd execution
@@ -794,15 +806,6 @@ We will look at this again via the Gini coefficient immediately below, but
794806first consider the following image of our system resources when the code above
795807is executing:
796808
797- (htop_again)=
798- ``` {figure} /_static/lecture_specific/wealth_dynamics/htop_again.png
799- :scale: 80
800- ```
801-
802- Notice how effectively Numba has implemented multithreading for this routine:
803- all 8 CPUs on our workstation are running at maximum capacity (even though
804- four of them are virtual).
805-
806809Since the code is both efficiently JIT compiled and fully parallelized, it's
807810close to impossible to make this sequence of tasks run faster without changing
808811hardware.
0 commit comments