Skip to content

Commit ecd4e26

Browse files
authored
Merge pull request #301 from QuantEcon/kesten_jax
[kesten_processes] Rewrite Main Exercise Using JAX
2 parents 81cd6a3 + 1fd8130 commit ecd4e26

File tree

2 files changed

+161
-10
lines changed

2 files changed

+161
-10
lines changed

lectures/kesten_processes.md

Lines changed: 160 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,16 @@ kernelspec:
1919

2020
# Kesten Processes and Firm Dynamics
2121

22+
```{admonition} GPU in use
23+
:class: warning
24+
25+
This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and JAX for GPU programming.
26+
27+
Free GPUs are available on Google Colab. To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU.
28+
29+
Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax#pip-installation-gpu-cuda) for installing JAX with GPU support.
30+
```
31+
2232
```{index} single: Linear State Space Models
2333
```
2434

@@ -34,6 +44,9 @@ tags: [hide-output]
3444
---
3545
!pip install quantecon
3646
!pip install --upgrade yfinance
47+
# If your machine has CUDA support, please follow the guide in GPU Warning.
48+
# Otherwise, run the line below:
49+
!pip install --upgrade "jax[CPU]"
3750
```
3851

3952
## Overview
@@ -673,14 +686,23 @@ s_init = 1.0 # initial condition for each firm
673686
:class: dropdown
674687
```
675688

676-
Here's one solution. First we generate the observations:
689+
Here's one solution in [JAX](https://python-programming.quantecon.org/jax_intro.html).
690+
691+
First let's import the necessary modules and check the backend for JAX
677692

678693
```{code-cell} ipython3
679-
from numba import njit, prange
680-
from numpy.random import randn
694+
import jax
695+
import jax.numpy as jnp
696+
from jax import random
681697
698+
# Check if JAX is using GPU
699+
print(f"jax backend: {jax.devices()[0].platform}")
700+
```
682701

683-
@njit(parallel=True)
702+
Now we can generate the observations:
703+
704+
```{code-cell} ipython3
705+
@jax.jit
684706
def generate_draws(μ_a=-0.5,
685707
σ_a=0.1,
686708
μ_b=0.0,
@@ -690,7 +712,136 @@ def generate_draws(μ_a=-0.5,
690712
s_bar=1.0,
691713
T=500,
692714
M=1_000_000,
693-
s_init=1.0):
715+
s_init=1.0,
716+
seed=123):
717+
718+
key = random.PRNGKey(seed)
719+
keys = random.split(key, 3)
720+
721+
# Generate arrays of random numbers
722+
a_random = μ_a + σ_a * random.normal(keys[0], (T, M))
723+
b_random = μ_b + σ_b * random.normal(keys[1], (T, M))
724+
e_random = μ_e + σ_e * random.normal(keys[2], (T, M))
725+
726+
# Initialize the array of s values with the initial value
727+
s = jnp.full((M, T+1), s_init)
728+
729+
# Perform the calculations in a vectorized manner for T periods
730+
for t in range(T):
731+
exp_a = jnp.exp(a_random[t, :])
732+
exp_b = jnp.exp(b_random[t, :])
733+
exp_e = jnp.exp(e_random[t, :])
734+
s = s.at[:, t+1].set(jnp.where(s[:, t] < s_bar,
735+
exp_e,
736+
exp_a * s[:, t] + exp_b))
737+
738+
return s[:, -1]
739+
740+
%time data = generate_draws().block_until_ready()
741+
```
742+
743+
Since we applied `jax.jit` on the function, it runs even faster when we call the function again
744+
745+
```{code-cell} ipython3
746+
%time data = generate_draws().block_until_ready()
747+
```
748+
749+
Let's produce the rank-size plot and check the distribution:
750+
751+
```{code-cell} ipython3
752+
fig, ax = plt.subplots()
753+
754+
rank_data, size_data = qe.rank_size(data, c=0.01)
755+
ax.loglog(rank_data, size_data, 'o', markersize=3.0, alpha=0.5)
756+
ax.set_xlabel("log rank")
757+
ax.set_ylabel("log size")
758+
759+
plt.show()
760+
```
761+
762+
The plot produces a straight line, consistent with a Pareto tail.
763+
764+
It is possible to further speed up our code by replacing the `for` loop with [`lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)
765+
to reduce the loop overhead in the compilation of the jitted function
766+
767+
```{code-cell} ipython3
768+
from jax import lax
769+
770+
@jax.jit
771+
def generate_draws_lax(μ_a=-0.5,
772+
σ_a=0.1,
773+
μ_b=0.0,
774+
σ_b=0.5,
775+
μ_e=0.0,
776+
σ_e=0.5,
777+
s_bar=1.0,
778+
T=500,
779+
M=1_000_000,
780+
s_init=1.0,
781+
seed=123):
782+
783+
key = random.PRNGKey(seed)
784+
keys = random.split(key, 3)
785+
786+
# Generate random draws and initial values
787+
a_random = μ_a + σ_a * random.normal(keys[0], (T, M))
788+
b_random = μ_b + σ_b * random.normal(keys[1], (T, M))
789+
e_random = μ_e + σ_e * random.normal(keys[2], (T, M))
790+
s = jnp.full((M, ), s_init)
791+
792+
# Define the function for each update
793+
def update_s(s, a_b_e_draws):
794+
a, b, e = a_b_e_draws
795+
res = jnp.where(s < s_bar,
796+
jnp.exp(e),
797+
jnp.exp(a) * s + jnp.exp(b))
798+
return res, res
799+
800+
# Use lax.scan to perform the calculations on all states
801+
s_final, _ = lax.scan(update_s, s, (a_random, b_random, e_random))
802+
return s_final
803+
804+
%time data = generate_draws_lax().block_until_ready()
805+
```
806+
807+
The compiled function is even faster
808+
809+
```{code-cell} ipython3
810+
%time data = generate_draws_lax().block_until_ready()
811+
```
812+
813+
Here we produce the same rank-size plot:
814+
815+
```{code-cell} ipython3
816+
fig, ax = plt.subplots()
817+
818+
rank_data, size_data = qe.rank_size(data, c=0.01)
819+
ax.loglog(rank_data, size_data, 'o', markersize=3.0, alpha=0.5)
820+
ax.set_xlabel("log rank")
821+
ax.set_ylabel("log size")
822+
823+
plt.show()
824+
```
825+
826+
We can also use Numba with `for` loops to generate the observations (replicating the results we obtained with JAX).
827+
828+
The results will be slightly different since the pseudo random number generation is implemented [differently in JAX](https://www.kaggle.com/code/aakashnain/tf-jax-tutorials-part-6-prng-in-jax/notebook)
829+
830+
```{code-cell} ipython3
831+
from numba import njit, prange
832+
from numpy.random import randn
833+
834+
@njit(parallel=True)
835+
def generate_draws_numba(μ_a=-0.5,
836+
σ_a=0.1,
837+
μ_b=0.0,
838+
σ_b=0.5,
839+
μ_e=0.0,
840+
σ_e=0.5,
841+
s_bar=1.0,
842+
T=500,
843+
M=1_000_000,
844+
s_init=1.0):
694845
695846
draws = np.empty(M)
696847
for m in prange(M):
@@ -707,10 +858,12 @@ def generate_draws(μ_a=-0.5,
707858
708859
return draws
709860
710-
data = generate_draws()
861+
%time data = generate_draws_numba()
711862
```
712863

713-
Now we produce the rank-size plot:
864+
We can see that JAX and vectorization of the code have sped up the computation significantly compared to the Numba version.
865+
866+
We produce the rank-size plot again using the data, and it shows the same pattern we saw before:
714867

715868
```{code-cell} ipython3
716869
fig, ax = plt.subplots()
@@ -723,7 +876,5 @@ ax.set_ylabel("log size")
723876
plt.show()
724877
```
725878

726-
The plot produces a straight line, consistent with a Pareto tail.
727-
728879
```{solution-end}
729880
```

lectures/status.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ This table contains the latest execution statistics.
2020

2121
These lectures are built on `linux` instances through `github actions` and `amazon web services (aws)` to
2222
enable access to a `gpu`. These lectures are built on a [p3.2xlarge](https://aws.amazon.com/ec2/instance-types/p3/)
23-
that has access to `8 vcpu's`, a `V100 NVIDIA Tesla GPU`, and `61 Gb` of memory.
23+
that has access to `8 vcpu's`, a `V100 NVIDIA Tesla GPU`, and `61 Gb` of memory.

0 commit comments

Comments
 (0)