[RFC] Refactor Input Transforms #1176
Closed
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
Currently, we apply the input transforms in
trainmode at theforwardcall, and inevalmodel at theposteriorcall. We also use atransform_train_inputscall at theeval/traincalls to make sure that atevaltime thetrain_inputsare stored as transformed (since they don't pass throughposterior). This design supportsExactGPmodels, and supports specifying where to apply which input transform via the flags (so that one-to-many transforms are only applied to test inputs). However, this does not work great with Approximate GP models, since this setup does not transform the inducing points atevaltime.This refactor splits out one-to-many transforms as
InputAugmentationTransform, allowing us to revert to simply applying thetransform_inputsin theforwardpass (at all times). We still need to apply one-to-many transforms (now calledInputAugmentationTransform) inposterior, so we introduce anaugment_inputsmethod.(Inspired by the public-private APIs of Ax) In order to minimize the transform related knowledge expected from developers, this introduces a
Model.forwardcall that appliestransform_inputsand callsself._forward.<AnyGivenModel>._forwardis the usualforwardcall that computes the prior, except that it no longer has to worry about transforms.Similarly, for the
posterior, this makesModel.posteriorinto a simple wrapper aroundModel._posterior, which applies theaugment_inputscall and theposterior_transform. Again, the<AnyGivenModel>._posteriorbecomes the usual posterior call that no longer has to worry about the input or posterior transforms (still has to deal with the outcome transform in the current implementation, though we can fix this by bringing back thefantasizeflag).This diff presents a minimal implementation around the
SingleTaskGPmodel.Differential Revision: D35129407