|
| 1 | +"""Minimisation via Stochastic Gradient Descent.""" |
| 2 | + |
| 3 | +from collections.abc import Callable |
| 4 | +from copy import deepcopy |
| 5 | + |
| 6 | +import jax |
| 7 | +import jax.numpy as jnp |
| 8 | +import numpy.typing as npt |
| 9 | +import optax |
| 10 | + |
| 11 | +from causalprog.utils.norms import PyTree, l2_normsq |
| 12 | + |
| 13 | + |
| 14 | +def stochastic_gradient_descent( |
| 15 | + obj_fn: Callable[[PyTree], npt.ArrayLike], |
| 16 | + initial_guess: PyTree, |
| 17 | + *, |
| 18 | + convergence_criteria: Callable[[PyTree, PyTree], npt.ArrayLike] | None = None, |
| 19 | + fn_args: tuple | None = None, |
| 20 | + fn_kwargs: dict | None = None, |
| 21 | + learning_rate: float = 1.0e-1, |
| 22 | + maxiter: int = 1000, |
| 23 | + optimiser: optax.GradientTransformationExtraArgs | None = None, |
| 24 | + tolerance: float = 1.0e-8, |
| 25 | +) -> tuple[PyTree, npt.ArrayLike, npt.ArrayLike, int]: |
| 26 | + """ |
| 27 | + Minimise a function of one argument using Stochastic Gradient Descent (SGD). |
| 28 | +
|
| 29 | + The `obj_fn` provided will be minimised over its first argument. If you wish to |
| 30 | + minimise a function over a different argument, or multiple arguments, wrap it in a |
| 31 | + suitable `lambda` expression that has the correct call signature. For example, to |
| 32 | + minimise a function `f(x, y, z)` over `y` and `z`, use |
| 33 | + `g = lambda yz, x: f(x, yz[0], yz[1])`, and pass `g` in as `obj_fn`. Note that |
| 34 | + you will also need to provide a constant value for `x` via `fn_args` or `fn_kwargs`. |
| 35 | +
|
| 36 | + The `fn_args` and `fn_kwargs` keys can be used to supply additional parameters that |
| 37 | + need to be passed to `obj_fn`, but which should be held constant. |
| 38 | +
|
| 39 | + SGD terminates when the `convergence_criteria` is found to be smaller than the |
| 40 | + `tolerance`. That is, when |
| 41 | + `convergence_criteria(objective_value, gradient_value) <= tolerance` is found to |
| 42 | + be `True`, the algorithm considers a minimum to have been found. The default |
| 43 | + condition under which the algorithm terminates is when the norm of the gradient |
| 44 | + at the current argument value is smaller than the provided `tolerance`. |
| 45 | +
|
| 46 | + The optimiser to use can be selected by passing in a suitable `optax` optimiser |
| 47 | + via the `optimiser` command. By default, `optax.adams` is used with the supplied |
| 48 | + `learning_rate`. Providing an explicit value for `optimiser` will result in the |
| 49 | + `learning_rate` argument being ignored. |
| 50 | +
|
| 51 | + Args: |
| 52 | + obj_fn: Function to be minimised over its first argument. |
| 53 | + initial_guess: Initial guess for the minimising argument. |
| 54 | + convergence_criteria: The quantity that will be tested against `tolerance`, to |
| 55 | + determine whether the method has converged to a minimum. It should be a |
| 56 | + `callable` that takes the current value of `obj_fn` as its 1st argument, and |
| 57 | + the current value of the gradient of `obj_fn` as its 2nd argument. The |
| 58 | + default criteria is the l2-norm of the gradient. |
| 59 | + fn_args: Positional arguments to be passed to `obj_fn`, and held constant. |
| 60 | + fn_kwargs: Keyword arguments to be passed to `obj_fn`, and held constant. |
| 61 | + learning_rate: Default learning rate (or step size) to use when using the |
| 62 | + default `optimiser`. No effect if `optimiser` is provided explicitly. |
| 63 | + maxiter: Maximum number of iterations to perform. An error will be reported if |
| 64 | + this number of iterations is exceeded. |
| 65 | + optimiser: The `optax` optimiser to use during the update step. |
| 66 | + tolerance: `tolerance` used when determining if a minimum has been found. |
| 67 | +
|
| 68 | + Returns: |
| 69 | + Minimising argument of `obj_fn`. |
| 70 | + Value of `obj_fn` at the minimum. |
| 71 | + Gradient of `obj_fn` at the minimum. |
| 72 | + Number of iterations performed. |
| 73 | +
|
| 74 | + """ |
| 75 | + if not fn_args: |
| 76 | + fn_args = () |
| 77 | + if not fn_kwargs: |
| 78 | + fn_kwargs = {} |
| 79 | + if not convergence_criteria: |
| 80 | + convergence_criteria = lambda _, dx: jnp.sqrt(l2_normsq(dx)) # noqa: E731 |
| 81 | + if not optimiser: |
| 82 | + optimiser = optax.adam(learning_rate) |
| 83 | + |
| 84 | + def objective(x: npt.ArrayLike) -> npt.ArrayLike: |
| 85 | + return obj_fn(x, *fn_args, **fn_kwargs) |
| 86 | + |
| 87 | + def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool: |
| 88 | + return convergence_criteria(x, dx) < tolerance |
| 89 | + |
| 90 | + gradient = jax.grad(objective) |
| 91 | + opt_state = optimiser.init(initial_guess) |
| 92 | + |
| 93 | + current_params = deepcopy(initial_guess) |
| 94 | + gradient_value = gradient(current_params) |
| 95 | + for i in range(maxiter): |
| 96 | + updates, opt_state = optimiser.update(gradient_value, opt_state) |
| 97 | + current_params = optax.apply_updates(current_params, updates) |
| 98 | + |
| 99 | + objective_value = objective(current_params) |
| 100 | + gradient_value = gradient(current_params) |
| 101 | + |
| 102 | + if is_converged(objective_value, gradient_value): |
| 103 | + return current_params, objective_value, gradient_value, i + 1 |
| 104 | + |
| 105 | + msg = f"Did not converge after {i + 1} iterations." |
| 106 | + raise RuntimeError(msg) |
0 commit comments