[GPT-OSS] Add HF state dict adapter to support loading from HF checkpoints #2021
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
As titled, this PR adds HF state dict adapter to support loading from GPT-OSS HF checkpoint. GPT-OSS checkpoint is quantized in MXPF4 format. The de-quantization steps are offloaded to the
QuantizedHuggingFaceStorageReaderindcp, so this feature depends on this PR to updateQuantizedHuggingFaceStorageReader(pytorch/pytorch#167672).Test 1. We use
dcp.load(hf_state_dict, storage_reader=QuantizedHuggingFaceStorageReader(path=input_dir))to load from GPT-OSS HF checkpoint, and map thehf_state_dictback to TorchTitan state dict. We build one test input, and compare two outputs: 1. Usingtransformerlibrary to load GPT-OSS HF checkpoint and run inference on the test input; 2. We use the converted TorchTitan model to run inference on the test input. We compare the outputs by comparing the KL divergence of two output probability distributions. The result shows two models are very similar.Test 2. We load the model directly from quantized GPT-OSS HF checkpoint, and do a test training.
