Skip to content

Commit 604d707

Browse files
Smit-createHumphreyYangmmcky
authored
Update jax box at top of lecture, update CUDA to 11.8 (#319)
* Fix jax upgrade with GPU * check using cuda * try cudnn82 * revert back to default * check cudnn installation status * try docker/cuda=11.8 * remove failing cmd * set jax==0.4.2 * adjust pip install and moved the jax[cpu] into the gpu warning box --------- Co-authored-by: Humphrey Yang <u6474961@anu.edu.au> Co-authored-by: mmcky <mamckay@gmail.com>
1 parent 0f73d25 commit 604d707

File tree

3 files changed

+17
-22
lines changed

3 files changed

+17
-22
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
needs: deploy-runner
2525
runs-on: [self-hosted, cml-gpu]
2626
container:
27-
image: docker://nvidia/cuda:11.2.1-devel-ubuntu20.04
27+
image: docker://nvidia/cuda:11.8.0-devel-ubuntu22.04
2828
options: --gpus all
2929
steps:
3030
- uses: actions/checkout@v3

lectures/kesten_processes.md

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ kernelspec:
1919

2020
# Kesten Processes and Firm Dynamics
2121

22-
```{admonition} GPU in use
22+
```{admonition} GPU
2323
:class: warning
2424
25-
This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and JAX for GPU programming.
25+
This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and JAX for GPU programming.
2626
2727
Free GPUs are available on Google Colab. To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU.
2828
29-
Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax) for installing JAX with GPU support.
29+
Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax) for installing JAX with GPU support. If you would like to install jax running on the `cpu` only you can use `pip install jax[cpu]`
3030
```
3131

3232
```{index} single: Linear State Space Models
@@ -44,9 +44,6 @@ tags: [hide-output]
4444
---
4545
!pip install quantecon
4646
!pip install --upgrade yfinance
47-
# If your machine has CUDA support, please follow the guide in GPU Warning.
48-
# Otherwise, run the line below:
49-
!pip install --upgrade "jax[CPU]"
5047
```
5148

5249
## Overview
@@ -686,7 +683,7 @@ s_init = 1.0 # initial condition for each firm
686683
:class: dropdown
687684
```
688685

689-
Here's one solution in [JAX](https://python-programming.quantecon.org/jax_intro.html).
686+
Here's one solution in [JAX](https://python-programming.quantecon.org/jax_intro.html).
690687

691688
First let's import the necessary modules and check the backend for JAX
692689

@@ -731,10 +728,10 @@ def generate_draws(μ_a=-0.5,
731728
exp_a = jnp.exp(a_random[t, :])
732729
exp_b = jnp.exp(b_random[t, :])
733730
exp_e = jnp.exp(e_random[t, :])
734-
s = s.at[:, t+1].set(jnp.where(s[:, t] < s_bar,
731+
s = s.at[:, t+1].set(jnp.where(s[:, t] < s_bar,
735732
exp_e,
736733
exp_a * s[:, t] + exp_b))
737-
734+
738735
return s[:, -1]
739736
740737
%time data = generate_draws().block_until_ready()
@@ -761,7 +758,7 @@ plt.show()
761758

762759
The plot produces a straight line, consistent with a Pareto tail.
763760

764-
It is possible to further speed up our code by replacing the `for` loop with [`lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)
761+
It is possible to further speed up our code by replacing the `for` loop with [`lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)
765762
to reduce the loop overhead in the compilation of the jitted function
766763

767764
```{code-cell} ipython3
@@ -779,10 +776,10 @@ def generate_draws_lax(μ_a=-0.5,
779776
M=1_000_000,
780777
s_init=1.0,
781778
seed=123):
782-
779+
783780
key = random.PRNGKey(seed)
784781
keys = random.split(key, 3)
785-
782+
786783
# Generate random draws and initial values
787784
a_random = μ_a + σ_a * random.normal(keys[0], (T, M))
788785
b_random = μ_b + σ_b * random.normal(keys[1], (T, M))
@@ -792,11 +789,11 @@ def generate_draws_lax(μ_a=-0.5,
792789
# Define the function for each update
793790
def update_s(s, a_b_e_draws):
794791
a, b, e = a_b_e_draws
795-
res = jnp.where(s < s_bar,
796-
jnp.exp(e),
792+
res = jnp.where(s < s_bar,
793+
jnp.exp(e),
797794
jnp.exp(a) * s + jnp.exp(b))
798795
return res, res
799-
796+
800797
# Use lax.scan to perform the calculations on all states
801798
s_final, _ = lax.scan(update_s, s, (a_random, b_random, e_random))
802799
return s_final
@@ -877,4 +874,4 @@ plt.show()
877874
```
878875

879876
```{solution-end}
880-
```
877+
```

lectures/wealth_dynamics.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ kernelspec:
1919

2020
# Wealth Distribution Dynamics
2121

22-
```{admonition} GPU in use
22+
```{admonition} GPU
2323
:class: warning
2424
2525
This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and JAX for GPU programming.
2626
2727
Free GPUs are available on Google Colab. To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU.
2828
29-
Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax#pip-installation-gpu-cuda) for installing JAX with GPU support.
29+
Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax) for installing JAX with GPU support. If you would like to install jax running on the `cpu` only you can use `pip install jax[cpu]`
3030
```
3131

3232
```{contents} Contents
@@ -40,9 +40,7 @@ In addition to what's in Anaconda, this lecture will need the following librarie
4040
tags: [hide-output]
4141
---
4242
!pip install quantecon
43-
# If your machine has CUDA support, please follow the guide in GPU Warning.
44-
# Otherwise, run the line below:
45-
!pip install --upgrade "jax[CPU]"
43+
!pip install myst-nb
4644
```
4745

4846
## Overview

0 commit comments

Comments
 (0)