From eaba0805baf3614118cbd07d77805d7388900c27 Mon Sep 17 00:00:00 2001 From: Sam Molyneux Date: Sun, 5 Oct 2025 17:46:00 +0100 Subject: [PATCH 01/12] Callbacks, history logging and speed boosts to sgd To sgd.py: - Added single pass JIT grad and objective calculation for big speed - boosts. - Added history logging option for easy debugging and better understanding of optimisation process. - Added callbacks functionality, allowing for user specific callbacks, e.g. early stopping, live graph etc. Added iteration_result.py to monitor current state of convergence. Use of dataclass ensures backwards compatablity of callbacks. Added solver_callbacks.py for callbacks and associated funcs. Added history attributes to solver result. --- src/causalprog/solvers/iteration_result.py | 63 ++++++++++++++++++++++ src/causalprog/solvers/sgd.py | 55 ++++++++++++++++--- src/causalprog/solvers/solver_callbacks.py | 61 +++++++++++++++++++++ src/causalprog/solvers/solver_result.py | 11 +++- 4 files changed, 181 insertions(+), 9 deletions(-) create mode 100644 src/causalprog/solvers/iteration_result.py create mode 100644 src/causalprog/solvers/solver_callbacks.py diff --git a/src/causalprog/solvers/iteration_result.py b/src/causalprog/solvers/iteration_result.py new file mode 100644 index 0000000..ae555f8 --- /dev/null +++ b/src/causalprog/solvers/iteration_result.py @@ -0,0 +1,63 @@ +"""Container classes for outputs from each iteration of solver methods.""" + +from dataclasses import dataclass, field + +import numpy.typing as npt + +from causalprog.utils.norms import PyTree + + +@dataclass(frozen=False) +class IterationResult: + """ + Container class storing state of solvers at iteration `iters`. + + Args: + fn_args: Argument to the objective function at final iteration (the solution, + if `successful is `True`). + grad_val: Value of the gradient of the objective function at the `fn_args`. + iters: Number of iterations performed. + obj_val: Value of the objective function at `fn_args`. + iter_history: List of iteration numbers at which history was logged. + fn_args_history: List of `fn_args` at each logged iteration. + grad_val_history: List of `grad_val` at each logged iteration. + obj_val_history: List of `obj_val` at each logged iteration. + + """ + + fn_args: PyTree + grad_val: PyTree + iters: int + obj_val: npt.ArrayLike + + iter_history: list[int] = field(default_factory=list) + fn_args_history: list[PyTree] = field(default_factory=list) + grad_val_history: list[PyTree] = field(default_factory=list) + obj_val_history: list[npt.ArrayLike] = field(default_factory=list) + + +def _update_iteration_result( + iter_result: IterationResult, + current_params: PyTree, + gradient_value: PyTree, + iters: int, + objective_value: npt.ArrayLike, + history_logging_interval: int, +) -> None: + """ + Update the `IterationResult` object with current iteration data. + + Only updates the history if `history_logging_interval` is positive and + the current iteration is a multiple of `history_logging_interval`. + + """ + iter_result.fn_args = current_params + iter_result.grad_val = gradient_value + iter_result.iters = iters + iter_result.obj_val = objective_value + + if history_logging_interval > 0 and iters % history_logging_interval == 0: + iter_result.iter_history.append(iters) + iter_result.fn_args_history.append(current_params) + iter_result.grad_val_history.append(gradient_value) + iter_result.obj_val_history.append(objective_value) diff --git a/src/causalprog/solvers/sgd.py b/src/causalprog/solvers/sgd.py index 141d5e3..08526b9 100644 --- a/src/causalprog/solvers/sgd.py +++ b/src/causalprog/solvers/sgd.py @@ -8,6 +8,11 @@ import numpy.typing as npt import optax +from causalprog.solvers.iteration_result import ( + IterationResult, + _update_iteration_result, +) +from causalprog.solvers.solver_callbacks import _normalise_callbacks, _run_callbacks from causalprog.solvers.solver_result import SolverResult from causalprog.utils.norms import PyTree, l2_normsq @@ -23,6 +28,10 @@ def stochastic_gradient_descent( maxiter: int = 1000, optimiser: optax.GradientTransformationExtraArgs | None = None, tolerance: float = 1.0e-8, + history_logging_interval: int = -1, + callbacks: Callable[[IterationResult], None] + | list[Callable[[IterationResult], None]] + | None = None, ) -> SolverResult: """ Minimise a function of one argument using Stochastic Gradient Descent (SGD). @@ -65,12 +74,17 @@ def stochastic_gradient_descent( this number of iterations is exceeded. optimiser: The `optax` optimiser to use during the update step. tolerance: `tolerance` used when determining if a minimum has been found. + history_logging_interval: Interval (in number of iterations) at which to log + the history of optimisation. If history_logging_interval <= 0, no + history is logged. + callbacks: A `callable` or list of `callables` that take an + `IterationResult` as their only argument, and return `None`. + These will be called at the end of each iteration of the optimisation + procedure. + Returns: - Minimising argument of `obj_fn`. - Value of `obj_fn` at the minimum. - Gradient of `obj_fn` at the minimum. - Number of iterations performed. + SolverResult: Result of the optimisation procedure. """ if not fn_args: @@ -82,21 +96,40 @@ def stochastic_gradient_descent( if not optimiser: optimiser = optax.adam(learning_rate) + callbacks = _normalise_callbacks(callbacks) + def objective(x: npt.ArrayLike) -> npt.ArrayLike: return obj_fn(x, *fn_args, **fn_kwargs) def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: return convergence_criteria(x, dx) < tolerance - converged = False + value_and_grad_fn = jax.jit(jax.value_and_grad(objective)) + # init state opt_state = optimiser.init(initial_guess) current_params = deepcopy(initial_guess) - gradient = jax.grad(objective) + converged = False + objective_value, gradient_value = value_and_grad_fn(current_params) + + iter_result = IterationResult( + fn_args=current_params, + grad_val=gradient_value, + iters=0, + obj_val=objective_value, + ) for _ in range(maxiter + 1): - objective_value = objective(current_params) - gradient_value = gradient(current_params) + _update_iteration_result( + iter_result, + current_params, + gradient_value, + _, + objective_value, + history_logging_interval, + ) + + _run_callbacks(iter_result, callbacks) if converged := is_converged(objective_value, gradient_value): break @@ -104,6 +137,8 @@ def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: updates, opt_state = optimiser.update(gradient_value, opt_state) current_params = optax.apply_updates(current_params, updates) + objective_value, gradient_value = value_and_grad_fn(current_params) + iters_used = _ reason_msg = ( f"Did not converge after {iters_used} iterations" if not converged else "" @@ -117,4 +152,8 @@ def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: obj_val=objective_value, reason=reason_msg, successful=converged, + iter_history=iter_result.iter_history, + fn_args_history=iter_result.fn_args_history, + grad_val_history=iter_result.grad_val_history, + obj_val_history=iter_result.obj_val_history, ) diff --git a/src/causalprog/solvers/solver_callbacks.py b/src/causalprog/solvers/solver_callbacks.py new file mode 100644 index 0000000..482d601 --- /dev/null +++ b/src/causalprog/solvers/solver_callbacks.py @@ -0,0 +1,61 @@ +"""Module for callback functions for solvers.""" + +from collections.abc import Callable + +from tqdm.auto import tqdm + +from causalprog.solvers.iteration_result import IterationResult + + +def _normalise_callbacks( + callbacks: Callable[[IterationResult], None] + | list[Callable[[IterationResult], None]] + | None = None, +) -> list[Callable[[IterationResult], None]]: + if callbacks is None: + return [] + if callable(callbacks): + return [callbacks] + if isinstance(callbacks, list) and all(callable(cb) for cb in callbacks): + return callbacks + + msg = "Callbacks must be a callable or a sequence of callables" + raise TypeError(msg) + + +def _run_callbacks( + iter_result: IterationResult, + callbacks: list[Callable[[IterationResult], None]], +) -> None: + for cb in callbacks: + cb(iter_result) + + +def tqdm_callback(total: int) -> Callable[[IterationResult], None]: + """ + Progress bar callback using `tqdm`. + + Creates a callback function that can be passed to solvers to display a progress bar + during optimization. The progress bar updates based on the number of iterations and + also displays the current objective value. + + Args: + total: Total number of iterations for the progress bar. + + Returns: + Callback function that updates the progress bar. + + """ + bar = tqdm(total=total) + last_it = {"i": 0} + + def cb(ir: IterationResult) -> None: + step = ir.iters - last_it["i"] + if step > 0: + bar.update(step) + + # Show objective and grad norm + bar.set_postfix(obj=float(ir.obj_val)) + last_it["i"] = ir.iters + + return cb diff --git a/src/causalprog/solvers/solver_result.py b/src/causalprog/solvers/solver_result.py index eb09457..e2fe8f6 100644 --- a/src/causalprog/solvers/solver_result.py +++ b/src/causalprog/solvers/solver_result.py @@ -1,6 +1,6 @@ """Container class for outputs from solver methods.""" -from dataclasses import dataclass +from dataclasses import dataclass, field import numpy.typing as npt @@ -26,6 +26,10 @@ class SolverResult: successful: `True` if solver converged, in which case `fn_args` is the argument to the objective function at the solution of the problem being solved. `False` otherwise. + iter_history: List of iteration numbers at which history was logged. + fn_args_history: List of `fn_args` at each logged iteration. + grad_val_history: List of `grad_val` at each logged iteration. + obj_val_history: List of `obj_val` at each logged iteration. """ @@ -36,3 +40,8 @@ class SolverResult: obj_val: npt.ArrayLike reason: str successful: bool + + iter_history: list[int] = field(default_factory=list) + fn_args_history: list[PyTree] = field(default_factory=list) + grad_val_history: list[PyTree] = field(default_factory=list) + obj_val_history: list[npt.ArrayLike] = field(default_factory=list) From 3fbed8dcceec9ba6f2fda12ba4097794c9650768 Mon Sep 17 00:00:00 2001 From: Sam Molyneux Date: Sun, 5 Oct 2025 21:17:15 +0100 Subject: [PATCH 02/12] Add tests for sgd callbacks and history logging --- .../test_normalise_solver_inputs.py | 29 ++ tests/test_solvers/test_sgd.py | 258 ++++++++++++++++++ 2 files changed, 287 insertions(+) create mode 100644 tests/test_solvers/test_normalise_solver_inputs.py diff --git a/tests/test_solvers/test_normalise_solver_inputs.py b/tests/test_solvers/test_normalise_solver_inputs.py new file mode 100644 index 0000000..ca1e32f --- /dev/null +++ b/tests/test_solvers/test_normalise_solver_inputs.py @@ -0,0 +1,29 @@ +import pytest + +from causalprog.solvers.iteration_result import IterationResult +from causalprog.solvers.solver_callbacks import _normalise_callbacks + + +def test_normalise_callbacks() -> None: + """Test that callbacks are normalised correctly.""" + + def callback(iter_result: IterationResult) -> None: + pass + + # Test single callable + assert _normalise_callbacks(callback) == [callback] + + # Test sequence of callables + assert _normalise_callbacks([callback, callback]) == [callback, callback] + + # Test None + assert _normalise_callbacks(None) == [] + + # Test empty sequence + assert _normalise_callbacks([]) == [] + + # Test invalid input + with pytest.raises( + TypeError, match="Callbacks must be a callable or a sequence of callables" + ): + _normalise_callbacks(42) # type: ignore[arg-type] diff --git a/tests/test_solvers/test_sgd.py b/tests/test_solvers/test_sgd.py index f602ee6..768a194 100644 --- a/tests/test_solvers/test_sgd.py +++ b/tests/test_solvers/test_sgd.py @@ -6,6 +6,7 @@ import numpy.typing as npt import pytest +from causalprog.solvers.iteration_result import IterationResult from causalprog.solvers.sgd import stochastic_gradient_descent from causalprog.utils.norms import PyTree @@ -87,3 +88,260 @@ def test_sgd( assert jax.tree_util.tree_all( jax.tree_util.tree_map(jax.numpy.allclose, result.fn_args, expected) ) + + +@pytest.mark.parametrize( + ( + "history_logging_interval", + "expected_iters", + ), + [ + pytest.param( + 1, + list(range(11)), + id="interval=1", + ), + pytest.param( + 2, + list(range(0, 11, 2)), + id="interval=2", + ), + pytest.param( + 3, + list(range(0, 11, 3)), + id="interval=3", + ), + pytest.param( + 0, + [], + id="interval=0 (no logging)", + ), + pytest.param( + -1, + [], + id="interval=-1 (no logging)", + ), + ], +) +def test_sgd_history_logging_intervals( + history_logging_interval: int, expected_iters: list[int] +) -> None: + """Test that history logging intervals work correctly.""" + + def obj_fn(x): + return (x**2).sum() + + initial_guess = jnp.atleast_1d(1.0) + + result = stochastic_gradient_descent( + obj_fn, + initial_guess, + maxiter=10, + tolerance=0.0, + history_logging_interval=history_logging_interval, + ) + + # Check that the correct iterations were logged + assert result.iter_history == expected_iters, ( + f"IterationResult.iter_history logged incorrectly. Expected {expected_iters}." + f"Got {result.iter_history}" + ) + + # Check that a correct number of fn_args, grad_val, obj_val were logged + assert len(result.fn_args_history) == len(expected_iters), ( + "IterationResult.fn_args_history logged incorrectly." + f"Expected {len(expected_iters)} entries. Got {len(result.fn_args_history)}" + ) + assert len(result.grad_val_history) == len(expected_iters), ( + "IterationResult.grad_val_history logged incorrectly." + f"Expected {len(expected_iters)} entries. Got {len(result.grad_val_history)}" + ) + assert len(result.obj_val_history) == len(expected_iters), ( + "IterationResult.obj_val_history logged incorrectly." + f"Expected {len(expected_iters)} entries. Got {len(result.obj_val_history)}" + ) + + # Check that logged fn_args, grad_val, obj_val line up correctly + value_and_grad_fn = jax.jit(jax.value_and_grad(obj_fn)) + + if len(expected_iters) > 0: + for fn_args, obj_val, grad_val in zip( + result.fn_args_history, + result.obj_val_history, + result.grad_val_history, + strict=True, + ): + real_obj_val, real_grad_val = value_and_grad_fn(fn_args) + + # Check that logged obj_val and fn_args line up correctly + assert real_obj_val == obj_val, ( + "Logged obj_val does not match obj_fn evaluated at logged fn_args." + f"For fn_args {fn_args}, we expected {obj_fn(fn_args)}, got {obj_val}." + ) + + # Check that logged gradient and fn_args line up correctly + assert real_grad_val == grad_val, ( + "Logged grad_val does not match gradient of obj_fn evaluated at" + f" logged fn_args. For fn_args {fn_args}, we expected" + f" {jax.gradient(obj_fn)(fn_args)}, got {grad_val}." + ) + + +@pytest.mark.parametrize( + ( + "make_callbacks", + "expected", + ), + [ + ( + lambda cb: cb, + [0, 1, 2], + ), + ( + lambda cb: [cb], + [0, 1, 2], + ), + ( + lambda cb: [cb, cb], + [0, 0, 1, 1, 2, 2], + ), + ( + lambda cb: [], # noqa: ARG005 + [], + ), + ( + lambda cb: None, # noqa: ARG005 + [], + ), + ], + ids=[ + "single callable", + "list of one callable", + "list of two callables", + "callbacks=[]", + "callbacks=None", + ], +) +def test_sgd_callbacks_invocation( + make_callbacks: Callable, expected: list[int] +) -> None: + """Test SGD invokes callbacks correctly for all shapes of callbacks input.""" + + def obj_fn(x): + return (x**2).sum() + + calls = [] + + def callback(iter_result: IterationResult) -> None: + calls.append(iter_result.iters) + + callbacks = make_callbacks(callback) + + initial = jnp.atleast_1d(1.0) + + stochastic_gradient_descent( + obj_fn, + initial, + maxiter=2, + tolerance=0.0, + callbacks=callbacks, + ) + + assert calls == expected, ( + f"Callback was not called correctly, got {calls}, expected {expected}" + ) + + +def test_sgd_invalid_callback() -> None: + def obj_fn(x): + return (x**2).sum() + + initial = jnp.atleast_1d(1.0) + + with pytest.raises( + TypeError, match="Callbacks must be a callable or a sequence of callables" + ): + stochastic_gradient_descent( + obj_fn, + initial, + maxiter=2, + tolerance=0.0, + callbacks=42, # type: ignore[arg-type] + ) + + +@pytest.mark.parametrize( + "history_logging_interval", [0, 1, 2], ids=lambda v: f"hist:{v}" +) +@pytest.mark.parametrize( + "make_callbacks", + [ + lambda cb: cb, + lambda cb: [cb], + lambda cb: [cb, cb], + lambda cb: [], # noqa: ARG005 + lambda cb: None, # noqa: ARG005 + ], + ids=["callable", "list_1", "list_2", "empty", "none"], +) +def test_logging_or_callbacks_affect_sgd_convergence( + history_logging_interval, + make_callbacks, +) -> None: + """Test that logging and callbacks don't affect convergence of SGD solver.""" + calls = [] + + def callback(iter_result: IterationResult) -> None: + calls.append(iter_result.iters) + + callbacks = make_callbacks(callback) + + def obj_fn(x): + return (x**2).sum() + + initial_guess = jnp.atleast_1d(1.0) + + baseline_result = stochastic_gradient_descent( + obj_fn, + initial_guess, + maxiter=6, + tolerance=0.0, + history_logging_interval=0, + ) + + result = stochastic_gradient_descent( + obj_fn, + initial_guess, + maxiter=6, + tolerance=0.0, + history_logging_interval=history_logging_interval, + callbacks=callbacks, + ) + + baseline_attributes = [ + baseline_result.fn_args, + baseline_result.obj_val, + baseline_result.grad_val, + baseline_result.iters, + baseline_result.successful, + baseline_result.reason, + ] + + result_attributes = [ + result.fn_args, + result.obj_val, + result.grad_val, + result.iters, + result.successful, + result.reason, + ] + + for baseline_attr, result_attr in zip( + baseline_attributes, result_attributes, strict=True + ): + assert baseline_attr == result_attr, ( + "Logging or callbacks changed the convergence behaviour of the" + " solver. For history_logging_interval" + f" {history_logging_interval}, callbacks {callbacks}, expected" + f" {baseline_attributes}, got {result_attributes}" + ) From b5102f71e17eee496dc7bf7949696c9460b0cb0b Mon Sep 17 00:00:00 2001 From: Sam Molyneux Date: Fri, 31 Oct 2025 22:22:06 +0000 Subject: [PATCH 03/12] combine update and iter_result --- src/causalprog/solvers/iteration_result.py | 61 +++++++++++++--------- src/causalprog/solvers/sgd.py | 18 +++---- 2 files changed, 43 insertions(+), 36 deletions(-) diff --git a/src/causalprog/solvers/iteration_result.py b/src/causalprog/solvers/iteration_result.py index ae555f8..31aec62 100644 --- a/src/causalprog/solvers/iteration_result.py +++ b/src/causalprog/solvers/iteration_result.py @@ -7,10 +7,15 @@ from causalprog.utils.norms import PyTree -@dataclass(frozen=False) +@dataclass(frozen=False, slots=True) class IterationResult: """ - Container class storing state of solvers at iteration `iters`. + Result container for iterative solvers with optional history logging. + + Stores the latest iterate and if `history_logging_interval > 0`, `update` appends + snapshots of the iterate to corresponding history lists each time the iteration + number is a multiple of `history_logging_interval`. + Instances are mutable but do not allow dynamic attribute creation. Args: fn_args: Argument to the objective function at final iteration (the solution, @@ -22,6 +27,8 @@ class IterationResult: fn_args_history: List of `fn_args` at each logged iteration. grad_val_history: List of `grad_val` at each logged iteration. obj_val_history: List of `obj_val` at each logged iteration. + history_logging_interval: Interval at which to log history. If + `history_logging_interval <= 0`, then no history is logged. """ @@ -29,35 +36,39 @@ class IterationResult: grad_val: PyTree iters: int obj_val: npt.ArrayLike + history_logging_interval: int = 0 iter_history: list[int] = field(default_factory=list) fn_args_history: list[PyTree] = field(default_factory=list) grad_val_history: list[PyTree] = field(default_factory=list) obj_val_history: list[npt.ArrayLike] = field(default_factory=list) + _log_enabled: bool = field(init=False, repr=False) -def _update_iteration_result( - iter_result: IterationResult, - current_params: PyTree, - gradient_value: PyTree, - iters: int, - objective_value: npt.ArrayLike, - history_logging_interval: int, -) -> None: - """ - Update the `IterationResult` object with current iteration data. + def __post_init__(self) -> None: + self._log_enabled = self.history_logging_interval > 0 - Only updates the history if `history_logging_interval` is positive and - the current iteration is a multiple of `history_logging_interval`. + def update( + self, + current_params: PyTree, + gradient_value: PyTree, + iters: int, + objective_value: npt.ArrayLike, + ) -> None: + """ + Update the `IterationResult` object with current iteration data. - """ - iter_result.fn_args = current_params - iter_result.grad_val = gradient_value - iter_result.iters = iters - iter_result.obj_val = objective_value - - if history_logging_interval > 0 and iters % history_logging_interval == 0: - iter_result.iter_history.append(iters) - iter_result.fn_args_history.append(current_params) - iter_result.grad_val_history.append(gradient_value) - iter_result.obj_val_history.append(objective_value) + Only updates the history if `history_logging_interval` is positive and + the current iteration is a multiple of `history_logging_interval`. + + """ + self.fn_args = current_params + self.grad_val = gradient_value + self.iters = iters + self.obj_val = objective_value + + if self._log_enabled and iters % self.history_logging_interval == 0: + self.iter_history.append(iters) + self.fn_args_history.append(current_params) + self.grad_val_history.append(gradient_value) + self.obj_val_history.append(objective_value) diff --git a/src/causalprog/solvers/sgd.py b/src/causalprog/solvers/sgd.py index 08526b9..b7e0e94 100644 --- a/src/causalprog/solvers/sgd.py +++ b/src/causalprog/solvers/sgd.py @@ -8,10 +8,7 @@ import numpy.typing as npt import optax -from causalprog.solvers.iteration_result import ( - IterationResult, - _update_iteration_result, -) +from causalprog.solvers.iteration_result import IterationResult from causalprog.solvers.solver_callbacks import _normalise_callbacks, _run_callbacks from causalprog.solvers.solver_result import SolverResult from causalprog.utils.norms import PyTree, l2_normsq @@ -117,16 +114,15 @@ def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: grad_val=gradient_value, iters=0, obj_val=objective_value, + history_logging_interval=history_logging_interval, ) for _ in range(maxiter + 1): - _update_iteration_result( - iter_result, - current_params, - gradient_value, - _, - objective_value, - history_logging_interval, + iter_result.update( + current_params=current_params, + gradient_value=gradient_value, + iters=_, + objective_value=objective_value, ) _run_callbacks(iter_result, callbacks) From 4aa1e504c01e1e9fb221646d4624d0ba732b2fa3 Mon Sep 17 00:00:00 2001 From: Sam Molyneux Date: Fri, 31 Oct 2025 22:26:15 +0000 Subject: [PATCH 04/12] change _ to current_iter in loop --- src/causalprog/solvers/sgd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/causalprog/solvers/sgd.py b/src/causalprog/solvers/sgd.py index b7e0e94..ac80d51 100644 --- a/src/causalprog/solvers/sgd.py +++ b/src/causalprog/solvers/sgd.py @@ -117,11 +117,11 @@ def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: history_logging_interval=history_logging_interval, ) - for _ in range(maxiter + 1): + for current_iter in range(maxiter + 1): iter_result.update( current_params=current_params, gradient_value=gradient_value, - iters=_, + iters=current_iter, objective_value=objective_value, ) @@ -135,7 +135,7 @@ def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: objective_value, gradient_value = value_and_grad_fn(current_params) - iters_used = _ + iters_used = current_iter reason_msg = ( f"Did not converge after {iters_used} iterations" if not converged else "" ) From 864541f4525551c962eaa1e09e73860efa4078fa Mon Sep 17 00:00:00 2001 From: Sam Molyneux Date: Fri, 31 Oct 2025 22:36:04 +0000 Subject: [PATCH 05/12] add tqd, to dependencies --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4750ef0..9a314c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ classifiers = [ "Programming Language :: Python :: 3.13", "Typing :: Typed", ] -dependencies = ["jax", "networkx", "numpy", "numpyro", "typing_extensions"] +dependencies = ["jax", "networkx", "numpy", "numpyro", "tqdm", "typing_extensions"] description = "A Python package for causal modelling and inference with stochastic causal programming" dynamic = ["version"] keywords = [] From a625478172755bf46ce1916b75ba4b8acd94d62f Mon Sep 17 00:00:00 2001 From: Sam Molyneux Date: Sat, 1 Nov 2025 01:08:47 +0000 Subject: [PATCH 06/12] allow any collection of callbacks --- src/causalprog/solvers/solver_callbacks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/causalprog/solvers/solver_callbacks.py b/src/causalprog/solvers/solver_callbacks.py index 482d601..1d40cae 100644 --- a/src/causalprog/solvers/solver_callbacks.py +++ b/src/causalprog/solvers/solver_callbacks.py @@ -1,6 +1,6 @@ """Module for callback functions for solvers.""" -from collections.abc import Callable +from collections.abc import Callable, Collection from tqdm.auto import tqdm @@ -9,15 +9,15 @@ def _normalise_callbacks( callbacks: Callable[[IterationResult], None] - | list[Callable[[IterationResult], None]] + | Collection[Callable[[IterationResult], None]] | None = None, ) -> list[Callable[[IterationResult], None]]: if callbacks is None: return [] if callable(callbacks): return [callbacks] - if isinstance(callbacks, list) and all(callable(cb) for cb in callbacks): - return callbacks + if all(callable(cb) for cb in callbacks): + return list(callbacks) msg = "Callbacks must be a callable or a sequence of callables" raise TypeError(msg) From 1498f2818c8f3e113d8b67ccfcf882cc7365ff6d Mon Sep 17 00:00:00 2001 From: Sam Molyneux Date: Sat, 1 Nov 2025 01:34:12 +0000 Subject: [PATCH 07/12] fix expected error on incorrect callback type --- tests/test_solvers/test_normalise_solver_inputs.py | 4 +--- tests/test_solvers/test_sgd.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/test_solvers/test_normalise_solver_inputs.py b/tests/test_solvers/test_normalise_solver_inputs.py index ca1e32f..0654851 100644 --- a/tests/test_solvers/test_normalise_solver_inputs.py +++ b/tests/test_solvers/test_normalise_solver_inputs.py @@ -23,7 +23,5 @@ def callback(iter_result: IterationResult) -> None: assert _normalise_callbacks([]) == [] # Test invalid input - with pytest.raises( - TypeError, match="Callbacks must be a callable or a sequence of callables" - ): + with pytest.raises(TypeError, match="'int' object is not iterable"): _normalise_callbacks(42) # type: ignore[arg-type] diff --git a/tests/test_solvers/test_sgd.py b/tests/test_solvers/test_sgd.py index 768a194..3292bd0 100644 --- a/tests/test_solvers/test_sgd.py +++ b/tests/test_solvers/test_sgd.py @@ -258,9 +258,7 @@ def obj_fn(x): initial = jnp.atleast_1d(1.0) - with pytest.raises( - TypeError, match="Callbacks must be a callable or a sequence of callables" - ): + with pytest.raises(TypeError, match="'int' object is not iterable"): stochastic_gradient_descent( obj_fn, initial, From 005f383963117e0bb1dba7dee1889135b6512588 Mon Sep 17 00:00:00 2001 From: Sam Molyneux Date: Sat, 1 Nov 2025 01:42:55 +0000 Subject: [PATCH 08/12] change pytest.raises to raises_context --- tests/test_solvers/test_normalise_solver_inputs.py | 6 ++---- tests/test_solvers/test_sgd.py | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/test_solvers/test_normalise_solver_inputs.py b/tests/test_solvers/test_normalise_solver_inputs.py index 0654851..52c9099 100644 --- a/tests/test_solvers/test_normalise_solver_inputs.py +++ b/tests/test_solvers/test_normalise_solver_inputs.py @@ -1,10 +1,8 @@ -import pytest - from causalprog.solvers.iteration_result import IterationResult from causalprog.solvers.solver_callbacks import _normalise_callbacks -def test_normalise_callbacks() -> None: +def test_normalise_callbacks(raises_context) -> None: """Test that callbacks are normalised correctly.""" def callback(iter_result: IterationResult) -> None: @@ -23,5 +21,5 @@ def callback(iter_result: IterationResult) -> None: assert _normalise_callbacks([]) == [] # Test invalid input - with pytest.raises(TypeError, match="'int' object is not iterable"): + with raises_context(TypeError("'int' object is not iterable")): _normalise_callbacks(42) # type: ignore[arg-type] diff --git a/tests/test_solvers/test_sgd.py b/tests/test_solvers/test_sgd.py index 3292bd0..63a7893 100644 --- a/tests/test_solvers/test_sgd.py +++ b/tests/test_solvers/test_sgd.py @@ -252,13 +252,13 @@ def callback(iter_result: IterationResult) -> None: ) -def test_sgd_invalid_callback() -> None: +def test_sgd_invalid_callback(raises_context) -> None: def obj_fn(x): return (x**2).sum() initial = jnp.atleast_1d(1.0) - with pytest.raises(TypeError, match="'int' object is not iterable"): + with raises_context(TypeError("'int' object is not iterable")): stochastic_gradient_descent( obj_fn, initial, From 40658ac16ddfd7c527e3e0ca89faae66f255c855 Mon Sep 17 00:00:00 2001 From: Sam Molyneux Date: Sat, 1 Nov 2025 01:47:58 +0000 Subject: [PATCH 09/12] remove redundant callback test cases --- tests/test_solvers/test_sgd.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tests/test_solvers/test_sgd.py b/tests/test_solvers/test_sgd.py index 63a7893..898b87f 100644 --- a/tests/test_solvers/test_sgd.py +++ b/tests/test_solvers/test_sgd.py @@ -101,11 +101,6 @@ def test_sgd( list(range(11)), id="interval=1", ), - pytest.param( - 2, - list(range(0, 11, 2)), - id="interval=2", - ), pytest.param( 3, list(range(0, 11, 3)), @@ -116,11 +111,6 @@ def test_sgd( [], id="interval=0 (no logging)", ), - pytest.param( - -1, - [], - id="interval=-1 (no logging)", - ), ], ) def test_sgd_history_logging_intervals( From c75da1f2646b4fe8046bcfcb8465bf94ac01b833 Mon Sep 17 00:00:00 2001 From: Sam Molyneux Date: Sat, 1 Nov 2025 02:28:49 +0000 Subject: [PATCH 10/12] parameterise test_normalise_callbacks --- .../test_normalise_solver_inputs.py | 54 +++++++++++++++---- 1 file changed, 43 insertions(+), 11 deletions(-) diff --git a/tests/test_solvers/test_normalise_solver_inputs.py b/tests/test_solvers/test_normalise_solver_inputs.py index 52c9099..d6a7ae8 100644 --- a/tests/test_solvers/test_normalise_solver_inputs.py +++ b/tests/test_solvers/test_normalise_solver_inputs.py @@ -1,25 +1,57 @@ +from collections.abc import Callable + +import pytest + from causalprog.solvers.iteration_result import IterationResult from causalprog.solvers.solver_callbacks import _normalise_callbacks -def test_normalise_callbacks(raises_context) -> None: +@pytest.mark.parametrize( + ( + "make_input_callbacks", + "make_expected_output_callbacks", + ), + [ + ( + lambda cb: cb, + lambda cb: [cb], + ), + ( + lambda cb: [cb, cb], + lambda cb: [cb, cb], + ), + ( + lambda cb: None, # noqa: ARG005 + lambda cb: [], # noqa: ARG005 + ), + ( + lambda cb: [], # noqa: ARG005 + lambda cb: [], # noqa: ARG005 + ), + ], + ids=[ + "single callable", + "list of two callables", + "callbacks=None", + "callbacks=[]", + ], +) +def test_normalise_callbacks( + make_input_callbacks: Callable, + make_expected_output_callbacks: Callable, +) -> None: """Test that callbacks are normalised correctly.""" def callback(iter_result: IterationResult) -> None: pass - # Test single callable - assert _normalise_callbacks(callback) == [callback] - - # Test sequence of callables - assert _normalise_callbacks([callback, callback]) == [callback, callback] + input_callbacks = make_input_callbacks(callback) + expected_output_callbacks = make_expected_output_callbacks(callback) - # Test None - assert _normalise_callbacks(None) == [] + assert _normalise_callbacks(input_callbacks) == expected_output_callbacks - # Test empty sequence - assert _normalise_callbacks([]) == [] - # Test invalid input +def test_normalise_invalid_callbacks(raises_context) -> None: + """Test that invalid callbacks raise TypeError.""" with raises_context(TypeError("'int' object is not iterable")): _normalise_callbacks(42) # type: ignore[arg-type] From cf832c416a432fdf1b2acf3378c03de8e9d3a570 Mon Sep 17 00:00:00 2001 From: Sam Molyneux Date: Sat, 1 Nov 2025 02:56:37 +0000 Subject: [PATCH 11/12] remove factories in favour of callback placeholder --- .../test_normalise_solver_inputs.py | 56 ++++++------------- 1 file changed, 17 insertions(+), 39 deletions(-) diff --git a/tests/test_solvers/test_normalise_solver_inputs.py b/tests/test_solvers/test_normalise_solver_inputs.py index d6a7ae8..3329495 100644 --- a/tests/test_solvers/test_normalise_solver_inputs.py +++ b/tests/test_solvers/test_normalise_solver_inputs.py @@ -1,57 +1,35 @@ -from collections.abc import Callable - import pytest from causalprog.solvers.iteration_result import IterationResult from causalprog.solvers.solver_callbacks import _normalise_callbacks +def _cb(ir: IterationResult) -> None: + """Placeholder callback""" + + @pytest.mark.parametrize( - ( - "make_input_callbacks", - "make_expected_output_callbacks", - ), + ("func_input", "expected"), [ - ( - lambda cb: cb, - lambda cb: [cb], - ), - ( - lambda cb: [cb, cb], - lambda cb: [cb, cb], - ), - ( - lambda cb: None, # noqa: ARG005 - lambda cb: [], # noqa: ARG005 - ), - ( - lambda cb: [], # noqa: ARG005 - lambda cb: [], # noqa: ARG005 - ), + (_cb, [_cb]), + ([_cb, _cb], [_cb, _cb]), + (None, []), + ([], []), + (42, TypeError("'int' object is not iterable")), ], ids=[ "single callable", "list of two callables", "callbacks=None", "callbacks=[]", + "callbacks=42", ], ) -def test_normalise_callbacks( - make_input_callbacks: Callable, - make_expected_output_callbacks: Callable, -) -> None: +def test_normalise_callbacks(func_input, expected, raises_context) -> None: """Test that callbacks are normalised correctly.""" - def callback(iter_result: IterationResult) -> None: - pass - - input_callbacks = make_input_callbacks(callback) - expected_output_callbacks = make_expected_output_callbacks(callback) - - assert _normalise_callbacks(input_callbacks) == expected_output_callbacks - - -def test_normalise_invalid_callbacks(raises_context) -> None: - """Test that invalid callbacks raise TypeError.""" - with raises_context(TypeError("'int' object is not iterable")): - _normalise_callbacks(42) # type: ignore[arg-type] + if isinstance(expected, Exception): + with raises_context(expected): + _normalise_callbacks(func_input) + else: + assert _normalise_callbacks(func_input) == expected From 23e2848c59e48235011408eb96c02b9a1b937939 Mon Sep 17 00:00:00 2001 From: Sam Molyneux Date: Sat, 1 Nov 2025 03:17:50 +0000 Subject: [PATCH 12/12] move obj_fn to fixture --- tests/fixtures/solvers.py | 16 +++++++++++++ tests/test_solvers/test_sgd.py | 44 +++++++++++++++------------------- 2 files changed, 35 insertions(+), 25 deletions(-) create mode 100644 tests/fixtures/solvers.py diff --git a/tests/fixtures/solvers.py b/tests/fixtures/solvers.py new file mode 100644 index 0000000..0b5aa91 --- /dev/null +++ b/tests/fixtures/solvers.py @@ -0,0 +1,16 @@ +from collections.abc import Callable + +import numpy.typing as npt +import pytest + +from causalprog.utils.norms import PyTree + + +@pytest.fixture +def sum_of_squares_obj() -> Callable[[PyTree], npt.ArrayLike]: + """f(x) = ||x||_2^2 = sum_i x_i^2""" + + def _inner(x): + return (x**2).sum() + + return _inner diff --git a/tests/test_solvers/test_sgd.py b/tests/test_solvers/test_sgd.py index 898b87f..a583743 100644 --- a/tests/test_solvers/test_sgd.py +++ b/tests/test_solvers/test_sgd.py @@ -114,17 +114,16 @@ def test_sgd( ], ) def test_sgd_history_logging_intervals( - history_logging_interval: int, expected_iters: list[int] + sum_of_squares_obj, + history_logging_interval: int, + expected_iters: list[int], ) -> None: """Test that history logging intervals work correctly.""" - def obj_fn(x): - return (x**2).sum() - initial_guess = jnp.atleast_1d(1.0) result = stochastic_gradient_descent( - obj_fn, + sum_of_squares_obj, initial_guess, maxiter=10, tolerance=0.0, @@ -152,7 +151,7 @@ def obj_fn(x): ) # Check that logged fn_args, grad_val, obj_val line up correctly - value_and_grad_fn = jax.jit(jax.value_and_grad(obj_fn)) + value_and_grad_fn = jax.jit(jax.value_and_grad(sum_of_squares_obj)) if len(expected_iters) > 0: for fn_args, obj_val, grad_val in zip( @@ -166,14 +165,15 @@ def obj_fn(x): # Check that logged obj_val and fn_args line up correctly assert real_obj_val == obj_val, ( "Logged obj_val does not match obj_fn evaluated at logged fn_args." - f"For fn_args {fn_args}, we expected {obj_fn(fn_args)}, got {obj_val}." + f"For fn_args {fn_args}, we expected {sum_of_squares_obj(fn_args)}," + "got {obj_val}." ) # Check that logged gradient and fn_args line up correctly assert real_grad_val == grad_val, ( "Logged grad_val does not match gradient of obj_fn evaluated at" f" logged fn_args. For fn_args {fn_args}, we expected" - f" {jax.gradient(obj_fn)(fn_args)}, got {grad_val}." + f" {jax.gradient(sum_of_squares_obj)(fn_args)}, got {grad_val}." ) @@ -213,13 +213,12 @@ def obj_fn(x): ], ) def test_sgd_callbacks_invocation( - make_callbacks: Callable, expected: list[int] + sum_of_squares_obj, + make_callbacks: Callable, + expected: list[int], ) -> None: """Test SGD invokes callbacks correctly for all shapes of callbacks input.""" - def obj_fn(x): - return (x**2).sum() - calls = [] def callback(iter_result: IterationResult) -> None: @@ -230,7 +229,7 @@ def callback(iter_result: IterationResult) -> None: initial = jnp.atleast_1d(1.0) stochastic_gradient_descent( - obj_fn, + sum_of_squares_obj, initial, maxiter=2, tolerance=0.0, @@ -242,15 +241,12 @@ def callback(iter_result: IterationResult) -> None: ) -def test_sgd_invalid_callback(raises_context) -> None: - def obj_fn(x): - return (x**2).sum() - +def test_sgd_invalid_callback(sum_of_squares_obj, raises_context) -> None: initial = jnp.atleast_1d(1.0) with raises_context(TypeError("'int' object is not iterable")): stochastic_gradient_descent( - obj_fn, + sum_of_squares_obj, initial, maxiter=2, tolerance=0.0, @@ -273,8 +269,9 @@ def obj_fn(x): ids=["callable", "list_1", "list_2", "empty", "none"], ) def test_logging_or_callbacks_affect_sgd_convergence( - history_logging_interval, - make_callbacks, + sum_of_squares_obj, + history_logging_interval: int, + make_callbacks: Callable, ) -> None: """Test that logging and callbacks don't affect convergence of SGD solver.""" calls = [] @@ -284,13 +281,10 @@ def callback(iter_result: IterationResult) -> None: callbacks = make_callbacks(callback) - def obj_fn(x): - return (x**2).sum() - initial_guess = jnp.atleast_1d(1.0) baseline_result = stochastic_gradient_descent( - obj_fn, + sum_of_squares_obj, initial_guess, maxiter=6, tolerance=0.0, @@ -298,7 +292,7 @@ def obj_fn(x): ) result = stochastic_gradient_descent( - obj_fn, + sum_of_squares_obj, initial_guess, maxiter=6, tolerance=0.0,