88import numpy .typing as npt
99import optax
1010
11+ from causalprog .solvers .iteration_result import (
12+ IterationResult ,
13+ _update_iteration_result ,
14+ )
15+ from causalprog .solvers .solver_callbacks import _normalise_callbacks , _run_callbacks
1116from causalprog .solvers .solver_result import SolverResult
1217from causalprog .utils .norms import PyTree , l2_normsq
1318
@@ -23,6 +28,10 @@ def stochastic_gradient_descent(
2328 maxiter : int = 1000 ,
2429 optimiser : optax .GradientTransformationExtraArgs | None = None ,
2530 tolerance : float = 1.0e-8 ,
31+ history_logging_interval : int = - 1 ,
32+ callbacks : Callable [[IterationResult ], None ]
33+ | list [Callable [[IterationResult ], None ]]
34+ | None = None ,
2635) -> SolverResult :
2736 """
2837 Minimise a function of one argument using Stochastic Gradient Descent (SGD).
@@ -65,12 +74,17 @@ def stochastic_gradient_descent(
6574 this number of iterations is exceeded.
6675 optimiser: The `optax` optimiser to use during the update step.
6776 tolerance: `tolerance` used when determining if a minimum has been found.
77+ history_logging_interval: Interval (in number of iterations) at which to log
78+ the history of optimisation. If history_logging_interval <= 0, no
79+ history is logged.
80+ callbacks: A `callable` or list of `callables` that take an
81+ `IterationResult` as their only argument, and return `None`.
82+ These will be called at the end of each iteration of the optimisation
83+ procedure.
84+
6885
6986 Returns:
70- Minimising argument of `obj_fn`.
71- Value of `obj_fn` at the minimum.
72- Gradient of `obj_fn` at the minimum.
73- Number of iterations performed.
87+ SolverResult: Result of the optimisation procedure.
7488
7589 """
7690 if not fn_args :
@@ -82,28 +96,49 @@ def stochastic_gradient_descent(
8296 if not optimiser :
8397 optimiser = optax .adam (learning_rate )
8498
99+ callbacks = _normalise_callbacks (callbacks )
100+
85101 def objective (x : npt .ArrayLike ) -> npt .ArrayLike :
86102 return obj_fn (x , * fn_args , ** fn_kwargs )
87103
88104 def is_converged (x : npt .ArrayLike , dx : npt .ArrayLike ) -> bool :
89105 return convergence_criteria (x , dx ) < tolerance
90106
91- converged = False
107+ value_and_grad_fn = jax . jit ( jax . value_and_grad ( objective ))
92108
109+ # init state
93110 opt_state = optimiser .init (initial_guess )
94111 current_params = deepcopy (initial_guess )
95- gradient = jax .grad (objective )
112+ converged = False
113+ objective_value , gradient_value = value_and_grad_fn (current_params )
114+
115+ iter_result = IterationResult (
116+ fn_args = current_params ,
117+ grad_val = gradient_value ,
118+ iters = 0 ,
119+ obj_val = objective_value ,
120+ )
96121
97122 for _ in range (maxiter + 1 ):
98- objective_value = objective (current_params )
99- gradient_value = gradient (current_params )
123+ _update_iteration_result (
124+ iter_result ,
125+ current_params ,
126+ gradient_value ,
127+ _ ,
128+ objective_value ,
129+ history_logging_interval ,
130+ )
131+
132+ _run_callbacks (iter_result , callbacks )
100133
101134 if converged := is_converged (objective_value , gradient_value ):
102135 break
103136
104137 updates , opt_state = optimiser .update (gradient_value , opt_state )
105138 current_params = optax .apply_updates (current_params , updates )
106139
140+ objective_value , gradient_value = value_and_grad_fn (current_params )
141+
107142 iters_used = _
108143 reason_msg = (
109144 f"Did not converge after { iters_used } iterations" if not converged else ""
@@ -117,4 +152,8 @@ def is_converged(x: npt.ArrayLike, dx: npt.ArrayLike) -> bool:
117152 obj_val = objective_value ,
118153 reason = reason_msg ,
119154 successful = converged ,
155+ iter_history = iter_result .iter_history ,
156+ fn_args_history = iter_result .fn_args_history ,
157+ grad_val_history = iter_result .grad_val_history ,
158+ obj_val_history = iter_result .obj_val_history ,
120159 )
0 commit comments