-
Notifications
You must be signed in to change notification settings - Fork 39
activation-level disillation #388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
great progress! did you freeze everything except the randomly initialized mixers? |
|
Resetting and distilling only one layer, freezing the rest of the model gives satisfactory results:
Note some changes were required to allow loading a pretrained model while freezing certain layers (#394 ) |
| phase: PhaseType, | ||
| iteration: int, | ||
| metrics: dict | None = None, | ||
| setup_activation_storage: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed, you can communicate through preprocessed meta kwargs.
| ("model", "base_model", "head", "distillation_model"): "teacher", | ||
| ("reference_models"): { | ||
| "teacher": { | ||
| "model": { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll need complete model descriptions, ex. copied from another, otherwise the created model will be too big.
| ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, | ||
| ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, | ||
| ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, | ||
| ModelTestingGroup.distributed: ModelTestingGroupAction.broken, # failing: tp2, stp2, stp2_ce4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll probably want to leave these as unimportant and run once in a while, because the testing suite can't really support many distributed runs.
| """ | ||
| Maybe apply activation distillation loss and setup backward hooks | ||
| """ | ||
| mixer_output = hidden_states if bias is None else hidden_states + bias |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should only be evaluated if needed.
| mixer_output = hidden_states if bias is None else hidden_states + bias | ||
| # Teacher populates mixer activations for distillation. | ||
| activation_storage = kwargs.get(BlockKwargs.activation_distillation_storage) | ||
| if activation_storage is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider using the new _debug / output_hidden_states interface instead? It does the exact same thing.



✨ Description
Closes #385
TODOs:
0and gradients as well.Sanity checks:
0loss ✔️. But loss then increases to a small value instead of staying at 0.0loss (orange)With the caveat that distillation seems to experience memory spikes at specific points in training. The actual usage was lower most of the time:
🔍 Type of change
Select all that apply:
Testing
Performance Impact
📊 Performance Impact Details
If there is any impact on performance, describe it and provide benchmark results, if applicable: