Skip to content

Commit 975be52

Browse files
committed
Remove hidden tags and use glue
1 parent e43c237 commit 975be52

File tree

1 file changed

+18
-42
lines changed

1 file changed

+18
-42
lines changed

lectures/wealth_dynamics.md

Lines changed: 18 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ from numba.experimental import jitclass
8181
import jax
8282
import jax.numpy as jnp
8383
from collections import namedtuple
84+
from myst_nb import glue
8485
```
8586

8687
Let's check the backend used by JAX and the devices available
@@ -377,18 +378,18 @@ size = (1,)
377378
```
378379

379380
```{code-cell} ipython3
380-
%%time
381-
381+
qe.tic()
382382
w_jax_result = wealth_time_series_for_loop_jax(wdy.y_mean, ts_length, wdy, size).block_until_ready()
383+
qe.toc()
383384
```
384385

385386
Running the above function again will be even faster because of JAX's JIT.
386387

387388
```{code-cell} ipython3
388-
%%time
389-
389+
qe.tic()
390390
# 2nd time is expected to be very fast because of JIT
391391
w_jax_result = wealth_time_series_for_loop_jax(wdy.y_mean, ts_length, wdy, size).block_until_ready()
392+
qe.toc()
392393
```
393394

394395
```{code-cell} ipython3
@@ -448,17 +449,18 @@ size = (1,)
448449

449450
```{code-cell} ipython3
450451
%%time
451-
452+
qe.tic()
452453
w_jax_result = wealth_time_series_jax(wdy.y_mean, ts_length, wdy, size).block_until_ready()
454+
glue("wealth_time_series_jax_time_1", qe.toc())
453455
```
454456

455457
Running the above function again will be even faster because of JAX's JIT.
456458

457459
```{code-cell} ipython3
458-
%%time
459-
460+
qe.tic()
460461
# 2nd time is expected to be very fast because of JIT
461462
w_jax_result = wealth_time_series_jax(wdy.y_mean, ts_length, wdy, size).block_until_ready()
463+
glue("wealth_time_series_jax_time_2", qe.toc())
462464
```
463465

464466
```{code-cell} ipython3
@@ -495,9 +497,7 @@ update_cross_section_jax = jax.jit(update_cross_section_jax, static_argnums=(1,3
495497
Here's some type information to help Numba.
496498

497499
```{code-cell} ipython3
498-
---
499-
tags: [hide-input]
500-
---
500+
501501
wealth_dynamics_data = [
502502
('w_hat', float64), # savings parameter
503503
('s_0', float64), # savings parameter
@@ -521,9 +521,7 @@ Here's a class that stores instance data and implements methods that update
521521
the aggregate state and household wealth.
522522

523523
```{code-cell} ipython3
524-
---
525-
tags: [hide-input]
526-
---
524+
527525
@jitclass(wealth_dynamics_data)
528526
class WealthDynamics:
529527
@@ -591,9 +589,7 @@ class WealthDynamics:
591589
Here's function to simulate the time series of wealth for in individual households.
592590

593591
```{code-cell} ipython3
594-
---
595-
tags: [hide-input]
596-
---
592+
597593
@njit
598594
def wealth_time_series(wdy, w_0, n):
599595
"""
@@ -622,9 +618,7 @@ Now here's function to simulate a cross section of households forward in time.
622618
Note the use of parallelization to speed up computation.
623619

624620
```{code-cell} ipython3
625-
---
626-
tags: [hide-input]
627-
---
621+
628622
@njit(parallel=True)
629623
def update_cross_section(wdy, w_distribution, shift_length=500):
630624
"""
@@ -666,34 +660,25 @@ the implications for the wealth distribution.
666660
Let's look at the wealth dynamics of an individual household using numba.
667661

668662
```{code-cell} ipython3
669-
---
670-
tags: [hide-input]
671-
---
672663
wdy = WealthDynamics()
673664
674665
ts_length = 200
675666
```
676667

677668
```{code-cell} ipython3
678-
---
679-
tags: [hide-input]
680-
---
681-
%%time
682-
669+
qe.tic()
683670
w = wealth_time_series(wdy, wdy.y_mean, ts_length)
671+
glue("wealth_time_series_time_1", qe.toc())
684672
```
685673

686674
```{code-cell} ipython3
687-
---
688-
tags: [hide-input]
689-
---
690-
%%time
691-
675+
qe.tic()
692676
# Check the time for 2nd execution
693677
w = wealth_time_series(wdy, wdy.y_mean, ts_length)
678+
glue("wealth_time_series_time_2", qe.toc())
694679
```
695680

696-
Notice the time difference between the `wealth_time_series` and `wealth_time_series_jax`
681+
Notice the time difference between the `wealth_time_series`: {glue:}`wealth_time_series_time_1` and `wealth_time_series_jax`: {glue:}`wealth_time_series_jax_time_1`
697682

698683

699684

@@ -733,9 +718,6 @@ def generate_lorenz_and_gini_jax(wdy, num_households=100_000, T=500):
733718
The following function uses the numba implementation
734719

735720
```{code-cell} ipython3
736-
---
737-
tags: [hide-input]
738-
---
739721
# Uses numba
740722
def generate_lorenz_and_gini(wdy, num_households=100_000, T=500):
741723
"""
@@ -780,9 +762,6 @@ plt.show()
780762
Now let's try to run the same code snippet but using the numba version.
781763

782764
```{code-cell} ipython3
783-
---
784-
tags: [hide-input]
785-
---
786765
%%time
787766
788767
fig, ax = plt.subplots()
@@ -849,9 +828,6 @@ plt.show()
849828
Using numba, we get,
850829

851830
```{code-cell} ipython3
852-
---
853-
tags: [hide-input]
854-
---
855831
%%time
856832
857833
fig, ax = plt.subplots()

0 commit comments

Comments
 (0)