Skip to content

Conversation

@shuhuayu
Copy link
Contributor

As titled, this PR adds HF state dict adapter to support loading from GPT-OSS HF checkpoint. GPT-OSS checkpoint is quantized in MXPF4 format. The de-quantization steps are offloaded to the QuantizedHuggingFaceStorageReader in dcp, so this feature depends on this PR to update QuantizedHuggingFaceStorageReader (pytorch/pytorch#167672).

  1. Test 1. We use dcp.load(hf_state_dict, storage_reader=QuantizedHuggingFaceStorageReader(path=input_dir)) to load from GPT-OSS HF checkpoint, and map the hf_state_dict back to TorchTitan state dict. We build one test input, and compare two outputs: 1. Using transformer library to load GPT-OSS HF checkpoint and run inference on the test input; 2. We use the converted TorchTitan model to run inference on the test input. We compare the outputs by comparing the KL divergence of two output probability distributions. The result shows two models are very similar. Pasted Graphic

  2. Test 2. We load the model directly from quantized GPT-OSS HF checkpoint, and do a test training.
    Pasted Graphic 1

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 12, 2025
@shuhuayu shuhuayu marked this pull request as draft November 12, 2025 20:42
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants