Skip to content

Commit eaba080

Browse files
committed
Callbacks, history logging and speed boosts to sgd
To sgd.py: - Added single pass JIT grad and objective calculation for big speed - boosts. - Added history logging option for easy debugging and better understanding of optimisation process. - Added callbacks functionality, allowing for user specific callbacks, e.g. early stopping, live graph etc. Added iteration_result.py to monitor current state of convergence. Use of dataclass ensures backwards compatablity of callbacks. Added solver_callbacks.py for callbacks and associated funcs. Added history attributes to solver result.
1 parent 16afa77 commit eaba080

File tree

4 files changed

+181
-9
lines changed

4 files changed

+181
-9
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""Container classes for outputs from each iteration of solver methods."""
2+
3+
from dataclasses import dataclass, field
4+
5+
import numpy.typing as npt
6+
7+
from causalprog.utils.norms import PyTree
8+
9+
10+
@dataclass(frozen=False)
11+
class IterationResult:
12+
"""
13+
Container class storing state of solvers at iteration `iters`.
14+
15+
Args:
16+
fn_args: Argument to the objective function at final iteration (the solution,
17+
if `successful is `True`).
18+
grad_val: Value of the gradient of the objective function at the `fn_args`.
19+
iters: Number of iterations performed.
20+
obj_val: Value of the objective function at `fn_args`.
21+
iter_history: List of iteration numbers at which history was logged.
22+
fn_args_history: List of `fn_args` at each logged iteration.
23+
grad_val_history: List of `grad_val` at each logged iteration.
24+
obj_val_history: List of `obj_val` at each logged iteration.
25+
26+
"""
27+
28+
fn_args: PyTree
29+
grad_val: PyTree
30+
iters: int
31+
obj_val: npt.ArrayLike
32+
33+
iter_history: list[int] = field(default_factory=list)
34+
fn_args_history: list[PyTree] = field(default_factory=list)
35+
grad_val_history: list[PyTree] = field(default_factory=list)
36+
obj_val_history: list[npt.ArrayLike] = field(default_factory=list)
37+
38+
39+
def _update_iteration_result(
40+
iter_result: IterationResult,
41+
current_params: PyTree,
42+
gradient_value: PyTree,
43+
iters: int,
44+
objective_value: npt.ArrayLike,
45+
history_logging_interval: int,
46+
) -> None:
47+
"""
48+
Update the `IterationResult` object with current iteration data.
49+
50+
Only updates the history if `history_logging_interval` is positive and
51+
the current iteration is a multiple of `history_logging_interval`.
52+
53+
"""
54+
iter_result.fn_args = current_params
55+
iter_result.grad_val = gradient_value
56+
iter_result.iters = iters
57+
iter_result.obj_val = objective_value
58+
59+
if history_logging_interval > 0 and iters % history_logging_interval == 0:
60+
iter_result.iter_history.append(iters)
61+
iter_result.fn_args_history.append(current_params)
62+
iter_result.grad_val_history.append(gradient_value)
63+
iter_result.obj_val_history.append(objective_value)

src/causalprog/solvers/sgd.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
import numpy.typing as npt
99
import optax
1010

11+
from causalprog.solvers.iteration_result import (
12+
IterationResult,
13+
_update_iteration_result,
14+
)
15+
from causalprog.solvers.solver_callbacks import _normalise_callbacks, _run_callbacks
1116
from causalprog.solvers.solver_result import SolverResult
1217
from causalprog.utils.norms import PyTree, l2_normsq
1318

@@ -23,6 +28,10 @@ def stochastic_gradient_descent(
2328
maxiter: int = 1000,
2429
optimiser: optax.GradientTransformationExtraArgs | None = None,
2530
tolerance: float = 1.0e-8,
31+
history_logging_interval: int = -1,
32+
callbacks: Callable[[IterationResult], None]
33+
| list[Callable[[IterationResult], None]]
34+
| None = None,
2635
) -> SolverResult:
2736
"""
2837
Minimise a function of one argument using Stochastic Gradient Descent (SGD).
@@ -65,12 +74,17 @@ def stochastic_gradient_descent(
6574
this number of iterations is exceeded.
6675
optimiser: The `optax` optimiser to use during the update step.
6776
tolerance: `tolerance` used when determining if a minimum has been found.
77+
history_logging_interval: Interval (in number of iterations) at which to log
78+
the history of optimisation. If history_logging_interval <= 0, no
79+
history is logged.
80+
callbacks: A `callable` or list of `callables` that take an
81+
`IterationResult` as their only argument, and return `None`.
82+
These will be called at the end of each iteration of the optimisation
83+
procedure.
84+
6885
6986
Returns:
70-
Minimising argument of `obj_fn`.
71-
Value of `obj_fn` at the minimum.
72-
Gradient of `obj_fn` at the minimum.
73-
Number of iterations performed.
87+
SolverResult: Result of the optimisation procedure.
7488
7589
"""
7690
if not fn_args:
@@ -82,28 +96,49 @@ def stochastic_gradient_descent(
8296
if not optimiser:
8397
optimiser = optax.adam(learning_rate)
8498

99+
callbacks = _normalise_callbacks(callbacks)
100+
85101
def objective(x: npt.ArrayLike) -> npt.ArrayLike:
86102
return obj_fn(x, *fn_args, **fn_kwargs)
87103

88104
def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool:
89105
return convergence_criteria(x, dx) < tolerance
90106

91-
converged = False
107+
value_and_grad_fn = jax.jit(jax.value_and_grad(objective))
92108

109+
# init state
93110
opt_state = optimiser.init(initial_guess)
94111
current_params = deepcopy(initial_guess)
95-
gradient = jax.grad(objective)
112+
converged = False
113+
objective_value, gradient_value = value_and_grad_fn(current_params)
114+
115+
iter_result = IterationResult(
116+
fn_args=current_params,
117+
grad_val=gradient_value,
118+
iters=0,
119+
obj_val=objective_value,
120+
)
96121

97122
for _ in range(maxiter + 1):
98-
objective_value = objective(current_params)
99-
gradient_value = gradient(current_params)
123+
_update_iteration_result(
124+
iter_result,
125+
current_params,
126+
gradient_value,
127+
_,
128+
objective_value,
129+
history_logging_interval,
130+
)
131+
132+
_run_callbacks(iter_result, callbacks)
100133

101134
if converged := is_converged(objective_value, gradient_value):
102135
break
103136

104137
updates, opt_state = optimiser.update(gradient_value, opt_state)
105138
current_params = optax.apply_updates(current_params, updates)
106139

140+
objective_value, gradient_value = value_and_grad_fn(current_params)
141+
107142
iters_used = _
108143
reason_msg = (
109144
f"Did not converge after {iters_used} iterations" if not converged else ""
@@ -117,4 +152,8 @@ def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool:
117152
obj_val=objective_value,
118153
reason=reason_msg,
119154
successful=converged,
155+
iter_history=iter_result.iter_history,
156+
fn_args_history=iter_result.fn_args_history,
157+
grad_val_history=iter_result.grad_val_history,
158+
obj_val_history=iter_result.obj_val_history,
120159
)
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""Module for callback functions for solvers."""
2+
3+
from collections.abc import Callable
4+
5+
from tqdm.auto import tqdm
6+
7+
from causalprog.solvers.iteration_result import IterationResult
8+
9+
10+
def _normalise_callbacks(
11+
callbacks: Callable[[IterationResult], None]
12+
| list[Callable[[IterationResult], None]]
13+
| None = None,
14+
) -> list[Callable[[IterationResult], None]]:
15+
if callbacks is None:
16+
return []
17+
if callable(callbacks):
18+
return [callbacks]
19+
if isinstance(callbacks, list) and all(callable(cb) for cb in callbacks):
20+
return callbacks
21+
22+
msg = "Callbacks must be a callable or a sequence of callables"
23+
raise TypeError(msg)
24+
25+
26+
def _run_callbacks(
27+
iter_result: IterationResult,
28+
callbacks: list[Callable[[IterationResult], None]],
29+
) -> None:
30+
for cb in callbacks:
31+
cb(iter_result)
32+
33+
34+
def tqdm_callback(total: int) -> Callable[[IterationResult], None]:
35+
"""
36+
Progress bar callback using `tqdm`.
37+
38+
Creates a callback function that can be passed to solvers to display a progress bar
39+
during optimization. The progress bar updates based on the number of iterations and
40+
also displays the current objective value.
41+
42+
Args:
43+
total: Total number of iterations for the progress bar.
44+
45+
Returns:
46+
Callback function that updates the progress bar.
47+
48+
"""
49+
bar = tqdm(total=total)
50+
last_it = {"i": 0}
51+
52+
def cb(ir: IterationResult) -> None:
53+
step = ir.iters - last_it["i"]
54+
if step > 0:
55+
bar.update(step)
56+
57+
# Show objective and grad norm
58+
bar.set_postfix(obj=float(ir.obj_val))
59+
last_it["i"] = ir.iters
60+
61+
return cb

src/causalprog/solvers/solver_result.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Container class for outputs from solver methods."""
22

3-
from dataclasses import dataclass
3+
from dataclasses import dataclass, field
44

55
import numpy.typing as npt
66

@@ -26,6 +26,10 @@ class SolverResult:
2626
successful: `True` if solver converged, in which case `fn_args` is the
2727
argument to the objective function at the solution of the problem being
2828
solved. `False` otherwise.
29+
iter_history: List of iteration numbers at which history was logged.
30+
fn_args_history: List of `fn_args` at each logged iteration.
31+
grad_val_history: List of `grad_val` at each logged iteration.
32+
obj_val_history: List of `obj_val` at each logged iteration.
2933
3034
"""
3135

@@ -36,3 +40,8 @@ class SolverResult:
3640
obj_val: npt.ArrayLike
3741
reason: str
3842
successful: bool
43+
44+
iter_history: list[int] = field(default_factory=list)
45+
fn_args_history: list[PyTree] = field(default_factory=list)
46+
grad_val_history: list[PyTree] = field(default_factory=list)
47+
obj_val_history: list[npt.ArrayLike] = field(default_factory=list)

0 commit comments

Comments
 (0)