diff --git a/pyproject.toml b/pyproject.toml index 4750ef0..9a314c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ classifiers = [ "Programming Language :: Python :: 3.13", "Typing :: Typed", ] -dependencies = ["jax", "networkx", "numpy", "numpyro", "typing_extensions"] +dependencies = ["jax", "networkx", "numpy", "numpyro", "tqdm", "typing_extensions"] description = "A Python package for causal modelling and inference with stochastic causal programming" dynamic = ["version"] keywords = [] diff --git a/src/causalprog/solvers/iteration_result.py b/src/causalprog/solvers/iteration_result.py new file mode 100644 index 0000000..31aec62 --- /dev/null +++ b/src/causalprog/solvers/iteration_result.py @@ -0,0 +1,74 @@ +"""Container classes for outputs from each iteration of solver methods.""" + +from dataclasses import dataclass, field + +import numpy.typing as npt + +from causalprog.utils.norms import PyTree + + +@dataclass(frozen=False, slots=True) +class IterationResult: + """ + Result container for iterative solvers with optional history logging. + + Stores the latest iterate and if `history_logging_interval > 0`, `update` appends + snapshots of the iterate to corresponding history lists each time the iteration + number is a multiple of `history_logging_interval`. + Instances are mutable but do not allow dynamic attribute creation. + + Args: + fn_args: Argument to the objective function at final iteration (the solution, + if `successful is `True`). + grad_val: Value of the gradient of the objective function at the `fn_args`. + iters: Number of iterations performed. + obj_val: Value of the objective function at `fn_args`. + iter_history: List of iteration numbers at which history was logged. + fn_args_history: List of `fn_args` at each logged iteration. + grad_val_history: List of `grad_val` at each logged iteration. + obj_val_history: List of `obj_val` at each logged iteration. + history_logging_interval: Interval at which to log history. If + `history_logging_interval <= 0`, then no history is logged. + + """ + + fn_args: PyTree + grad_val: PyTree + iters: int + obj_val: npt.ArrayLike + history_logging_interval: int = 0 + + iter_history: list[int] = field(default_factory=list) + fn_args_history: list[PyTree] = field(default_factory=list) + grad_val_history: list[PyTree] = field(default_factory=list) + obj_val_history: list[npt.ArrayLike] = field(default_factory=list) + + _log_enabled: bool = field(init=False, repr=False) + + def __post_init__(self) -> None: + self._log_enabled = self.history_logging_interval > 0 + + def update( + self, + current_params: PyTree, + gradient_value: PyTree, + iters: int, + objective_value: npt.ArrayLike, + ) -> None: + """ + Update the `IterationResult` object with current iteration data. + + Only updates the history if `history_logging_interval` is positive and + the current iteration is a multiple of `history_logging_interval`. + + """ + self.fn_args = current_params + self.grad_val = gradient_value + self.iters = iters + self.obj_val = objective_value + + if self._log_enabled and iters % self.history_logging_interval == 0: + self.iter_history.append(iters) + self.fn_args_history.append(current_params) + self.grad_val_history.append(gradient_value) + self.obj_val_history.append(objective_value) diff --git a/src/causalprog/solvers/sgd.py b/src/causalprog/solvers/sgd.py index 141d5e3..ac80d51 100644 --- a/src/causalprog/solvers/sgd.py +++ b/src/causalprog/solvers/sgd.py @@ -8,6 +8,8 @@ import numpy.typing as npt import optax +from causalprog.solvers.iteration_result import IterationResult +from causalprog.solvers.solver_callbacks import _normalise_callbacks, _run_callbacks from causalprog.solvers.solver_result import SolverResult from causalprog.utils.norms import PyTree, l2_normsq @@ -23,6 +25,10 @@ def stochastic_gradient_descent( maxiter: int = 1000, optimiser: optax.GradientTransformationExtraArgs | None = None, tolerance: float = 1.0e-8, + history_logging_interval: int = -1, + callbacks: Callable[[IterationResult], None] + | list[Callable[[IterationResult], None]] + | None = None, ) -> SolverResult: """ Minimise a function of one argument using Stochastic Gradient Descent (SGD). @@ -65,12 +71,17 @@ def stochastic_gradient_descent( this number of iterations is exceeded. optimiser: The `optax` optimiser to use during the update step. tolerance: `tolerance` used when determining if a minimum has been found. + history_logging_interval: Interval (in number of iterations) at which to log + the history of optimisation. If history_logging_interval <= 0, no + history is logged. + callbacks: A `callable` or list of `callables` that take an + `IterationResult` as their only argument, and return `None`. + These will be called at the end of each iteration of the optimisation + procedure. + Returns: - Minimising argument of `obj_fn`. - Value of `obj_fn` at the minimum. - Gradient of `obj_fn` at the minimum. - Number of iterations performed. + SolverResult: Result of the optimisation procedure. """ if not fn_args: @@ -82,21 +93,39 @@ def stochastic_gradient_descent( if not optimiser: optimiser = optax.adam(learning_rate) + callbacks = _normalise_callbacks(callbacks) + def objective(x: npt.ArrayLike) -> npt.ArrayLike: return obj_fn(x, *fn_args, **fn_kwargs) def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: return convergence_criteria(x, dx) < tolerance - converged = False + value_and_grad_fn = jax.jit(jax.value_and_grad(objective)) + # init state opt_state = optimiser.init(initial_guess) current_params = deepcopy(initial_guess) - gradient = jax.grad(objective) + converged = False + objective_value, gradient_value = value_and_grad_fn(current_params) - for _ in range(maxiter + 1): - objective_value = objective(current_params) - gradient_value = gradient(current_params) + iter_result = IterationResult( + fn_args=current_params, + grad_val=gradient_value, + iters=0, + obj_val=objective_value, + history_logging_interval=history_logging_interval, + ) + + for current_iter in range(maxiter + 1): + iter_result.update( + current_params=current_params, + gradient_value=gradient_value, + iters=current_iter, + objective_value=objective_value, + ) + + _run_callbacks(iter_result, callbacks) if converged := is_converged(objective_value, gradient_value): break @@ -104,7 +133,9 @@ def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: updates, opt_state = optimiser.update(gradient_value, opt_state) current_params = optax.apply_updates(current_params, updates) - iters_used = _ + objective_value, gradient_value = value_and_grad_fn(current_params) + + iters_used = current_iter reason_msg = ( f"Did not converge after {iters_used} iterations" if not converged else "" ) @@ -117,4 +148,8 @@ def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: obj_val=objective_value, reason=reason_msg, successful=converged, + iter_history=iter_result.iter_history, + fn_args_history=iter_result.fn_args_history, + grad_val_history=iter_result.grad_val_history, + obj_val_history=iter_result.obj_val_history, ) diff --git a/src/causalprog/solvers/solver_callbacks.py b/src/causalprog/solvers/solver_callbacks.py new file mode 100644 index 0000000..1d40cae --- /dev/null +++ b/src/causalprog/solvers/solver_callbacks.py @@ -0,0 +1,61 @@ +"""Module for callback functions for solvers.""" + +from collections.abc import Callable, Collection + +from tqdm.auto import tqdm + +from causalprog.solvers.iteration_result import IterationResult + + +def _normalise_callbacks( + callbacks: Callable[[IterationResult], None] + | Collection[Callable[[IterationResult], None]] + | None = None, +) -> list[Callable[[IterationResult], None]]: + if callbacks is None: + return [] + if callable(callbacks): + return [callbacks] + if all(callable(cb) for cb in callbacks): + return list(callbacks) + + msg = "Callbacks must be a callable or a sequence of callables" + raise TypeError(msg) + + +def _run_callbacks( + iter_result: IterationResult, + callbacks: list[Callable[[IterationResult], None]], +) -> None: + for cb in callbacks: + cb(iter_result) + + +def tqdm_callback(total: int) -> Callable[[IterationResult], None]: + """ + Progress bar callback using `tqdm`. + + Creates a callback function that can be passed to solvers to display a progress bar + during optimization. The progress bar updates based on the number of iterations and + also displays the current objective value. + + Args: + total: Total number of iterations for the progress bar. + + Returns: + Callback function that updates the progress bar. + + """ + bar = tqdm(total=total) + last_it = {"i": 0} + + def cb(ir: IterationResult) -> None: + step = ir.iters - last_it["i"] + if step > 0: + bar.update(step) + + # Show objective and grad norm + bar.set_postfix(obj=float(ir.obj_val)) + last_it["i"] = ir.iters + + return cb diff --git a/src/causalprog/solvers/solver_result.py b/src/causalprog/solvers/solver_result.py index eb09457..e2fe8f6 100644 --- a/src/causalprog/solvers/solver_result.py +++ b/src/causalprog/solvers/solver_result.py @@ -1,6 +1,6 @@ """Container class for outputs from solver methods.""" -from dataclasses import dataclass +from dataclasses import dataclass, field import numpy.typing as npt @@ -26,6 +26,10 @@ class SolverResult: successful: `True` if solver converged, in which case `fn_args` is the argument to the objective function at the solution of the problem being solved. `False` otherwise. + iter_history: List of iteration numbers at which history was logged. + fn_args_history: List of `fn_args` at each logged iteration. + grad_val_history: List of `grad_val` at each logged iteration. + obj_val_history: List of `obj_val` at each logged iteration. """ @@ -36,3 +40,8 @@ class SolverResult: obj_val: npt.ArrayLike reason: str successful: bool + + iter_history: list[int] = field(default_factory=list) + fn_args_history: list[PyTree] = field(default_factory=list) + grad_val_history: list[PyTree] = field(default_factory=list) + obj_val_history: list[npt.ArrayLike] = field(default_factory=list) diff --git a/tests/fixtures/solvers.py b/tests/fixtures/solvers.py new file mode 100644 index 0000000..0b5aa91 --- /dev/null +++ b/tests/fixtures/solvers.py @@ -0,0 +1,16 @@ +from collections.abc import Callable + +import numpy.typing as npt +import pytest + +from causalprog.utils.norms import PyTree + + +@pytest.fixture +def sum_of_squares_obj() -> Callable[[PyTree], npt.ArrayLike]: + """f(x) = ||x||_2^2 = sum_i x_i^2""" + + def _inner(x): + return (x**2).sum() + + return _inner diff --git a/tests/test_solvers/test_normalise_solver_inputs.py b/tests/test_solvers/test_normalise_solver_inputs.py new file mode 100644 index 0000000..3329495 --- /dev/null +++ b/tests/test_solvers/test_normalise_solver_inputs.py @@ -0,0 +1,35 @@ +import pytest + +from causalprog.solvers.iteration_result import IterationResult +from causalprog.solvers.solver_callbacks import _normalise_callbacks + + +def _cb(ir: IterationResult) -> None: + """Placeholder callback""" + + +@pytest.mark.parametrize( + ("func_input", "expected"), + [ + (_cb, [_cb]), + ([_cb, _cb], [_cb, _cb]), + (None, []), + ([], []), + (42, TypeError("'int' object is not iterable")), + ], + ids=[ + "single callable", + "list of two callables", + "callbacks=None", + "callbacks=[]", + "callbacks=42", + ], +) +def test_normalise_callbacks(func_input, expected, raises_context) -> None: + """Test that callbacks are normalised correctly.""" + + if isinstance(expected, Exception): + with raises_context(expected): + _normalise_callbacks(func_input) + else: + assert _normalise_callbacks(func_input) == expected diff --git a/tests/test_solvers/test_sgd.py b/tests/test_solvers/test_sgd.py index f602ee6..a583743 100644 --- a/tests/test_solvers/test_sgd.py +++ b/tests/test_solvers/test_sgd.py @@ -6,6 +6,7 @@ import numpy.typing as npt import pytest +from causalprog.solvers.iteration_result import IterationResult from causalprog.solvers.sgd import stochastic_gradient_descent from causalprog.utils.norms import PyTree @@ -87,3 +88,242 @@ def test_sgd( assert jax.tree_util.tree_all( jax.tree_util.tree_map(jax.numpy.allclose, result.fn_args, expected) ) + + +@pytest.mark.parametrize( + ( + "history_logging_interval", + "expected_iters", + ), + [ + pytest.param( + 1, + list(range(11)), + id="interval=1", + ), + pytest.param( + 3, + list(range(0, 11, 3)), + id="interval=3", + ), + pytest.param( + 0, + [], + id="interval=0 (no logging)", + ), + ], +) +def test_sgd_history_logging_intervals( + sum_of_squares_obj, + history_logging_interval: int, + expected_iters: list[int], +) -> None: + """Test that history logging intervals work correctly.""" + + initial_guess = jnp.atleast_1d(1.0) + + result = stochastic_gradient_descent( + sum_of_squares_obj, + initial_guess, + maxiter=10, + tolerance=0.0, + history_logging_interval=history_logging_interval, + ) + + # Check that the correct iterations were logged + assert result.iter_history == expected_iters, ( + f"IterationResult.iter_history logged incorrectly. Expected {expected_iters}." + f"Got {result.iter_history}" + ) + + # Check that a correct number of fn_args, grad_val, obj_val were logged + assert len(result.fn_args_history) == len(expected_iters), ( + "IterationResult.fn_args_history logged incorrectly." + f"Expected {len(expected_iters)} entries. Got {len(result.fn_args_history)}" + ) + assert len(result.grad_val_history) == len(expected_iters), ( + "IterationResult.grad_val_history logged incorrectly." + f"Expected {len(expected_iters)} entries. Got {len(result.grad_val_history)}" + ) + assert len(result.obj_val_history) == len(expected_iters), ( + "IterationResult.obj_val_history logged incorrectly." + f"Expected {len(expected_iters)} entries. Got {len(result.obj_val_history)}" + ) + + # Check that logged fn_args, grad_val, obj_val line up correctly + value_and_grad_fn = jax.jit(jax.value_and_grad(sum_of_squares_obj)) + + if len(expected_iters) > 0: + for fn_args, obj_val, grad_val in zip( + result.fn_args_history, + result.obj_val_history, + result.grad_val_history, + strict=True, + ): + real_obj_val, real_grad_val = value_and_grad_fn(fn_args) + + # Check that logged obj_val and fn_args line up correctly + assert real_obj_val == obj_val, ( + "Logged obj_val does not match obj_fn evaluated at logged fn_args." + f"For fn_args {fn_args}, we expected {sum_of_squares_obj(fn_args)}," + "got {obj_val}." + ) + + # Check that logged gradient and fn_args line up correctly + assert real_grad_val == grad_val, ( + "Logged grad_val does not match gradient of obj_fn evaluated at" + f" logged fn_args. For fn_args {fn_args}, we expected" + f" {jax.gradient(sum_of_squares_obj)(fn_args)}, got {grad_val}." + ) + + +@pytest.mark.parametrize( + ( + "make_callbacks", + "expected", + ), + [ + ( + lambda cb: cb, + [0, 1, 2], + ), + ( + lambda cb: [cb], + [0, 1, 2], + ), + ( + lambda cb: [cb, cb], + [0, 0, 1, 1, 2, 2], + ), + ( + lambda cb: [], # noqa: ARG005 + [], + ), + ( + lambda cb: None, # noqa: ARG005 + [], + ), + ], + ids=[ + "single callable", + "list of one callable", + "list of two callables", + "callbacks=[]", + "callbacks=None", + ], +) +def test_sgd_callbacks_invocation( + sum_of_squares_obj, + make_callbacks: Callable, + expected: list[int], +) -> None: + """Test SGD invokes callbacks correctly for all shapes of callbacks input.""" + + calls = [] + + def callback(iter_result: IterationResult) -> None: + calls.append(iter_result.iters) + + callbacks = make_callbacks(callback) + + initial = jnp.atleast_1d(1.0) + + stochastic_gradient_descent( + sum_of_squares_obj, + initial, + maxiter=2, + tolerance=0.0, + callbacks=callbacks, + ) + + assert calls == expected, ( + f"Callback was not called correctly, got {calls}, expected {expected}" + ) + + +def test_sgd_invalid_callback(sum_of_squares_obj, raises_context) -> None: + initial = jnp.atleast_1d(1.0) + + with raises_context(TypeError("'int' object is not iterable")): + stochastic_gradient_descent( + sum_of_squares_obj, + initial, + maxiter=2, + tolerance=0.0, + callbacks=42, # type: ignore[arg-type] + ) + + +@pytest.mark.parametrize( + "history_logging_interval", [0, 1, 2], ids=lambda v: f"hist:{v}" +) +@pytest.mark.parametrize( + "make_callbacks", + [ + lambda cb: cb, + lambda cb: [cb], + lambda cb: [cb, cb], + lambda cb: [], # noqa: ARG005 + lambda cb: None, # noqa: ARG005 + ], + ids=["callable", "list_1", "list_2", "empty", "none"], +) +def test_logging_or_callbacks_affect_sgd_convergence( + sum_of_squares_obj, + history_logging_interval: int, + make_callbacks: Callable, +) -> None: + """Test that logging and callbacks don't affect convergence of SGD solver.""" + calls = [] + + def callback(iter_result: IterationResult) -> None: + calls.append(iter_result.iters) + + callbacks = make_callbacks(callback) + + initial_guess = jnp.atleast_1d(1.0) + + baseline_result = stochastic_gradient_descent( + sum_of_squares_obj, + initial_guess, + maxiter=6, + tolerance=0.0, + history_logging_interval=0, + ) + + result = stochastic_gradient_descent( + sum_of_squares_obj, + initial_guess, + maxiter=6, + tolerance=0.0, + history_logging_interval=history_logging_interval, + callbacks=callbacks, + ) + + baseline_attributes = [ + baseline_result.fn_args, + baseline_result.obj_val, + baseline_result.grad_val, + baseline_result.iters, + baseline_result.successful, + baseline_result.reason, + ] + + result_attributes = [ + result.fn_args, + result.obj_val, + result.grad_val, + result.iters, + result.successful, + result.reason, + ] + + for baseline_attr, result_attr in zip( + baseline_attributes, result_attributes, strict=True + ): + assert baseline_attr == result_attr, ( + "Logging or callbacks changed the convergence behaviour of the" + " solver. For history_logging_interval" + f" {history_logging_interval}, callbacks {callbacks}, expected" + f" {baseline_attributes}, got {result_attributes}" + )