Skip to content

Commit 5f8d7da

Browse files
authored
FIX: update for compatibility for jax>=0.3.2 (#214)
* FIX: update for compatibility for jax>=0.3.2 * minor update to note text * install jax from conda #205
1 parent 4f31434 commit 5f8d7da

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

lectures/back_prop.md

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ kernelspec:
1616
```{code-cell} ipython3
1717
:tags: [hide-output]
1818
19-
!pip install --upgrade jax jaxlib
19+
!conda install -y -c conda-forge jax jaxlib
2020
!conda install -y -c plotly plotly plotly-orca retrying
2121
```
2222

@@ -302,7 +302,6 @@ import jax.numpy as jnp
302302
from jax import grad, jit, jacfwd, vmap
303303
from jax import random
304304
import jax
305-
from jax.ops import index_update
306305
import plotly.graph_objects as go
307306
```
308307
@@ -329,25 +328,25 @@ def compute_xδw_seq(params, x):
329328
330329
h = jax.nn.sigmoid
331330
332-
xs = index_update(xs, 0, x)
331+
xs = xs.at[0].set(x)
333332
for i, (w, b) in enumerate(params[:-1]):
334333
output = w * xs[i] + b
335334
activation = h(output[0, 0])
336335
337336
# Store elements
338-
δ = index_update(δ, i, grad(h)(output[0, 0]))
339-
ws = index_update(ws, i, w[0, 0])
340-
bs = index_update(bs, i, b[0])
341-
xs = index_update(xs, i+1, activation)
337+
δ = δ.at[i].set(grad(h)(output[0, 0]))
338+
ws = ws.at[i].set(w[0, 0])
339+
bs = bs.at[i].set(b[0])
340+
xs = xs.at[i+1].set(activation)
342341
343342
final_w, final_b = params[-1]
344343
preds = final_w * xs[-2] + final_b
345344
346345
# Store elements
347-
δ = index_update(δ, -1, 1.)
348-
ws = index_update(ws, -1, final_w[0, 0])
349-
bs = index_update(ws, -1, final_b[0])
350-
xs = index_update(xs, -1, preds[0, 0])
346+
δ = δ.at[-1].set(1.)
347+
ws = ws.at[-1].set(final_w[0, 0])
348+
bs = bs.at[-1].set(final_b[0])
349+
xs = xs.at[-1].set(preds[0, 0])
351350
352351
return xs, δ, ws, bs
353352
@@ -595,8 +594,8 @@ from jax.lib import xla_bridge
595594
print(xla_bridge.get_backend().platform)
596595
```
597596
598-
```{note} Cloud Enivronment
599-
This lecture site is built in a server environment that doesn't have access to a `gpu`
597+
```{note}
598+
**Cloud Environment:** This lecture site is built in a server environment that doesn't have access to a `gpu`
600599
If you run this lecture locally this lets you know where your code is being executed, either
601600
via the `cpu` or the `gpu`
602601
```

0 commit comments

Comments
 (0)