You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: lectures/kesten_processes.md
+160-9Lines changed: 160 additions & 9 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -19,6 +19,16 @@ kernelspec:
19
19
20
20
# Kesten Processes and Firm Dynamics
21
21
22
+
```{admonition} GPU in use
23
+
:class: warning
24
+
25
+
This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and JAX for GPU programming.
26
+
27
+
Free GPUs are available on Google Colab. To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU.
28
+
29
+
Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax#pip-installation-gpu-cuda) for installing JAX with GPU support.
30
+
```
31
+
22
32
```{index} single: Linear State Space Models
23
33
```
24
34
@@ -34,6 +44,9 @@ tags: [hide-output]
34
44
---
35
45
!pip install quantecon
36
46
!pip install --upgrade yfinance
47
+
# If your machine has CUDA support, please follow the guide in GPU Warning.
48
+
# Otherwise, run the line below:
49
+
!pip install --upgrade "jax[CPU]"
37
50
```
38
51
39
52
## Overview
@@ -673,14 +686,23 @@ s_init = 1.0 # initial condition for each firm
673
686
:class: dropdown
674
687
```
675
688
676
-
Here's one solution. First we generate the observations:
689
+
Here's one solution in [JAX](https://python-programming.quantecon.org/jax_intro.html).
690
+
691
+
First let's import the necessary modules and check the backend for JAX
The plot produces a straight line, consistent with a Pareto tail.
763
+
764
+
It is possible to further speed up our code by replacing the `for` loop with [`lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)
765
+
to reduce the loop overhead in the compilation of the jitted function
We can also use Numba with `for` loops to generate the observations (replicating the results we obtained with JAX).
827
+
828
+
The results will be slightly different since the pseudo random number generation is implemented [differently in JAX](https://www.kaggle.com/code/aakashnain/tf-jax-tutorials-part-6-prng-in-jax/notebook)
0 commit comments