Skip to content

Commit 9dd7a58

Browse files
authored
Use objects for solver outcomes, not errors and tuples (#101)
* Container class for solver results * SGD uses solver_result return value * Update typehint return type for SGD * Fix SGD tests * Fix two normal example test * Rename attribute to more sensible name
1 parent 5d324f5 commit 9dd7a58

File tree

4 files changed

+75
-22
lines changed

4 files changed

+75
-22
lines changed

src/causalprog/solvers/sgd.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy.typing as npt
99
import optax
1010

11+
from causalprog.solvers.solver_result import SolverResult
1112
from causalprog.utils.norms import PyTree, l2_normsq
1213

1314

@@ -22,7 +23,7 @@ def stochastic_gradient_descent(
2223
maxiter: int = 1000,
2324
optimiser: optax.GradientTransformationExtraArgs | None = None,
2425
tolerance: float = 1.0e-8,
25-
) -> tuple[PyTree, npt.ArrayLike, npt.ArrayLike, int]:
26+
) -> SolverResult:
2627
"""
2728
Minimise a function of one argument using Stochastic Gradient Descent (SGD).
2829
@@ -87,20 +88,33 @@ def objective(x: npt.ArrayLike) -> npt.ArrayLike:
8788
def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool:
8889
return convergence_criteria(x, dx) < tolerance
8990

90-
gradient = jax.grad(objective)
91-
opt_state = optimiser.init(initial_guess)
91+
converged = False
9292

93+
opt_state = optimiser.init(initial_guess)
9394
current_params = deepcopy(initial_guess)
94-
gradient_value = gradient(current_params)
95-
for i in range(maxiter):
96-
updates, opt_state = optimiser.update(gradient_value, opt_state)
97-
current_params = optax.apply_updates(current_params, updates)
95+
gradient = jax.grad(objective)
9896

97+
for _ in range(maxiter + 1):
9998
objective_value = objective(current_params)
10099
gradient_value = gradient(current_params)
101100

102-
if is_converged(objective_value, gradient_value):
103-
return current_params, objective_value, gradient_value, i + 1
101+
if converged := is_converged(objective_value, gradient_value):
102+
break
103+
104+
updates, opt_state = optimiser.update(gradient_value, opt_state)
105+
current_params = optax.apply_updates(current_params, updates)
104106

105-
msg = f"Did not converge after {i + 1} iterations."
106-
raise RuntimeError(msg)
107+
iters_used = _
108+
reason_msg = (
109+
f"Did not converge after {iters_used} iterations" if not converged else ""
110+
)
111+
112+
return SolverResult(
113+
fn_args=current_params,
114+
grad_val=gradient_value,
115+
iters=iters_used,
116+
maxiter=maxiter,
117+
obj_val=objective_value,
118+
reason=reason_msg,
119+
successful=converged,
120+
)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""Container class for outputs from solver methods."""
2+
3+
from dataclasses import dataclass
4+
5+
import numpy.typing as npt
6+
7+
from causalprog.utils.norms import PyTree
8+
9+
10+
@dataclass(frozen=True)
11+
class SolverResult:
12+
"""
13+
Container class for outputs from solver methods.
14+
15+
Instances of this class provide a container for useful information that
16+
comes out of running one of the solver methods on a causal problem.
17+
18+
Attributes:
19+
fn_args: Argument to the objective function at final iteration (the solution,
20+
if `successful is `True`).
21+
grad_val: Value of the gradient of the objective function at the `fn_args`.
22+
iters: Number of iterations performed.
23+
maxiter: Maximum number of iterations the solver was permitted to perform.
24+
obj_val: Value of the objective function at `fn_args`.
25+
reason: Human-readable string explaining success or reasons for solver failure.
26+
successful: `True` if solver converged, in which case `fn_args` is the
27+
argument to the objective function at the solution of the problem being
28+
solved. `False` otherwise.
29+
30+
"""
31+
32+
fn_args: PyTree
33+
grad_val: PyTree
34+
iters: int
35+
maxiter: int
36+
obj_val: npt.ArrayLike
37+
reason: str
38+
successful: bool

tests/test_integration/test_two_normal_example.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def objective(x, key):
107107
}
108108
l_mult = jnp.atleast_1d(lagrange_mult_sol)
109109

110-
opt_params, _, _, _ = stochastic_gradient_descent(
110+
result = stochastic_gradient_descent(
111111
objective,
112112
(params, l_mult),
113113
convergence_criteria=lambda x, _: jnp.abs(x),
@@ -116,8 +116,10 @@ def objective(x, key):
116116
maxiter=maxiter,
117117
tolerance=minimisation_tolerance,
118118
)
119+
assert result.successful, "SGD did not converge."
120+
119121
# Unpack concatenated arguments
120-
params, l_mult = opt_params
122+
params, l_mult = result.fn_args
121123

122124
# The lagrangian is independent of nu_x, thus it should not have changed value.
123125
assert jnp.isclose(params["cov2"], nu_x_starting_value), (

tests/test_solvers/test_sgd.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
pytest.param(
2929
lambda x: (x**2).sum(),
3030
jnp.atleast_1d(10.0),
31-
RuntimeError("Did not converge after 1 iterations"),
31+
"Did not converge after 1 iterations",
3232
{"maxiter": 1},
3333
id="Reaches iteration limit",
3434
),
@@ -67,8 +67,7 @@ def test_sgd(
6767
obj_fn: Callable[[PyTree], npt.ArrayLike],
6868
initial_guess: PyTree,
6969
kwargs_to_sgd: dict[str, Any],
70-
expected: PyTree | Exception,
71-
raises_context,
70+
expected: PyTree | str,
7271
) -> None:
7372
"""Test the SGD method on a (deterministic) problem.
7473
@@ -79,12 +78,12 @@ def test_sgd(
7978
if not kwargs_to_sgd:
8079
kwargs_to_sgd = {}
8180

82-
if isinstance(expected, Exception):
83-
with raises_context(expected):
84-
stochastic_gradient_descent(obj_fn, initial_guess, **kwargs_to_sgd)
85-
else:
86-
result = stochastic_gradient_descent(obj_fn, initial_guess, **kwargs_to_sgd)[0]
81+
result = stochastic_gradient_descent(obj_fn, initial_guess, **kwargs_to_sgd)
8782

83+
if isinstance(expected, str):
84+
assert not result.successful
85+
assert result.reason == expected
86+
else:
8887
assert jax.tree_util.tree_all(
89-
jax.tree_util.tree_map(jax.numpy.allclose, result, expected)
88+
jax.tree_util.tree_map(jax.numpy.allclose, result.fn_args, expected)
9089
)

0 commit comments

Comments
 (0)