You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -258,7 +258,8 @@ we start with a guess $x_0$ of the fixed
258
258
point and then update by solving for the fixed point of a tangent line at
259
259
$x_0$.
260
260
261
-
To begin with, we recall that the first order approximation of $g$ at $x_0$ is
261
+
To begin with, we recall that the first-order approximation of $g$ at $x_0$
262
+
(i.e., the first order Taylor approximation of $g$ at $x_0$) is
262
263
the function
263
264
264
265
```{math}
@@ -364,7 +365,7 @@ the problem of finding fixed points.
364
365
365
366
### Newton's Method for Zeros
366
367
367
-
Let's suppose we want to find an $x$ such that $f(x)=0$ for some given
368
+
Let's suppose we want to find an $x$ such that $f(x)=0$ for some smooth
368
369
function $f$ mapping real numbers to real numbers.
369
370
370
371
Suppose we have a guess $x_0$ and we want to update it to a new point $x_1$.
@@ -737,14 +738,14 @@ With only slight modification, we can generalize [our previous attempt](first_ne
737
738
```{code-cell} python3
738
739
def newton(f, x_0, tol=1e-5, max_iter=10):
739
740
x = x_0
740
-
iteration = jax.jit(lambda x: x - jnp.linalg.solve(jax.jacobian(f)(x), f(x)))
741
+
q = jax.jit(lambda x: x - jnp.linalg.solve(jax.jacobian(f)(x), f(x)))
741
742
error = tol + 1
742
743
n = 0
743
744
while error > tol:
744
745
n+=1
745
746
if(n > max_iter):
746
747
raise Exception('Max iteration reached without convergence')
747
-
y = iteration(x)
748
+
y = q(x)
748
749
if(any(jnp.isnan(y))):
749
750
raise Exception('Solution not found with NaN generated')
750
751
error = jnp.linalg.norm(x - y)
@@ -782,6 +783,8 @@ However, things will change when we move to higher dimensional problems.
782
783
783
784
Our next step is to investigate a large market with 5,000 goods.
784
785
786
+
To handle this large problem we will use Google JAX.
787
+
785
788
The excess demand function is essentially the same, but now the matrix $A$ is $5000 \times 5000$ and the parameter vectors $b$ and $c$ are $5000 \times 1$.
786
789
787
790
@@ -800,7 +803,7 @@ b = jnp.ones(dim)
800
803
c = jnp.ones(dim)
801
804
```
802
805
803
-
Here's the same demand function using `jax.numpy`:
806
+
Here is essentially the same demand function we applied before, but now using `jax.numpy` for the calculations.
804
807
805
808
```{code-cell} python3
806
809
def e(p, A, b, c):
@@ -813,7 +816,9 @@ Here's our initial condition
813
816
init_p = jnp.ones(dim)
814
817
```
815
818
816
-
Newton's method reaches a relatively small error within 10 seconds
819
+
By leveraging the power of Newton's method, JAX accelerated linear algebra,
820
+
automatic differentiation, and a GPU, we obtain a relatively small error for
821
+
this very large problem in just a few seconds:
817
822
818
823
```{code-cell} python3
819
824
%%time
@@ -824,7 +829,8 @@ p = newton(lambda p: e(p, A, b, c), init_p).block_until_ready()
824
829
np.max(np.abs(e(p, A, b, c)))
825
830
```
826
831
827
-
With the same tolerance, the `root` function would cost minutes to run with jacobian supplied
832
+
With the same tolerance, SciPy's `root` function takes much longer to run,
0 commit comments