-
Notifications
You must be signed in to change notification settings - Fork 0
Add callbacks, history logging and performance improvements to sgd. #105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
samjmolyneux
wants to merge
12
commits into
main
Choose a base branch
from
sjmolyneux/sgd-improvements
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
eaba080
Callbacks, history logging and speed boosts to sgd
samjmolyneux 3fbed8d
Add tests for sgd callbacks and history logging
samjmolyneux b5102f7
combine update and iter_result
samjmolyneux 4aa1e50
change _ to current_iter in loop
samjmolyneux 864541f
add tqd, to dependencies
samjmolyneux a625478
allow any collection of callbacks
samjmolyneux 1498f28
fix expected error on incorrect callback type
samjmolyneux 005f383
change pytest.raises to raises_context
samjmolyneux 40658ac
remove redundant callback test cases
samjmolyneux c75da1f
parameterise test_normalise_callbacks
samjmolyneux cf832c4
remove factories in favour of callback placeholder
samjmolyneux 23e2848
move obj_fn to fixture
samjmolyneux File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| """Container classes for outputs from each iteration of solver methods.""" | ||
|
|
||
| from dataclasses import dataclass, field | ||
|
|
||
| import numpy.typing as npt | ||
|
|
||
| from causalprog.utils.norms import PyTree | ||
|
|
||
|
|
||
| @dataclass(frozen=False, slots=True) | ||
| class IterationResult: | ||
| """ | ||
| Result container for iterative solvers with optional history logging. | ||
|
|
||
| Stores the latest iterate and if `history_logging_interval > 0`, `update` appends | ||
| snapshots of the iterate to corresponding history lists each time the iteration | ||
| number is a multiple of `history_logging_interval`. | ||
| Instances are mutable but do not allow dynamic attribute creation. | ||
|
|
||
| Args: | ||
| fn_args: Argument to the objective function at final iteration (the solution, | ||
| if `successful is `True`). | ||
| grad_val: Value of the gradient of the objective function at the `fn_args`. | ||
| iters: Number of iterations performed. | ||
| obj_val: Value of the objective function at `fn_args`. | ||
| iter_history: List of iteration numbers at which history was logged. | ||
| fn_args_history: List of `fn_args` at each logged iteration. | ||
| grad_val_history: List of `grad_val` at each logged iteration. | ||
| obj_val_history: List of `obj_val` at each logged iteration. | ||
| history_logging_interval: Interval at which to log history. If | ||
| `history_logging_interval <= 0`, then no history is logged. | ||
|
|
||
| """ | ||
|
|
||
| fn_args: PyTree | ||
| grad_val: PyTree | ||
| iters: int | ||
| obj_val: npt.ArrayLike | ||
| history_logging_interval: int = 0 | ||
|
|
||
| iter_history: list[int] = field(default_factory=list) | ||
| fn_args_history: list[PyTree] = field(default_factory=list) | ||
| grad_val_history: list[PyTree] = field(default_factory=list) | ||
| obj_val_history: list[npt.ArrayLike] = field(default_factory=list) | ||
|
|
||
| _log_enabled: bool = field(init=False, repr=False) | ||
|
|
||
| def __post_init__(self) -> None: | ||
| self._log_enabled = self.history_logging_interval > 0 | ||
|
|
||
| def update( | ||
| self, | ||
| current_params: PyTree, | ||
| gradient_value: PyTree, | ||
| iters: int, | ||
| objective_value: npt.ArrayLike, | ||
| ) -> None: | ||
| """ | ||
| Update the `IterationResult` object with current iteration data. | ||
|
|
||
| Only updates the history if `history_logging_interval` is positive and | ||
| the current iteration is a multiple of `history_logging_interval`. | ||
|
|
||
| """ | ||
| self.fn_args = current_params | ||
| self.grad_val = gradient_value | ||
| self.iters = iters | ||
| self.obj_val = objective_value | ||
|
|
||
| if self._log_enabled and iters % self.history_logging_interval == 0: | ||
| self.iter_history.append(iters) | ||
| self.fn_args_history.append(current_params) | ||
| self.grad_val_history.append(gradient_value) | ||
| self.obj_val_history.append(objective_value) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| """Module for callback functions for solvers.""" | ||
|
|
||
| from collections.abc import Callable, Collection | ||
|
|
||
| from tqdm.auto import tqdm | ||
|
|
||
| from causalprog.solvers.iteration_result import IterationResult | ||
|
|
||
|
|
||
| def _normalise_callbacks( | ||
| callbacks: Callable[[IterationResult], None] | ||
| | Collection[Callable[[IterationResult], None]] | ||
| | None = None, | ||
| ) -> list[Callable[[IterationResult], None]]: | ||
| if callbacks is None: | ||
| return [] | ||
| if callable(callbacks): | ||
| return [callbacks] | ||
| if all(callable(cb) for cb in callbacks): | ||
| return list(callbacks) | ||
|
|
||
| msg = "Callbacks must be a callable or a sequence of callables" | ||
| raise TypeError(msg) | ||
|
|
||
|
|
||
| def _run_callbacks( | ||
| iter_result: IterationResult, | ||
| callbacks: list[Callable[[IterationResult], None]], | ||
| ) -> None: | ||
| for cb in callbacks: | ||
| cb(iter_result) | ||
|
|
||
|
|
||
| def tqdm_callback(total: int) -> Callable[[IterationResult], None]: | ||
| """ | ||
| Progress bar callback using `tqdm`. | ||
|
|
||
| Creates a callback function that can be passed to solvers to display a progress bar | ||
| during optimization. The progress bar updates based on the number of iterations and | ||
| also displays the current objective value. | ||
|
|
||
| Args: | ||
| total: Total number of iterations for the progress bar. | ||
|
|
||
| Returns: | ||
| Callback function that updates the progress bar. | ||
|
|
||
| """ | ||
| bar = tqdm(total=total) | ||
| last_it = {"i": 0} | ||
|
|
||
| def cb(ir: IterationResult) -> None: | ||
| step = ir.iters - last_it["i"] | ||
| if step > 0: | ||
| bar.update(step) | ||
|
|
||
| # Show objective and grad norm | ||
| bar.set_postfix(obj=float(ir.obj_val)) | ||
| last_it["i"] = ir.iters | ||
|
|
||
| return cb | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| from collections.abc import Callable | ||
|
|
||
| import numpy.typing as npt | ||
| import pytest | ||
|
|
||
| from causalprog.utils.norms import PyTree | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def sum_of_squares_obj() -> Callable[[PyTree], npt.ArrayLike]: | ||
| """f(x) = ||x||_2^2 = sum_i x_i^2""" | ||
|
|
||
| def _inner(x): | ||
| return (x**2).sum() | ||
|
|
||
| return _inner |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| import pytest | ||
|
|
||
| from causalprog.solvers.iteration_result import IterationResult | ||
| from causalprog.solvers.solver_callbacks import _normalise_callbacks | ||
|
|
||
|
|
||
| def _cb(ir: IterationResult) -> None: | ||
| """Placeholder callback""" | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| ("func_input", "expected"), | ||
| [ | ||
| (_cb, [_cb]), | ||
| ([_cb, _cb], [_cb, _cb]), | ||
| (None, []), | ||
| ([], []), | ||
| (42, TypeError("'int' object is not iterable")), | ||
| ], | ||
| ids=[ | ||
| "single callable", | ||
| "list of two callables", | ||
| "callbacks=None", | ||
| "callbacks=[]", | ||
| "callbacks=42", | ||
| ], | ||
| ) | ||
| def test_normalise_callbacks(func_input, expected, raises_context) -> None: | ||
| """Test that callbacks are normalised correctly.""" | ||
|
|
||
| if isinstance(expected, Exception): | ||
| with raises_context(expected): | ||
| _normalise_callbacks(func_input) | ||
| else: | ||
| assert _normalise_callbacks(func_input) == expected |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.