Skip to content

Commit badf425

Browse files
committed
misc
1 parent 9f0909d commit badf425

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

lectures/newton_method.md

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ In other words, we seek a $k^* > 0$ such that $g(k^*)=k^*$.
120120

121121
Using pencil and paper to solve $g(k)=k$, you will be able to confirm that
122122

123-
$$ k^* = \left(\frac{s * A}{δ}\right)^{1/(1 - α)} $$
123+
$$ k^* = \left(\frac{s A}{δ}\right)^{1/(1 - α)} $$
124124

125125
### Implementation
126126

@@ -258,7 +258,8 @@ we start with a guess $x_0$ of the fixed
258258
point and then update by solving for the fixed point of a tangent line at
259259
$x_0$.
260260

261-
To begin with, we recall that the first order approximation of $g$ at $x_0$ is
261+
To begin with, we recall that the first-order approximation of $g$ at $x_0$
262+
(i.e., the first order Taylor approximation of $g$ at $x_0$) is
262263
the function
263264

264265
```{math}
@@ -364,7 +365,7 @@ the problem of finding fixed points.
364365

365366
### Newton's Method for Zeros
366367

367-
Let's suppose we want to find an $x$ such that $f(x)=0$ for some given
368+
Let's suppose we want to find an $x$ such that $f(x)=0$ for some smooth
368369
function $f$ mapping real numbers to real numbers.
369370

370371
Suppose we have a guess $x_0$ and we want to update it to a new point $x_1$.
@@ -737,14 +738,14 @@ With only slight modification, we can generalize [our previous attempt](first_ne
737738
```{code-cell} python3
738739
def newton(f, x_0, tol=1e-5, max_iter=10):
739740
x = x_0
740-
iteration = jax.jit(lambda x: x - jnp.linalg.solve(jax.jacobian(f)(x), f(x)))
741+
q = jax.jit(lambda x: x - jnp.linalg.solve(jax.jacobian(f)(x), f(x)))
741742
error = tol + 1
742743
n = 0
743744
while error > tol:
744745
n+=1
745746
if(n > max_iter):
746747
raise Exception('Max iteration reached without convergence')
747-
y = iteration(x)
748+
y = q(x)
748749
if(any(jnp.isnan(y))):
749750
raise Exception('Solution not found with NaN generated')
750751
error = jnp.linalg.norm(x - y)
@@ -782,6 +783,8 @@ However, things will change when we move to higher dimensional problems.
782783

783784
Our next step is to investigate a large market with 5,000 goods.
784785

786+
To handle this large problem we will use Google JAX.
787+
785788
The excess demand function is essentially the same, but now the matrix $A$ is $5000 \times 5000$ and the parameter vectors $b$ and $c$ are $5000 \times 1$.
786789

787790

@@ -800,7 +803,7 @@ b = jnp.ones(dim)
800803
c = jnp.ones(dim)
801804
```
802805

803-
Here's the same demand function using `jax.numpy`:
806+
Here is essentially the same demand function we applied before, but now using `jax.numpy` for the calculations.
804807

805808
```{code-cell} python3
806809
def e(p, A, b, c):
@@ -813,7 +816,9 @@ Here's our initial condition
813816
init_p = jnp.ones(dim)
814817
```
815818

816-
Newton's method reaches a relatively small error within 10 seconds
819+
By leveraging the power of Newton's method, JAX accelerated linear algebra,
820+
automatic differentiation, and a GPU, we obtain a relatively small error for
821+
this very large problem in just a few seconds:
817822

818823
```{code-cell} python3
819824
%%time
@@ -824,7 +829,8 @@ p = newton(lambda p: e(p, A, b, c), init_p).block_until_ready()
824829
np.max(np.abs(e(p, A, b, c)))
825830
```
826831

827-
With the same tolerance, the `root` function would cost minutes to run with jacobian supplied
832+
With the same tolerance, SciPy's `root` function takes much longer to run,
833+
even with the Jacobian supplied.
828834

829835

830836
```{code-cell} python3
@@ -843,6 +849,8 @@ np.max(np.abs(e(p, A, b, c)))
843849

844850
The result is also less accurate.
845851

852+
853+
846854
## Exercises
847855

848856
```{exercise-start}

0 commit comments

Comments
 (0)