diff --git a/docs/checkpoint.md b/docs/checkpoint.md index 6e3112309b..8aca58eb06 100644 --- a/docs/checkpoint.md +++ b/docs/checkpoint.md @@ -68,7 +68,7 @@ NGPU=1 CONFIG_FILE= ./run_train.sh --checkpoint.enable --c ### HuggingFace `torchtitan` offers two ways to work with Hugging Face models: either by directly saving and loading a Hugging Face checkpoint during training, or by using an example conversion script to directly reformat the model weights on cpu. -1. You can directly save huggingface model weights during training by using the `--checkpoint.last_save_in_safetensors_format` and `--checkpoint.last_save_model_only` options together. To directly load a `torchtitan` training session from a huggingface safetensors file, enable `--checkpoint.initial_load_in_hf`, and set either `--model.hf_assets_path` or `--checkpoint.initial_load_path` to the directory containing the huggingface checkpoint. `--checkpoint.initial_load_path` overrides `--model.hf_assets_path` if both are set. +1. You can directly save huggingface model weights during training by using the `--checkpoint.last_save_in_hf` and `--checkpoint.last_save_model_only` options together. To directly load a `torchtitan` training session from a huggingface safetensors file, enable `--checkpoint.initial_load_in_hf`, and set either `--model.hf_assets_path` or `--checkpoint.initial_load_path` to the directory containing the huggingface checkpoint. `--checkpoint.initial_load_path` overrides `--model.hf_assets_path` if both are set. 2. To directly reformat the weights without the need to run a training loop, run the corresponding conversion script. The naming scheme is `torchtitan`-centric, e.g. convert_from_hf means convert hf->tt. diff --git a/torchtitan/experiments/gpt_oss/__init__.py b/torchtitan/experiments/gpt_oss/__init__.py index c12ad13a5c..0ebc20645f 100644 --- a/torchtitan/experiments/gpt_oss/__init__.py +++ b/torchtitan/experiments/gpt_oss/__init__.py @@ -16,6 +16,7 @@ from .infra.parallelize import parallelize_gptoss from .model.args import GptOssModelArgs from .model.model import GptOssModel +from .model.state_dict_adapter import GptOssStateDictAdapter __all__ = [ "parallelize_gptoss", @@ -84,4 +85,5 @@ def get_train_spec() -> TrainSpec: build_dataloader_fn=build_text_dataloader, build_tokenizer_fn=build_hf_tokenizer, build_loss_fn=build_cross_entropy_loss, + state_dict_adapter=GptOssStateDictAdapter, ) diff --git a/torchtitan/experiments/gpt_oss/model/state_dict_adapter.py b/torchtitan/experiments/gpt_oss/model/state_dict_adapter.py new file mode 100644 index 0000000000..ca85789baf --- /dev/null +++ b/torchtitan/experiments/gpt_oss/model/state_dict_adapter.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import re +from typing import Any + +from torch.distributed.checkpoint import HuggingFaceStorageReader +from torchtitan.models.utils import MoEStateDictAdapter + +from .args import GptOssModelArgs + + +class GptOssStateDictAdapter(MoEStateDictAdapter): + def __init__(self, model_args: GptOssModelArgs, hf_assets_path: str | None): + super().__init__(model_args, hf_assets_path) + self.from_hf_map = { + "model.embed_tokens.weight": "tok_embeddings.weight", + # Attention module + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.q_proj.bias": "layers.{}.attention.wq.bias", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.k_proj.bias": "layers.{}.attention.wk.bias", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.v_proj.bias": "layers.{}.attention.wv.bias", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.self_attn.o_proj.bias": "layers.{}.attention.wo.bias", + "model.layers.{}.self_attn.sinks": "layers.{}.attention.sinks", + # Transformer layer + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + # MoE + "model.layers.{}.mlp.experts.gate_up_proj_blocks": "layers.{}.moe.experts.mlp1_weight", + "model.layers.{}.mlp.experts.gate_up_proj_bias": "layers.{}.moe.experts.mlp1_bias", + "model.layers.{}.mlp.experts.down_proj_blocks": "layers.{}.moe.experts.mlp2_weight", + "model.layers.{}.mlp.experts.down_proj_bias": "layers.{}.moe.experts.mlp2_bias", + "model.layers.{}.mlp.router.weight": "layers.{}.moe.router.gate.weight", + "model.layers.{}.mlp.router.bias": "layers.{}.moe.router.gate.bias", + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + } + + def get_hf_storage_reader( + self, path: str, from_quantized: bool = False + ) -> HuggingFaceStorageReader: + """ + Override default get_hf_storage_reader function to return QuantizedHFStorageReader. + """ + if from_quantized: + from torch.distributed.checkpoint.quantized_hf_storage import ( + QuantizedHuggingFaceStorageReader, + ) + + # NOTE: Now we use Quantized HF storage reader to read GPT-OSS model where + # expert weights are saved in MXFP4 format. + # If loading checkpoints without quantization, use HuggingFaceStorageReader instead + return QuantizedHuggingFaceStorageReader( + path=path, + thread_count=4, + ) + else: + return HuggingFaceStorageReader(path) + + def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: + """ + Convert from a tt model state dict to a hf format state dict. + + Only map keys without changing shapes to the same as MXFP4 checkpoint. + For loading from quantized checkpoints, the QuantizedHuggingFaceStorageReader + will handle dequantization during load. + + Warning: Conversion does not support saving to mxfp4 quantization format. + One can save into unquantized hf checkpoints with last_save_in_hf = true. + """ + to_hf_map = {v: k for k, v in self.from_hf_map.items()} + hf_state_dict = {} + + for key, value in state_dict.items(): + if "layers" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + if abstract_key not in to_hf_map: + continue + layer_num = re.search(r"\d+", key).group(0) + hf_key = to_hf_map[abstract_key] + hf_key = hf_key.format(layer_num) + hf_state_dict[hf_key] = value + else: + if key not in to_hf_map: + continue + hf_key = to_hf_map[key] + hf_state_dict[hf_key] = value + + return hf_state_dict + + def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: + """ + Convert from hf format state dict to tt model state dict. + """ + + state_dict = {} + + for key, value in hf_state_dict.items(): + if "layers" in key: + layer_num = re.search(r"\d+", key).group(0) + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + tt_key = self.from_hf_map[abstract_key] + tt_key = tt_key.format(layer_num) + state_dict[tt_key] = value + else: + tt_key = self.from_hf_map[key] + state_dict[tt_key] = value + + return state_dict