Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ classifiers = [
"Programming Language :: Python :: 3.13",
"Typing :: Typed",
]
dependencies = ["jax", "networkx", "numpy", "numpyro", "typing_extensions"]
dependencies = ["jax", "networkx", "numpy", "numpyro", "tqdm", "typing_extensions"]
description = "A Python package for causal modelling and inference with stochastic causal programming"
dynamic = ["version"]
keywords = []
Expand Down
74 changes: 74 additions & 0 deletions src/causalprog/solvers/iteration_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Container classes for outputs from each iteration of solver methods."""

from dataclasses import dataclass, field

import numpy.typing as npt

from causalprog.utils.norms import PyTree


@dataclass(frozen=False, slots=True)
class IterationResult:
"""
Result container for iterative solvers with optional history logging.

Stores the latest iterate and if `history_logging_interval > 0`, `update` appends
snapshots of the iterate to corresponding history lists each time the iteration
number is a multiple of `history_logging_interval`.
Instances are mutable but do not allow dynamic attribute creation.

Args:
fn_args: Argument to the objective function at final iteration (the solution,
if `successful is `True`).
grad_val: Value of the gradient of the objective function at the `fn_args`.
iters: Number of iterations performed.
obj_val: Value of the objective function at `fn_args`.
iter_history: List of iteration numbers at which history was logged.
fn_args_history: List of `fn_args` at each logged iteration.
grad_val_history: List of `grad_val` at each logged iteration.
obj_val_history: List of `obj_val` at each logged iteration.
history_logging_interval: Interval at which to log history. If
`history_logging_interval <= 0`, then no history is logged.

"""

fn_args: PyTree
grad_val: PyTree
iters: int
obj_val: npt.ArrayLike
history_logging_interval: int = 0

iter_history: list[int] = field(default_factory=list)
fn_args_history: list[PyTree] = field(default_factory=list)
grad_val_history: list[PyTree] = field(default_factory=list)
obj_val_history: list[npt.ArrayLike] = field(default_factory=list)

_log_enabled: bool = field(init=False, repr=False)

def __post_init__(self) -> None:
self._log_enabled = self.history_logging_interval > 0

def update(
self,
current_params: PyTree,
gradient_value: PyTree,
iters: int,
objective_value: npt.ArrayLike,
) -> None:
"""
Update the `IterationResult` object with current iteration data.

Only updates the history if `history_logging_interval` is positive and
the current iteration is a multiple of `history_logging_interval`.

"""
self.fn_args = current_params
self.grad_val = gradient_value
self.iters = iters
self.obj_val = objective_value

if self._log_enabled and iters % self.history_logging_interval == 0:
self.iter_history.append(iters)
self.fn_args_history.append(current_params)
self.grad_val_history.append(gradient_value)
self.obj_val_history.append(objective_value)
55 changes: 45 additions & 10 deletions src/causalprog/solvers/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import numpy.typing as npt
import optax

from causalprog.solvers.iteration_result import IterationResult
from causalprog.solvers.solver_callbacks import _normalise_callbacks, _run_callbacks
from causalprog.solvers.solver_result import SolverResult
from causalprog.utils.norms import PyTree, l2_normsq

Expand All @@ -23,6 +25,10 @@ def stochastic_gradient_descent(
maxiter: int = 1000,
optimiser: optax.GradientTransformationExtraArgs | None = None,
tolerance: float = 1.0e-8,
history_logging_interval: int = -1,
callbacks: Callable[[IterationResult], None]
| list[Callable[[IterationResult], None]]
| None = None,
) -> SolverResult:
"""
Minimise a function of one argument using Stochastic Gradient Descent (SGD).
Expand Down Expand Up @@ -65,12 +71,17 @@ def stochastic_gradient_descent(
this number of iterations is exceeded.
optimiser: The `optax` optimiser to use during the update step.
tolerance: `tolerance` used when determining if a minimum has been found.
history_logging_interval: Interval (in number of iterations) at which to log
the history of optimisation. If history_logging_interval <= 0, no
history is logged.
callbacks: A `callable` or list of `callables` that take an
`IterationResult` as their only argument, and return `None`.
These will be called at the end of each iteration of the optimisation
procedure.


Returns:
Minimising argument of `obj_fn`.
Value of `obj_fn` at the minimum.
Gradient of `obj_fn` at the minimum.
Number of iterations performed.
SolverResult: Result of the optimisation procedure.

"""
if not fn_args:
Expand All @@ -82,29 +93,49 @@ def stochastic_gradient_descent(
if not optimiser:
optimiser = optax.adam(learning_rate)

callbacks = _normalise_callbacks(callbacks)

def objective(x: npt.ArrayLike) -> npt.ArrayLike:
return obj_fn(x, *fn_args, **fn_kwargs)

def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool:
return convergence_criteria(x, dx) < tolerance

converged = False
value_and_grad_fn = jax.jit(jax.value_and_grad(objective))

# init state
opt_state = optimiser.init(initial_guess)
current_params = deepcopy(initial_guess)
gradient = jax.grad(objective)
converged = False
objective_value, gradient_value = value_and_grad_fn(current_params)

for _ in range(maxiter + 1):
objective_value = objective(current_params)
gradient_value = gradient(current_params)
iter_result = IterationResult(
fn_args=current_params,
grad_val=gradient_value,
iters=0,
obj_val=objective_value,
history_logging_interval=history_logging_interval,
)

for current_iter in range(maxiter + 1):
iter_result.update(
current_params=current_params,
gradient_value=gradient_value,
iters=current_iter,
objective_value=objective_value,
)

_run_callbacks(iter_result, callbacks)

if converged := is_converged(objective_value, gradient_value):
break

updates, opt_state = optimiser.update(gradient_value, opt_state)
current_params = optax.apply_updates(current_params, updates)

iters_used = _
objective_value, gradient_value = value_and_grad_fn(current_params)

iters_used = current_iter
reason_msg = (
f"Did not converge after {iters_used} iterations" if not converged else ""
)
Expand All @@ -117,4 +148,8 @@ def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool:
obj_val=objective_value,
reason=reason_msg,
successful=converged,
iter_history=iter_result.iter_history,
fn_args_history=iter_result.fn_args_history,
grad_val_history=iter_result.grad_val_history,
obj_val_history=iter_result.obj_val_history,
)
61 changes: 61 additions & 0 deletions src/causalprog/solvers/solver_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Module for callback functions for solvers."""

from collections.abc import Callable, Collection

from tqdm.auto import tqdm

from causalprog.solvers.iteration_result import IterationResult


def _normalise_callbacks(
callbacks: Callable[[IterationResult], None]
| Collection[Callable[[IterationResult], None]]
| None = None,
) -> list[Callable[[IterationResult], None]]:
if callbacks is None:
return []
if callable(callbacks):
return [callbacks]
if all(callable(cb) for cb in callbacks):
return list(callbacks)

msg = "Callbacks must be a callable or a sequence of callables"
raise TypeError(msg)


def _run_callbacks(
iter_result: IterationResult,
callbacks: list[Callable[[IterationResult], None]],
) -> None:
for cb in callbacks:
cb(iter_result)


def tqdm_callback(total: int) -> Callable[[IterationResult], None]:
"""
Progress bar callback using `tqdm`.

Creates a callback function that can be passed to solvers to display a progress bar
during optimization. The progress bar updates based on the number of iterations and
also displays the current objective value.

Args:
total: Total number of iterations for the progress bar.

Returns:
Callback function that updates the progress bar.

"""
bar = tqdm(total=total)
last_it = {"i": 0}

def cb(ir: IterationResult) -> None:
step = ir.iters - last_it["i"]
if step > 0:
bar.update(step)

# Show objective and grad norm
bar.set_postfix(obj=float(ir.obj_val))
last_it["i"] = ir.iters

return cb
11 changes: 10 additions & 1 deletion src/causalprog/solvers/solver_result.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Container class for outputs from solver methods."""

from dataclasses import dataclass
from dataclasses import dataclass, field

import numpy.typing as npt

Expand All @@ -26,6 +26,10 @@ class SolverResult:
successful: `True` if solver converged, in which case `fn_args` is the
argument to the objective function at the solution of the problem being
solved. `False` otherwise.
iter_history: List of iteration numbers at which history was logged.
fn_args_history: List of `fn_args` at each logged iteration.
grad_val_history: List of `grad_val` at each logged iteration.
obj_val_history: List of `obj_val` at each logged iteration.

"""

Expand All @@ -36,3 +40,8 @@ class SolverResult:
obj_val: npt.ArrayLike
reason: str
successful: bool

iter_history: list[int] = field(default_factory=list)
fn_args_history: list[PyTree] = field(default_factory=list)
grad_val_history: list[PyTree] = field(default_factory=list)
obj_val_history: list[npt.ArrayLike] = field(default_factory=list)
16 changes: 16 additions & 0 deletions tests/fixtures/solvers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from collections.abc import Callable

import numpy.typing as npt
import pytest

from causalprog.utils.norms import PyTree


@pytest.fixture
def sum_of_squares_obj() -> Callable[[PyTree], npt.ArrayLike]:
"""f(x) = ||x||_2^2 = sum_i x_i^2"""

def _inner(x):
return (x**2).sum()

return _inner
35 changes: 35 additions & 0 deletions tests/test_solvers/test_normalise_solver_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest

from causalprog.solvers.iteration_result import IterationResult
from causalprog.solvers.solver_callbacks import _normalise_callbacks


def _cb(ir: IterationResult) -> None:
"""Placeholder callback"""


@pytest.mark.parametrize(
("func_input", "expected"),
[
(_cb, [_cb]),
([_cb, _cb], [_cb, _cb]),
(None, []),
([], []),
(42, TypeError("'int' object is not iterable")),
],
ids=[
"single callable",
"list of two callables",
"callbacks=None",
"callbacks=[]",
"callbacks=42",
],
)
def test_normalise_callbacks(func_input, expected, raises_context) -> None:
"""Test that callbacks are normalised correctly."""

if isinstance(expected, Exception):
with raises_context(expected):
_normalise_callbacks(func_input)
else:
assert _normalise_callbacks(func_input) == expected
Loading
Loading