Skip to content

Commit 9f48fa0

Browse files
committed
improve based on comments
1 parent 1351d01 commit 9f48fa0

File tree

1 file changed

+37
-39
lines changed

1 file changed

+37
-39
lines changed

lectures/newton_method.md

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,12 @@ We first consider an easy, one-dimensional fixed point problem where we know the
3434

3535
Then we generalize Newton's method to multi-dimensional settings to solve market equilibrium with multiple goods.
3636

37+
In each step, we will refine and improve our implementation and compare our results to alternative methods.
38+
3739
We use the following imports in this lecture
3840

3941
```{code-cell} python3
4042
import numpy as np
41-
from numpy import exp, sqrt
4243
import matplotlib.pyplot as plt
4344
from collections import namedtuple
4445
from scipy.optimize import root
@@ -74,7 +75,7 @@ def create_solow_params(A=2.0, s=0.3, α=0.3, δ=0.4):
7475
return SolowParameters(A=A, s=s, α=α, δ=δ)
7576
```
7677

77-
The next two functions describe the [law of motion](motion_law) and the true fixed point $k^*$.
78+
The next two functions implements the law of motion [](motion_law) and the true fixed point $k^*$.
7879

7980
```{code-cell} python3
8081
def g(k, params):
@@ -206,7 +207,7 @@ def Dg(k, params):
206207
return α * A * s * k**(α-1) + (1 - δ)
207208
```
208209

209-
Here's a function $q$ representing the [formula for newtons' method above](newtons_method).
210+
Here's a function $q$ representing [](newtons_method).
210211

211212
```{code-cell} python3
212213
def q(k, params):
@@ -255,7 +256,7 @@ plot_trajectories(params)
255256

256257
We can see that Newton's Method reaches convergence faster than the successive approximation.
257258

258-
The above problem can be seen as a root-finding problem since the computation of a fixed point can be seen as approximating $x^*$ iteratively such that $g(x^*) - x^* = 0$.
259+
The above fixed-point calculation can be seen as a root-finding problem since the computation of a fixed point can be seen as approximating $x^*$ iteratively such that $g(x^*) - x^* = 0$.
259260

260261
For one-dimensional root-finding problems, Newton's method iterates on:
261262

@@ -266,15 +267,16 @@ x_{t+1} = x_t - \frac{ g(x_t) }{ g'(x_t) },
266267
\qquad x_0 \text{ given}
267268
```
268269

269-
This is also a more frequently used representation of Newton's method (in textbooks and online resources).
270+
Root-finding formula is also a more frequently used form of Newton's method.
270271

271272
The following code implements the iteration
272273

274+
(first_newton_attempt)=
273275
```{code-cell} python3
274276
def newton(g, Dg, x_0, tol, params=params, maxIter=10):
275277
x = x_0
276278
277-
# Implement the one-dimensional Newton's method
279+
# Implement the root-finding formula
278280
iteration = lambda x, params: x - g(x, params)/Dg(x, params)
279281
280282
error = tol + 1
@@ -310,7 +312,7 @@ The multi-dimensional variant will be left as an [exercise](newton_ex1).
310312

311313
By observing the formula of Newton's method, it is easy to see the possibility to implement Newton's method using Jacobian when we move up the ladder to higher dimensions.
312314

313-
This naturally leads us to use Newton's method to solve multi-dimensional problems for which we will use the powerful auto-differentiation functionality in `jax` to solve intricate calculations.
315+
This naturally leads us to use Newton's method to solve multi-dimensional problems for which we will use the powerful auto-differentiation functionality in JAX to do intricate calculations.
314316

315317
## Multivariate Newton’s Method
316318

@@ -385,7 +387,7 @@ The function below calculates the excess demand for given parameters
385387

386388
```{code-cell} python3
387389
def e(p, A, b, c):
388-
return exp(- A @ p) + c - b * sqrt(p)
390+
return np.exp(- A @ p) + c - b * np.sqrt(p)
389391
```
390392

391393

@@ -538,10 +540,10 @@ def jacobian(p, A, b, c):
538540
p_0, p_1 = p
539541
a_00, a_01 = A[0, :]
540542
a_10, a_11 = A[1, :]
541-
j_00 = -a_00 * exp(-a_00 * p_0) - (b[0]/2) * p_0**(-1/2)
542-
j_01 = -a_01 * exp(-a_01 * p_1)
543-
j_10 = -a_10 * exp(-a_10 * p_0)
544-
j_11 = -a_11 * exp(-a_11 * p_1) - (b[1]/2) * p_1**(-1/2)
543+
j_00 = -a_00 * np.exp(-a_00 * p_0) - (b[0]/2) * p_0**(-1/2)
544+
j_01 = -a_01 * np.exp(-a_01 * p_1)
545+
j_10 = -a_10 * np.exp(-a_10 * p_0)
546+
j_11 = -a_11 * np.exp(-a_11 * p_1) - (b[1]/2) * p_1**(-1/2)
545547
J = [[j_00, j_01],
546548
[j_10, j_11]]
547549
return np.array(J)
@@ -564,9 +566,7 @@ np.max(np.abs(e(p, A, b, c)))
564566

565567
#### Using Newton's Method
566568

567-
We can also use Newton's method to find the root.
568-
569-
We are going to try to compute the equilibrium price using the multivariate version of Newton's method, which means iterating on the equation:
569+
Now let's use Newton's method to compute the equilibrium price using the multivariate version of Newton's method:
570570

571571
```{math}
572572
:label: multi-newton
@@ -576,9 +576,9 @@ p_{n+1} = p_n - J_e(p_n)^{-1} e(p_n)
576576

577577
starting from some initial guess of the price vector $p_0$. (Here $J_e(p_n)$ is the Jacobian of $e$ evaluated at $p_n$.)
578578

579-
We use the `jax.jacobian()` function to auto-differentiate and calculate the jacobian.
579+
Instead of coding Jacobian by hand, We use the `jax.jacobian()` function to auto-differentiate and calculate Jacobian.
580580

581-
With only slight modification, we can generalize our previous attempt to multi-dimensional problems
581+
With only slight modification, we can generalize [our previous attempt](first_newton_attempt) to multi-dimensional problems
582582

583583
```{code-cell} python3
584584
def newton(f, x_0, tol=1e-5, maxIter=10):
@@ -616,11 +616,11 @@ p = newton(lambda p: e(p, A, b, c), init_p).block_until_ready()
616616
np.max(np.abs(e(p, A, b, c)))
617617
```
618618

619-
The error is almost 0.
619+
The result is very accurate.
620620

621621
With the larger overhead, the speed is not better than the optimized `scipy` function.
622622

623-
However, things will change slightly when we move to higher dimensional problems.
623+
However, things will change when we move to higher dimensional problems.
624624

625625

626626

@@ -646,7 +646,7 @@ b = jnp.ones(dim)
646646
c = jnp.ones(dim)
647647
```
648648

649-
Here's the same demand function expressed in matrix syntax:
649+
Here's the same demand function using `jax.numpy`:
650650

651651
```{code-cell} python3
652652
def e(p, A, b, c):
@@ -659,7 +659,7 @@ Here's our initial condition
659659
init_p = jnp.ones(dim)
660660
```
661661

662-
Newton's method reaches a relatively small error within a minute
662+
Newton's method reaches a relatively small error within 10 seconds
663663

664664
```{code-cell} python3
665665
%%time
@@ -687,7 +687,7 @@ p = solution.x
687687
np.max(np.abs(e(p, A, b, c)))
688688
```
689689

690-
And the result is less accurate.
690+
The result is also less accurate.
691691

692692
## Exercises
693693

@@ -722,7 +722,7 @@ $$
722722
723723
- The computation of fixed point can be seen as computing $k^*$ such that $f(k^*) - k^* = 0$.
724724
725-
- If you are unsure about your solution, you can start with the known solution to check your formula:
725+
- If you are unsure about your solution, you can start with the solved example:
726726
727727
```{math}
728728
A = \begin{pmatrix}
@@ -736,14 +736,10 @@ with $s = 0.3$, $α = 0.3$, and $δ = 0.4$ and starting value:
736736
737737
738738
```{math}
739-
k_0 = \begin{pmatrix}
740-
1 \\
741-
1 \\
742-
1
743-
\end{pmatrix}
739+
k_0 = (1, 1, 1)
744740
```
745741
746-
The result should converge to the [solved solution in the one-dimensional problem](solved_k).
742+
The result should converge to the [analytical solution](solved_k).
747743
````
748744

749745
```{exercise-end}
@@ -790,14 +786,14 @@ for init in initLs:
790786
attempt +=1
791787
```
792788

793-
We find that the results are invariant to the starting values given the well-defined property of this question. We can apply more a restrictive threshold for tolerance to achieve more accurate results.
789+
We find that the results are invariant to the starting values given the well-defined property of this question.
794790

795791
But the number of iterations it takes to converge is dependent on the starting values.
796792

797793
Substitute it back to the formulate to check our last result
798794

799795
```{code-cell} python3
800-
multivariate_solow(k)
796+
multivariate_solow(k) - k
801797
```
802798

803799
Note the error is very small.
@@ -822,10 +818,12 @@ init = jnp.repeat(1.0, 3)
822818

823819
The result is very close to the ground truth but still slightly different.
824820

825-
We can increase the precision of the floating point numbers and restrict the tolerance to obtain a more accurate approximation
821+
We can increase the precision of the floating point numbers and restrict the tolerance to obtain a more accurate approximation (see detailed discussion in the [lecture on JAX](https://python-programming.quantecon.org/jax_intro.html#differences))
826822

827823
```{code-cell} python3
828-
from jax.config import config; config.update("jax_enable_x64", True)
824+
from jax.config import config
825+
826+
config.update("jax_enable_x64", True)
829827
830828
init = init.astype('float64')
831829
@@ -834,7 +832,7 @@ init = init.astype('float64')
834832
tol=1e-7).block_until_ready()
835833
```
836834

837-
We can see Newton's method steps towards a more accurate solution.
835+
We can see it steps towards a more accurate solution.
838836

839837
```{solution-end}
840838
```
@@ -874,8 +872,8 @@ $$
874872
875873
\begin{aligned}
876874
p1_{0} &= (5, 5, 5) \\
877-
p2_{0} &= (4.25, 4.25, 4.25) \\
878-
p3_{0} &= (1, 1, 1)
875+
p2_{0} &= (1, 1, 1) \\
876+
p3_{0} &= (4.5, 0.1, 4)
879877
\end{aligned}
880878
$$
881879

@@ -885,7 +883,7 @@ Set the tolerance to $0.0$ for more accurate output.
885883
```{hint}
886884
:class: dropdown
887885
888-
Similar to [exercise 1](newton_ex1), enabling `float64` for `JAX` can improve the precision of our results.
886+
Similar to [exercise 1](newton_ex1), enabling `float64` for JAX can improve the precision of our results.
889887
```
890888

891889

@@ -910,7 +908,7 @@ c = jnp.array([1.0, 1.0, 1.0])
910908
911909
initLs = [jnp.repeat(5.0, 3),
912910
jnp.ones(3),
913-
jnp.array([4.5, 0.1, 4])]
911+
jnp.array([4.5, 0.1, 4.0])]
914912
```
915913

916914
Let’s run through each initial guess and check the output
@@ -933,7 +931,7 @@ We can find that Newton's method may fail for some starting values.
933931

934932
Sometimes it may take a few initial guesses to achieve convergence.
935933

936-
Substitute one result back to the formula to check our result
934+
Substitute the result back to the formula to check our result
937935

938936
```{code-cell} python3
939937
e(p, A, b, c)

0 commit comments

Comments
 (0)