-
Notifications
You must be signed in to change notification settings - Fork 0
Add callbacks, history logging and performance improvements to sgd. #105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
To sgd.py:
- Added single pass JIT grad and objective calculation for big speed
- boosts.
- Added history logging option for easy debugging and better understanding
of optimisation process.
- Added callbacks functionality, allowing for user specific callbacks,
e.g. early stopping, live graph etc.
Added iteration_result.py to monitor current state of convergence.
Use of dataclass ensures backwards compatablity of callbacks.
Added solver_callbacks.py for callbacks and associated funcs.
Added history attributes to solver result.
62924d7 to
3fbed8d
Compare
willGraham01
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All-in-all I like these changes, and these are useful features that we could do with adding. Callbacks in particular should be very useful from a development/debugging perspective too,
Most of my comments relate to the design decisions for the code, considering what's in the rest of the codebase. Namely I think we can do some code recycling in places, and we tend to write our tests in a particular format (though the test cases provided are good).
Also, there are only two commits on this branch (one for codebase changes, one for tests). In general, don't be afraid to use more granular commits in your PRs (just take a look at how long the other PRs are!) - we use squash merges anyway, so everything gets condensed into a single commit on main anyway. And it's good to be able to roll things back.
| obj_val_history: list[npt.ArrayLike] = field(default_factory=list) | ||
|
|
||
|
|
||
| def _update_iteration_result( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason why this isn't a method of the IterationResult class? I expected this to be something like IterationResult.update (and it's currently written in that form too -> swap iter_result to self).
Related, any reason for why history_logging_interval is an argument that we pass in, rather than an attribute that's set at creation time (I guess we could want dynamic logging which we wouldn't get with a fixed attribute, but is that a common enough use-case to design around?). It also means that we could just check history_logging_interval > 0 once at creation time, and not do it every time in the method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason why this isn't a method of the
IterationResultclass? I expected this to be something likeIterationResult.update(and it's currently written in that form too -> swapiter_resulttoself).
Nope. It should be.
Related, any reason for why
history_logging_intervalis an argument that we pass in, rather than an attribute that's set at creation time (I guess we could want dynamic logging which we wouldn't get with a fixed attribute, but is that a common enough use-case to design around?). It also means that we could just checkhistory_logging_interval > 0once at creation time, and not do it every time in the method.
No. This is a good point.
| for _ in range(maxiter + 1): | ||
| objective_value = objective(current_params) | ||
| gradient_value = gradient(current_params) | ||
| _update_iteration_result( | ||
| iter_result, | ||
| current_params, | ||
| gradient_value, | ||
| _, | ||
| objective_value, | ||
| history_logging_interval, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per Matt's convention from another PR, since we're using _ in the loop, we should probably use a name like current_iter or something for the loop variable. (Would add a suggestion sorry but GitHub doesn't let me suggest things for unchanged lines)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, because of the iters_used = _ line, I wasn't sure if this was a convention you were using, so I didn't want to change it without checking.
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't advocate for doing this here and now, but I think we can likely condense these tests a bit.
If I'm right, we're currently checking:
- The number of iterations logged is correct
- The parameters / objective function / gradient value logged at each iteration is correct
- Callbacks are invoked correctly regardless of shape (whether they are a list / single callable / None etc).
- And catching the error case of the above (when not given callables).
- Testing that callbacks don't affect the SGD result / convergence.
The correct invocation (and its associated error catch) are the same things that we're checking in _normalise_callbacks. As such, I'm of the opinion that we don't need to test for catching them here (since the tests for _normalise_callbacks will flag what happens if we pass bad things in here!) - and we should just pass valid entries to sgd's callbacks argument. Testing these callbacks return & log the expected values however, is of course something we should still be doing!
Value logging is probably worth checking, but we can probably drop one of the "interval=2" and "interval=3" cases (the purpose of both tests is to check the logging interval is respected), and one of the "interval=0" and "interval=-1" cases (which both check something sensible happens for a nonsensical input).
This means that it's probably possible to condense these 3 tests into a single test function (with parametrisation) along the lines of "test_sgd_logging". Where in each test we check logging, recording, and non-effect on convergence in each case. But that sounds like a lot of reorganisation, which I should probably just break out into a follow-on issue 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I'm right, we're currently checking:
- The number of iterations logged is correct
- The parameters / objective function / gradient value logged at each iteration is correct
- Callbacks are invoked correctly regardless of shape (whether they are a list / single callable / None etc).
- And catching the error case of the above (when not given callables).
- Testing that callbacks don't affect the SGD result / convergence.
Yes and for the last bullet point, we also testing that convergence isn't affected by combinations of history logging and callbacks. With the additional caveat that if the IterationResult attributes are directly changed then of course the convergence will differ.
The correct invocation (and its associated error catch) are the same things that we're checking in
_normalise_callbacks. As such, I'm of the opinion that we don't need to test for catching them here (since the tests for_normalise_callbackswill flag what happens if we pass bad things in here!) - and we should just pass valid entries tosgd'scallbacksargument.
Opting to test both was intentional. My thoughts are that, 1). I would like to know _normalise_callbacks works correctly and 2). it is implemented correctly in each solver. I think we could remove either one of test_sgd_callbacks_invoaction or test_normalise_callbacks. But I personally favour removing test_normalise_callbacks, and keeping test_sgd_callbacks_invocation because I think it's more important to know that it is implemented correctly in each solver.
Value logging is probably worth checking, but we can probably drop one of the "interval=2" and "interval=3" cases (the purpose of both tests is to check the logging interval is respected), and one of the "interval=0" and "interval=-1" cases (which both check something sensible happens for a nonsensical input).
Yeah I can remove those. My brain always just questions if there is something special about the first edge case that makes it work correctly, so I always feel the need to excessively add more!
This means that it's probably possible to condense these 3 tests into a single test function (with parametrisation) along the lines of "
test_sgd_logging". Where in each test we check logging, recording, and non-effect on convergence in each case. But that sounds like a lot of reorganisation, which I should probably just break out into a follow-on issue 😅
Got it 👍 .
jax.jit, and calculating the gradient and objective in a single pass withjax.value_and_grad, we get big speed boosts to sgd. The new time taken is approximatelyhistory_logging_intervalparameter for thestochastic_gradient_descentfunction allows the user to enable or disable logging of the optimisation history. The interval determines how frequently the history is logged. This makes it easier to debug optimisations and make decisions about hyperparameters.IterationResultwhich is analogous toSolverResult. However,IterationResultusesfrozen=Trueto allow for dataclass updates each iteration. Using a dataclass ensures backward compatibility for callbacks if a new attribute is logged.callbacksparameter for thestochastic_gradient_descentfunction allows the user to set a list of callback functions as is standard in optimisation loops. In future, the callbacks can be used for early stopping or live plotting of the results. As an example, we include a usefultqdmcallback that displays a progress bar for the iterations and displays the current objective value.test_normalise_callbacks: Tests that_normalise_callbacksdoes validation and casts valid types tolist[Callable[IterationResult], None].test_sgd_history_logging_intervals: Tests that the correct iterations are logged for different intervals and that the correct associated obj, fn_args and grad are too for sgd.test_callback_invocation: Tests that sgd callbacks are called in the correct order with the correctIterationResult.test_invalid_callback: Tests that sgd will raise an error if given an invalid callback.test_logging_or_callbacks_affect_sgd_convergence: Tests that various combinations of callbacks and logging intervals all result in the same convergence behaviour and thus all have the same final obj, fn_args, grad etc.