Skip to content

Commit 5d324f5

Browse files
authored
Refactor the SGD method (#100)
* Refactor SGD method * Test that l2^2 norm is implemented correctly * Write tests for sgd minimiser * Refactor integration test to now use SGD method * Use actual norm-function rather than hacky piece-together * Tidy docstring * Qualify convergence condition default * Drop extra SGD info in test return * Apply code review suggestions
1 parent f3b8135 commit 5d324f5

File tree

6 files changed

+256
-30
lines changed

6 files changed

+256
-30
lines changed

src/causalprog/solvers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Solvers for Causal Problems."""

src/causalprog/solvers/sgd.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
"""Minimisation via Stochastic Gradient Descent."""
2+
3+
from collections.abc import Callable
4+
from copy import deepcopy
5+
6+
import jax
7+
import jax.numpy as jnp
8+
import numpy.typing as npt
9+
import optax
10+
11+
from causalprog.utils.norms import PyTree, l2_normsq
12+
13+
14+
def stochastic_gradient_descent(
15+
obj_fn: Callable[[PyTree], npt.ArrayLike],
16+
initial_guess: PyTree,
17+
*,
18+
convergence_criteria: Callable[[PyTree, PyTree], npt.ArrayLike] | None = None,
19+
fn_args: tuple | None = None,
20+
fn_kwargs: dict | None = None,
21+
learning_rate: float = 1.0e-1,
22+
maxiter: int = 1000,
23+
optimiser: optax.GradientTransformationExtraArgs | None = None,
24+
tolerance: float = 1.0e-8,
25+
) -> tuple[PyTree, npt.ArrayLike, npt.ArrayLike, int]:
26+
"""
27+
Minimise a function of one argument using Stochastic Gradient Descent (SGD).
28+
29+
The `obj_fn` provided will be minimised over its first argument. If you wish to
30+
minimise a function over a different argument, or multiple arguments, wrap it in a
31+
suitable `lambda` expression that has the correct call signature. For example, to
32+
minimise a function `f(x, y, z)` over `y` and `z`, use
33+
`g = lambda yz, x: f(x, yz[0], yz[1])`, and pass `g` in as `obj_fn`. Note that
34+
you will also need to provide a constant value for `x` via `fn_args` or `fn_kwargs`.
35+
36+
The `fn_args` and `fn_kwargs` keys can be used to supply additional parameters that
37+
need to be passed to `obj_fn`, but which should be held constant.
38+
39+
SGD terminates when the `convergence_criteria` is found to be smaller than the
40+
`tolerance`. That is, when
41+
`convergence_criteria(objective_value, gradient_value) <= tolerance` is found to
42+
be `True`, the algorithm considers a minimum to have been found. The default
43+
condition under which the algorithm terminates is when the norm of the gradient
44+
at the current argument value is smaller than the provided `tolerance`.
45+
46+
The optimiser to use can be selected by passing in a suitable `optax` optimiser
47+
via the `optimiser` command. By default, `optax.adams` is used with the supplied
48+
`learning_rate`. Providing an explicit value for `optimiser` will result in the
49+
`learning_rate` argument being ignored.
50+
51+
Args:
52+
obj_fn: Function to be minimised over its first argument.
53+
initial_guess: Initial guess for the minimising argument.
54+
convergence_criteria: The quantity that will be tested against `tolerance`, to
55+
determine whether the method has converged to a minimum. It should be a
56+
`callable` that takes the current value of `obj_fn` as its 1st argument, and
57+
the current value of the gradient of `obj_fn` as its 2nd argument. The
58+
default criteria is the l2-norm of the gradient.
59+
fn_args: Positional arguments to be passed to `obj_fn`, and held constant.
60+
fn_kwargs: Keyword arguments to be passed to `obj_fn`, and held constant.
61+
learning_rate: Default learning rate (or step size) to use when using the
62+
default `optimiser`. No effect if `optimiser` is provided explicitly.
63+
maxiter: Maximum number of iterations to perform. An error will be reported if
64+
this number of iterations is exceeded.
65+
optimiser: The `optax` optimiser to use during the update step.
66+
tolerance: `tolerance` used when determining if a minimum has been found.
67+
68+
Returns:
69+
Minimising argument of `obj_fn`.
70+
Value of `obj_fn` at the minimum.
71+
Gradient of `obj_fn` at the minimum.
72+
Number of iterations performed.
73+
74+
"""
75+
if not fn_args:
76+
fn_args = ()
77+
if not fn_kwargs:
78+
fn_kwargs = {}
79+
if not convergence_criteria:
80+
convergence_criteria = lambda _, dx: jnp.sqrt(l2_normsq(dx)) # noqa: E731
81+
if not optimiser:
82+
optimiser = optax.adam(learning_rate)
83+
84+
def objective(x: npt.ArrayLike) -> npt.ArrayLike:
85+
return obj_fn(x, *fn_args, **fn_kwargs)
86+
87+
def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool:
88+
return convergence_criteria(x, dx) < tolerance
89+
90+
gradient = jax.grad(objective)
91+
opt_state = optimiser.init(initial_guess)
92+
93+
current_params = deepcopy(initial_guess)
94+
gradient_value = gradient(current_params)
95+
for i in range(maxiter):
96+
updates, opt_state = optimiser.update(gradient_value, opt_state)
97+
current_params = optax.apply_updates(current_params, updates)
98+
99+
objective_value = objective(current_params)
100+
gradient_value = gradient(current_params)
101+
102+
if is_converged(objective_value, gradient_value):
103+
return current_params, objective_value, gradient_value, i + 1
104+
105+
msg = f"Did not converge after {i + 1} iterations."
106+
raise RuntimeError(msg)

src/causalprog/utils/norms.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""Misc collection of norm-like functions for PyTree structures."""
2+
3+
from typing import TypeVar
4+
5+
import jax
6+
import numpy.typing as npt
7+
8+
PyTree = TypeVar("PyTree")
9+
10+
11+
def l2_normsq(x: PyTree) -> npt.ArrayLike:
12+
"""
13+
Square of the l2-norm of a PyTree.
14+
15+
This is effectively "sum(elements**2 in leaf for leaf in x)".
16+
"""
17+
leaves, _ = jax.tree_util.tree_flatten(x)
18+
return sum(jax.numpy.sum(leaf**2) for leaf in leaves)

tests/test_integration/test_two_normal_example.py

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
import sys
21
from collections.abc import Callable
32

43
import jax
54
import jax.numpy as jnp
6-
import optax
75
import pytest
86

97
from causalprog.causal_problem.causal_problem import CausalProblem
108
from causalprog.causal_problem.components import CausalEstimand, Constraint
119
from causalprog.graph import Graph
10+
from causalprog.solvers.sgd import stochastic_gradient_descent
11+
from causalprog.utils.norms import l2_normsq
1212

1313

1414
@pytest.mark.parametrize(
@@ -92,9 +92,9 @@ def test_two_normal_example(
9292
# We'll be seeking stationary points of the Lagrangian, using the
9393
# naive approach of minimising the norm of its gradient. We will need to
9494
# ensure we "converge" to a minimum value suitably close to 0.
95-
def objective(params, l_mult, key):
96-
v = jax.grad(lagrangian, argnums=(0, 1))(params, l_mult, key)
97-
return sum(value**2 for value in v[0].values()) + (v[1] ** 2).sum()
95+
def objective(x, key):
96+
v = jax.grad(lagrangian, argnums=(0, 1))(*x, rng_key=key)
97+
return l2_normsq(v)
9898

9999
# Choose a starting guess that is at the optimal solution, in the hopes that
100100
# SGD converges quickly. We almost certainly will not have this luxury in general.
@@ -107,31 +107,17 @@ def objective(params, l_mult, key):
107107
}
108108
l_mult = jnp.atleast_1d(lagrange_mult_sol)
109109

110-
# Setup SGD optimiser
111-
optimiser = optax.adam(adams_learning_rate)
112-
opt_state = optimiser.init((params, l_mult))
113-
114-
# Run optimisation loop on gradient of the Lagrangian
115-
converged = False
116-
for _ in range(maxiter):
117-
# Actual iteration loop
118-
grads = jax.jacobian(objective, argnums=(0, 1))(params, l_mult, rng_key)
119-
updates, opt_state = optimiser.update(grads, opt_state)
120-
params, l_mult = optax.apply_updates((params, l_mult), updates)
121-
122-
# Convergence "check" and progress update
123-
objective_value = objective(params, l_mult, rng_key)
124-
sys.stdout.write(
125-
f"{_}, F_val={objective_value:.4e}, "
126-
f"mu_ux={params['mean']:.4e}, "
127-
f"nu_x={params['cov2']:.4e}, "
128-
f"lambda={l_mult[0]:.4e}\n"
129-
)
130-
if jnp.abs(objective_value) <= minimisation_tolerance:
131-
converged = True
132-
break
133-
134-
assert converged, f"Did not converge, final objective value: {objective_value}"
110+
opt_params, _, _, _ = stochastic_gradient_descent(
111+
objective,
112+
(params, l_mult),
113+
convergence_criteria=lambda x, _: jnp.abs(x),
114+
fn_kwargs={"key": rng_key},
115+
learning_rate=adams_learning_rate,
116+
maxiter=maxiter,
117+
tolerance=minimisation_tolerance,
118+
)
119+
# Unpack concatenated arguments
120+
params, l_mult = opt_params
135121

136122
# The lagrangian is independent of nu_x, thus it should not have changed value.
137123
assert jnp.isclose(params["cov2"], nu_x_starting_value), (

tests/test_solvers/test_sgd.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from collections.abc import Callable
2+
from typing import Any
3+
4+
import jax
5+
import jax.numpy as jnp
6+
import numpy.typing as npt
7+
import pytest
8+
9+
from causalprog.solvers.sgd import stochastic_gradient_descent
10+
from causalprog.utils.norms import PyTree
11+
12+
13+
@pytest.mark.parametrize(
14+
(
15+
"obj_fn",
16+
"initial_guess",
17+
"expected",
18+
"kwargs_to_sgd",
19+
),
20+
[
21+
pytest.param(
22+
lambda x: (x**2).sum(),
23+
jnp.atleast_1d(1.0),
24+
jnp.atleast_1d(0.0),
25+
None,
26+
id="Deterministic x**2",
27+
),
28+
pytest.param(
29+
lambda x: (x**2).sum(),
30+
jnp.atleast_1d(10.0),
31+
RuntimeError("Did not converge after 1 iterations"),
32+
{"maxiter": 1},
33+
id="Reaches iteration limit",
34+
),
35+
pytest.param(
36+
lambda x: (x**2).sum(),
37+
jnp.atleast_1d(1.0),
38+
jnp.atleast_1d(0.9),
39+
{
40+
"convergence_criteria": lambda x, _: jnp.abs(x.sum()),
41+
"tolerance": 1.0e0,
42+
"learning_rate": 1e-1,
43+
},
44+
id="Converge on function value less than 1",
45+
),
46+
pytest.param(
47+
lambda x, a: ((x - a) ** 2).sum(),
48+
jnp.atleast_1d(1.0),
49+
jnp.atleast_1d(2.0),
50+
{
51+
"fn_args": (2.0,),
52+
},
53+
id="Fix positional argument",
54+
),
55+
pytest.param(
56+
lambda x, *, a: ((x - a) ** 2).sum(),
57+
jnp.atleast_1d(1.0),
58+
jnp.atleast_1d(2.0),
59+
{
60+
"fn_kwargs": {"a": 2.0},
61+
},
62+
id="Fix keyword argument",
63+
),
64+
],
65+
)
66+
def test_sgd(
67+
obj_fn: Callable[[PyTree], npt.ArrayLike],
68+
initial_guess: PyTree,
69+
kwargs_to_sgd: dict[str, Any],
70+
expected: PyTree | Exception,
71+
raises_context,
72+
) -> None:
73+
"""Test the SGD method on a (deterministic) problem.
74+
75+
This is just an assurance check that all the components of the method are working
76+
as intended. In each test case, we minimise (a variation of) x**2, changing the
77+
options that we pass to the SGD solver.
78+
"""
79+
if not kwargs_to_sgd:
80+
kwargs_to_sgd = {}
81+
82+
if isinstance(expected, Exception):
83+
with raises_context(expected):
84+
stochastic_gradient_descent(obj_fn, initial_guess, **kwargs_to_sgd)
85+
else:
86+
result = stochastic_gradient_descent(obj_fn, initial_guess, **kwargs_to_sgd)[0]
87+
88+
assert jax.tree_util.tree_all(
89+
jax.tree_util.tree_map(jax.numpy.allclose, result, expected)
90+
)

tests/test_utils/test_norms.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from collections.abc import Callable
2+
3+
import numpy as np
4+
import pytest
5+
6+
from causalprog.utils.norms import PyTree, l2_normsq
7+
8+
9+
@pytest.mark.parametrize(
10+
("pt", "norm", "expected_value"),
11+
[
12+
pytest.param(1.0, l2_normsq, 1.0, id="l2^2, scalar"),
13+
pytest.param(
14+
np.array([1.0, 2.0, 3.0]), l2_normsq, 14.0, id="l2^2, numpy array"
15+
),
16+
pytest.param(
17+
{"a": 1.0, "b": (np.arange(3), [2.0, (-1.0, 0.0)])},
18+
l2_normsq,
19+
1.0 + (np.arange(3) ** 2).sum() + 4.0 + 1.0,
20+
id="l2^2, PyTree",
21+
),
22+
],
23+
)
24+
def test_norm_value(pt: PyTree, norm: Callable[[PyTree], float], expected_value: float):
25+
assert np.allclose(norm(pt), expected_value)

0 commit comments

Comments
 (0)