@@ -16,7 +16,7 @@ kernelspec:
1616``` {code-cell} ipython3
1717:tags: [hide-output]
1818
19- !pip install --upgrade jax jaxlib
19+ !conda install -y -c conda-forge jax jaxlib
2020!conda install -y -c plotly plotly plotly-orca retrying
2121```
2222
@@ -302,7 +302,6 @@ import jax.numpy as jnp
302302from jax import grad, jit, jacfwd, vmap
303303from jax import random
304304import jax
305- from jax.ops import index_update
306305import plotly.graph_objects as go
307306```
308307
@@ -329,25 +328,25 @@ def compute_xδw_seq(params, x):
329328
330329 h = jax.nn.sigmoid
331330
332- xs = index_update(xs, 0, x)
331+ xs = xs.at[0].set( x)
333332 for i, (w, b) in enumerate(params[:-1]):
334333 output = w * xs[i] + b
335334 activation = h(output[0, 0])
336335
337336 # Store elements
338- δ = index_update(δ, i, grad(h)(output[0, 0]))
339- ws = index_update(ws, i, w[0, 0])
340- bs = index_update(bs, i, b[0])
341- xs = index_update(xs, i+1, activation)
337+ δ = δ.at[i].set( grad(h)(output[0, 0]))
338+ ws = ws.at[i].set( w[0, 0])
339+ bs = bs.at[i].set( b[0])
340+ xs = xs.at[ i+1].set( activation)
342341
343342 final_w, final_b = params[-1]
344343 preds = final_w * xs[-2] + final_b
345344
346345 # Store elements
347- δ = index_update(δ, -1, 1.)
348- ws = index_update(ws, -1, final_w[0, 0])
349- bs = index_update(ws, -1, final_b[0])
350- xs = index_update(xs, -1, preds[0, 0])
346+ δ = δ.at[-1].set( 1.)
347+ ws = ws.at[-1].set( final_w[0, 0])
348+ bs = bs.at[-1].set( final_b[0])
349+ xs = xs.at[-1].set( preds[0, 0])
351350
352351 return xs, δ, ws, bs
353352
@@ -595,8 +594,8 @@ from jax.lib import xla_bridge
595594print(xla_bridge.get_backend().platform)
596595```
597596
598- ```{note} Cloud Enivronment
599- This lecture site is built in a server environment that doesn't have access to a `gpu`
597+ ```{note}
598+ **Cloud Environment:** This lecture site is built in a server environment that doesn't have access to a `gpu`
600599If you run this lecture locally this lets you know where your code is being executed, either
601600via the `cpu` or the `gpu`
602601```
0 commit comments