From 4294bf85d2e0755b03dd4bf0171c2d60969e82c3 Mon Sep 17 00:00:00 2001 From: Shuhua Yu Date: Mon, 3 Nov 2025 16:49:18 -0800 Subject: [PATCH 1/2] [GPT-OSS] Add HF state dict adapter to support loading from HF checkpoints --- docs/checkpoint.md | 2 +- torchtitan/experiments/gpt_oss/__init__.py | 2 + .../gpt_oss/model/state_dict_adapter.py | 205 ++++++++++++++++++ 3 files changed, 208 insertions(+), 1 deletion(-) create mode 100644 torchtitan/experiments/gpt_oss/model/state_dict_adapter.py diff --git a/docs/checkpoint.md b/docs/checkpoint.md index 6e3112309b..8aca58eb06 100644 --- a/docs/checkpoint.md +++ b/docs/checkpoint.md @@ -68,7 +68,7 @@ NGPU=1 CONFIG_FILE= ./run_train.sh --checkpoint.enable --c ### HuggingFace `torchtitan` offers two ways to work with Hugging Face models: either by directly saving and loading a Hugging Face checkpoint during training, or by using an example conversion script to directly reformat the model weights on cpu. -1. You can directly save huggingface model weights during training by using the `--checkpoint.last_save_in_safetensors_format` and `--checkpoint.last_save_model_only` options together. To directly load a `torchtitan` training session from a huggingface safetensors file, enable `--checkpoint.initial_load_in_hf`, and set either `--model.hf_assets_path` or `--checkpoint.initial_load_path` to the directory containing the huggingface checkpoint. `--checkpoint.initial_load_path` overrides `--model.hf_assets_path` if both are set. +1. You can directly save huggingface model weights during training by using the `--checkpoint.last_save_in_hf` and `--checkpoint.last_save_model_only` options together. To directly load a `torchtitan` training session from a huggingface safetensors file, enable `--checkpoint.initial_load_in_hf`, and set either `--model.hf_assets_path` or `--checkpoint.initial_load_path` to the directory containing the huggingface checkpoint. `--checkpoint.initial_load_path` overrides `--model.hf_assets_path` if both are set. 2. To directly reformat the weights without the need to run a training loop, run the corresponding conversion script. The naming scheme is `torchtitan`-centric, e.g. convert_from_hf means convert hf->tt. diff --git a/torchtitan/experiments/gpt_oss/__init__.py b/torchtitan/experiments/gpt_oss/__init__.py index c12ad13a5c..0ebc20645f 100644 --- a/torchtitan/experiments/gpt_oss/__init__.py +++ b/torchtitan/experiments/gpt_oss/__init__.py @@ -16,6 +16,7 @@ from .infra.parallelize import parallelize_gptoss from .model.args import GptOssModelArgs from .model.model import GptOssModel +from .model.state_dict_adapter import GptOssStateDictAdapter __all__ = [ "parallelize_gptoss", @@ -84,4 +85,5 @@ def get_train_spec() -> TrainSpec: build_dataloader_fn=build_text_dataloader, build_tokenizer_fn=build_hf_tokenizer, build_loss_fn=build_cross_entropy_loss, + state_dict_adapter=GptOssStateDictAdapter, ) diff --git a/torchtitan/experiments/gpt_oss/model/state_dict_adapter.py b/torchtitan/experiments/gpt_oss/model/state_dict_adapter.py new file mode 100644 index 0000000000..0377c049f8 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/model/state_dict_adapter.py @@ -0,0 +1,205 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +import re +from typing import Any + +import torch +from torch.distributed.tensor import DTensor +from torchtitan.models.utils import MoEStateDictAdapter + +from .args import GptOssModelArgs + + +FP4_VALUES = [ + +0.0, + +0.5, + +1.0, + +1.5, + +2.0, + +3.0, + +4.0, + +6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, +] + + +def get_mxfp4_tensor( + blocks, + scales, + *, + dtype: torch.dtype = torch.bfloat16, + rows_per_chunk: int = 16384 * 512, +) -> torch.Tensor: + """ + Adapted from openai's implementation of mxfp4 dequantization: + https://github.com/openai/gpt-oss/blob/8890e95919f975a490fc0ba09ffb10890ec7319d/gpt_oss/torch/weights.py#L68 + """ + + is_dtensor = isinstance(blocks, DTensor) + if is_dtensor: + device_mesh = blocks.device_mesh + placements = blocks.placements + blocks = blocks.to_local() + scales = scales.to_local() + + scales = scales.to(torch.int32) - 127 + + assert ( + blocks.shape[:-1] == scales.shape + ), f"{blocks.shape=} does not match {scales.shape=}" + + lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device) + + *prefix_shape, G, B = blocks.shape + rows_total = math.prod(prefix_shape) * G + + blocks = blocks.reshape(rows_total, B) + scales = scales.reshape(rows_total, 1) + + out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device) + + for r0 in range(0, rows_total, rows_per_chunk): + r1 = min(r0 + rows_per_chunk, rows_total) + + blk = blocks[r0:r1] + exp = scales[r0:r1] + + # nibble indices -> int64 + idx_lo = (blk & 0x0F).to(torch.long) + idx_hi = (blk >> 4).to(torch.long) + + sub = out[r0:r1] + sub[:, 0::2] = lut[idx_lo] + sub[:, 1::2] = lut[idx_hi] + + torch.ldexp(sub, exp, out=sub) + del idx_lo, idx_hi, blk, exp + + result = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) + + if is_dtensor: + result = DTensor.from_local( + result, device_mesh=device_mesh, placements=placements + ) + + return result + + +class GptOssStateDictAdapter(MoEStateDictAdapter): + def __init__(self, model_args: GptOssModelArgs, hf_assets_path: str | None): + super().__init__(model_args, hf_assets_path) + self.from_hf_map = { + "model.embed_tokens.weight": "tok_embeddings.weight", + # Attention module + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.q_proj.bias": "layers.{}.attention.wq.bias", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.k_proj.bias": "layers.{}.attention.wk.bias", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.v_proj.bias": "layers.{}.attention.wv.bias", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.self_attn.o_proj.bias": "layers.{}.attention.wo.bias", + "model.layers.{}.self_attn.sinks": "layers.{}.attention.sinks", + # Transformer layer + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + # MoE + ( + "model.layers.{}.mlp.experts.gate_up_proj_blocks", + "model.layers.{}.mlp.experts.gate_up_proj_scales", + ): "layers.{}.moe.experts.mlp1_weight", + "model.layers.{}.mlp.experts.gate_up_proj_bias": "layers.{}.moe.experts.mlp1_bias", + ( + "model.layers.{}.mlp.experts.down_proj_blocks", + "model.layers.{}.mlp.experts.down_proj_scales", + ): "layers.{}.moe.experts.mlp2_weight", + "model.layers.{}.mlp.experts.down_proj_bias": "layers.{}.moe.experts.mlp2_bias", + "model.layers.{}.mlp.router.weight": "layers.{}.moe.router.gate.weight", + "model.layers.{}.mlp.router.bias": "layers.{}.moe.router.gate.bias", + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + } + + def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: + """ + Convert from a tt model state dict to a hf format state dict. + Warning: Conversion does not support mxfp4 quantization, + and the function is only for the purpose of loading from hf checkpoints. + TODO: Add support for exact conversion of mxfp4 quantized tensors, + then one can save into hf checkpoints with last_save_in_hf = true. + """ + to_hf_map = {v: k for k, v in self.from_hf_map.items()} + hf_state_dict = {} + + for key, value in state_dict.items(): + if "layers" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + if abstract_key not in to_hf_map: + continue + layer_num = re.search(r"\d+", key).group(0) + hf_key = to_hf_map[abstract_key] + match hf_key: + case (blocks, scales): + blocks = blocks.format(layer_num) + scales = scales.format(layer_num) + hf_state_dict[blocks] = value.new_empty( + (*value.shape[:2], value.shape[2] // 32, 16), + dtype=torch.uint8, + ) + hf_state_dict[scales] = value.new_empty( + (*value.shape[:2], value.shape[2] // 32), + dtype=torch.uint8, + ) + case tensor_name: + tensor_name = tensor_name.format(layer_num) + hf_state_dict[tensor_name] = value + else: + hf_key = to_hf_map[key] + hf_state_dict[hf_key] = value + + return hf_state_dict + + def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: + """ + Convert from quantized hf format state dict to tt model state dict. + """ + + state_dict = {} + + subtract_key = lambda key: re.sub(r"(\d+)", "{}", key, count=1) + + for key, value in hf_state_dict.items(): + if "layers" in key: + layer_num = re.search(r"\d+", key).group(0) + if "_blocks" in key: + value_scale = hf_state_dict[key.replace("_blocks", "_scales")] + abstract_key = ( + subtract_key(key), + subtract_key(key.replace("_blocks", "_scales")), + ) + tt_key = self.from_hf_map[abstract_key] + tt_key = tt_key.format(layer_num) + dequantized_values = get_mxfp4_tensor(value, value_scale) + state_dict[tt_key] = dequantized_values + elif "_scales" not in key: + abstract_key = subtract_key(key) + tt_key = self.from_hf_map[abstract_key] + tt_key = tt_key.format(layer_num) + state_dict[tt_key] = value + else: + tt_key = self.from_hf_map[key] + state_dict[tt_key] = value + + return state_dict From 768ede312a8657fbc8a89ab6d01a3d0019b2a2ac Mon Sep 17 00:00:00 2001 From: Shuhua Yu Date: Wed, 12 Nov 2025 11:41:44 -0800 Subject: [PATCH 2/2] [GPT-OSS] Offload dequantization to QuantizedHuggingFaceStorageReader --- .../gpt_oss/model/state_dict_adapter.py | 170 +++++------------- 1 file changed, 40 insertions(+), 130 deletions(-) diff --git a/torchtitan/experiments/gpt_oss/model/state_dict_adapter.py b/torchtitan/experiments/gpt_oss/model/state_dict_adapter.py index 0377c049f8..ca85789baf 100644 --- a/torchtitan/experiments/gpt_oss/model/state_dict_adapter.py +++ b/torchtitan/experiments/gpt_oss/model/state_dict_adapter.py @@ -4,99 +4,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import math import re from typing import Any -import torch -from torch.distributed.tensor import DTensor +from torch.distributed.checkpoint import HuggingFaceStorageReader from torchtitan.models.utils import MoEStateDictAdapter from .args import GptOssModelArgs -FP4_VALUES = [ - +0.0, - +0.5, - +1.0, - +1.5, - +2.0, - +3.0, - +4.0, - +6.0, - -0.0, - -0.5, - -1.0, - -1.5, - -2.0, - -3.0, - -4.0, - -6.0, -] - - -def get_mxfp4_tensor( - blocks, - scales, - *, - dtype: torch.dtype = torch.bfloat16, - rows_per_chunk: int = 16384 * 512, -) -> torch.Tensor: - """ - Adapted from openai's implementation of mxfp4 dequantization: - https://github.com/openai/gpt-oss/blob/8890e95919f975a490fc0ba09ffb10890ec7319d/gpt_oss/torch/weights.py#L68 - """ - - is_dtensor = isinstance(blocks, DTensor) - if is_dtensor: - device_mesh = blocks.device_mesh - placements = blocks.placements - blocks = blocks.to_local() - scales = scales.to_local() - - scales = scales.to(torch.int32) - 127 - - assert ( - blocks.shape[:-1] == scales.shape - ), f"{blocks.shape=} does not match {scales.shape=}" - - lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device) - - *prefix_shape, G, B = blocks.shape - rows_total = math.prod(prefix_shape) * G - - blocks = blocks.reshape(rows_total, B) - scales = scales.reshape(rows_total, 1) - - out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device) - - for r0 in range(0, rows_total, rows_per_chunk): - r1 = min(r0 + rows_per_chunk, rows_total) - - blk = blocks[r0:r1] - exp = scales[r0:r1] - - # nibble indices -> int64 - idx_lo = (blk & 0x0F).to(torch.long) - idx_hi = (blk >> 4).to(torch.long) - - sub = out[r0:r1] - sub[:, 0::2] = lut[idx_lo] - sub[:, 1::2] = lut[idx_hi] - - torch.ldexp(sub, exp, out=sub) - del idx_lo, idx_hi, blk, exp - - result = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) - - if is_dtensor: - result = DTensor.from_local( - result, device_mesh=device_mesh, placements=placements - ) - - return result - - class GptOssStateDictAdapter(MoEStateDictAdapter): def __init__(self, model_args: GptOssModelArgs, hf_assets_path: str | None): super().__init__(model_args, hf_assets_path) @@ -116,15 +32,9 @@ def __init__(self, model_args: GptOssModelArgs, hf_assets_path: str | None): "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", # MoE - ( - "model.layers.{}.mlp.experts.gate_up_proj_blocks", - "model.layers.{}.mlp.experts.gate_up_proj_scales", - ): "layers.{}.moe.experts.mlp1_weight", + "model.layers.{}.mlp.experts.gate_up_proj_blocks": "layers.{}.moe.experts.mlp1_weight", "model.layers.{}.mlp.experts.gate_up_proj_bias": "layers.{}.moe.experts.mlp1_bias", - ( - "model.layers.{}.mlp.experts.down_proj_blocks", - "model.layers.{}.mlp.experts.down_proj_scales", - ): "layers.{}.moe.experts.mlp2_weight", + "model.layers.{}.mlp.experts.down_proj_blocks": "layers.{}.moe.experts.mlp2_weight", "model.layers.{}.mlp.experts.down_proj_bias": "layers.{}.moe.experts.mlp2_bias", "model.layers.{}.mlp.router.weight": "layers.{}.moe.router.gate.weight", "model.layers.{}.mlp.router.bias": "layers.{}.moe.router.gate.bias", @@ -132,13 +42,37 @@ def __init__(self, model_args: GptOssModelArgs, hf_assets_path: str | None): "lm_head.weight": "output.weight", } + def get_hf_storage_reader( + self, path: str, from_quantized: bool = False + ) -> HuggingFaceStorageReader: + """ + Override default get_hf_storage_reader function to return QuantizedHFStorageReader. + """ + if from_quantized: + from torch.distributed.checkpoint.quantized_hf_storage import ( + QuantizedHuggingFaceStorageReader, + ) + + # NOTE: Now we use Quantized HF storage reader to read GPT-OSS model where + # expert weights are saved in MXFP4 format. + # If loading checkpoints without quantization, use HuggingFaceStorageReader instead + return QuantizedHuggingFaceStorageReader( + path=path, + thread_count=4, + ) + else: + return HuggingFaceStorageReader(path) + def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: """ Convert from a tt model state dict to a hf format state dict. - Warning: Conversion does not support mxfp4 quantization, - and the function is only for the purpose of loading from hf checkpoints. - TODO: Add support for exact conversion of mxfp4 quantized tensors, - then one can save into hf checkpoints with last_save_in_hf = true. + + Only map keys without changing shapes to the same as MXFP4 checkpoint. + For loading from quantized checkpoints, the QuantizedHuggingFaceStorageReader + will handle dequantization during load. + + Warning: Conversion does not support saving to mxfp4 quantization format. + One can save into unquantized hf checkpoints with last_save_in_hf = true. """ to_hf_map = {v: k for k, v in self.from_hf_map.items()} hf_state_dict = {} @@ -150,22 +84,11 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: continue layer_num = re.search(r"\d+", key).group(0) hf_key = to_hf_map[abstract_key] - match hf_key: - case (blocks, scales): - blocks = blocks.format(layer_num) - scales = scales.format(layer_num) - hf_state_dict[blocks] = value.new_empty( - (*value.shape[:2], value.shape[2] // 32, 16), - dtype=torch.uint8, - ) - hf_state_dict[scales] = value.new_empty( - (*value.shape[:2], value.shape[2] // 32), - dtype=torch.uint8, - ) - case tensor_name: - tensor_name = tensor_name.format(layer_num) - hf_state_dict[tensor_name] = value + hf_key = hf_key.format(layer_num) + hf_state_dict[hf_key] = value else: + if key not in to_hf_map: + continue hf_key = to_hf_map[key] hf_state_dict[hf_key] = value @@ -173,31 +96,18 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: """ - Convert from quantized hf format state dict to tt model state dict. + Convert from hf format state dict to tt model state dict. """ state_dict = {} - subtract_key = lambda key: re.sub(r"(\d+)", "{}", key, count=1) - for key, value in hf_state_dict.items(): if "layers" in key: layer_num = re.search(r"\d+", key).group(0) - if "_blocks" in key: - value_scale = hf_state_dict[key.replace("_blocks", "_scales")] - abstract_key = ( - subtract_key(key), - subtract_key(key.replace("_blocks", "_scales")), - ) - tt_key = self.from_hf_map[abstract_key] - tt_key = tt_key.format(layer_num) - dequantized_values = get_mxfp4_tensor(value, value_scale) - state_dict[tt_key] = dequantized_values - elif "_scales" not in key: - abstract_key = subtract_key(key) - tt_key = self.from_hf_map[abstract_key] - tt_key = tt_key.format(layer_num) - state_dict[tt_key] = value + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + tt_key = self.from_hf_map[abstract_key] + tt_key = tt_key.format(layer_num) + state_dict[tt_key] = value else: tt_key = self.from_hf_map[key] state_dict[tt_key] = value