Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/checkpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ NGPU=1 CONFIG_FILE=<path_to_model_config> ./run_train.sh --checkpoint.enable --c
### HuggingFace
`torchtitan` offers two ways to work with Hugging Face models: either by directly saving and loading a Hugging Face checkpoint during training, or by using an example conversion script to directly reformat the model weights on cpu.

1. You can directly save huggingface model weights during training by using the `--checkpoint.last_save_in_safetensors_format` and `--checkpoint.last_save_model_only` options together. To directly load a `torchtitan` training session from a huggingface safetensors file, enable `--checkpoint.initial_load_in_hf`, and set either `--model.hf_assets_path` or `--checkpoint.initial_load_path` to the directory containing the huggingface checkpoint. `--checkpoint.initial_load_path` overrides `--model.hf_assets_path` if both are set.
1. You can directly save huggingface model weights during training by using the `--checkpoint.last_save_in_hf` and `--checkpoint.last_save_model_only` options together. To directly load a `torchtitan` training session from a huggingface safetensors file, enable `--checkpoint.initial_load_in_hf`, and set either `--model.hf_assets_path` or `--checkpoint.initial_load_path` to the directory containing the huggingface checkpoint. `--checkpoint.initial_load_path` overrides `--model.hf_assets_path` if both are set.

2. To directly reformat the weights without the need to run a training loop, run the corresponding conversion script. The naming scheme is `torchtitan`-centric, e.g. convert_from_hf means convert hf->tt.

Expand Down
2 changes: 2 additions & 0 deletions torchtitan/experiments/gpt_oss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .infra.parallelize import parallelize_gptoss
from .model.args import GptOssModelArgs
from .model.model import GptOssModel
from .model.state_dict_adapter import GptOssStateDictAdapter

__all__ = [
"parallelize_gptoss",
Expand Down Expand Up @@ -84,4 +85,5 @@ def get_train_spec() -> TrainSpec:
build_dataloader_fn=build_text_dataloader,
build_tokenizer_fn=build_hf_tokenizer,
build_loss_fn=build_cross_entropy_loss,
state_dict_adapter=GptOssStateDictAdapter,
)
115 changes: 115 additions & 0 deletions torchtitan/experiments/gpt_oss/model/state_dict_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import re
from typing import Any

from torch.distributed.checkpoint import HuggingFaceStorageReader
from torchtitan.models.utils import MoEStateDictAdapter

from .args import GptOssModelArgs


class GptOssStateDictAdapter(MoEStateDictAdapter):
def __init__(self, model_args: GptOssModelArgs, hf_assets_path: str | None):
super().__init__(model_args, hf_assets_path)
self.from_hf_map = {
"model.embed_tokens.weight": "tok_embeddings.weight",
# Attention module
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.q_proj.bias": "layers.{}.attention.wq.bias",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.k_proj.bias": "layers.{}.attention.wk.bias",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.v_proj.bias": "layers.{}.attention.wv.bias",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
"model.layers.{}.self_attn.o_proj.bias": "layers.{}.attention.wo.bias",
"model.layers.{}.self_attn.sinks": "layers.{}.attention.sinks",
# Transformer layer
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
# MoE
"model.layers.{}.mlp.experts.gate_up_proj_blocks": "layers.{}.moe.experts.mlp1_weight",
"model.layers.{}.mlp.experts.gate_up_proj_bias": "layers.{}.moe.experts.mlp1_bias",
"model.layers.{}.mlp.experts.down_proj_blocks": "layers.{}.moe.experts.mlp2_weight",
"model.layers.{}.mlp.experts.down_proj_bias": "layers.{}.moe.experts.mlp2_bias",
"model.layers.{}.mlp.router.weight": "layers.{}.moe.router.gate.weight",
"model.layers.{}.mlp.router.bias": "layers.{}.moe.router.gate.bias",
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}

def get_hf_storage_reader(
self, path: str, from_quantized: bool = False
) -> HuggingFaceStorageReader:
"""
Override default get_hf_storage_reader function to return QuantizedHFStorageReader.
"""
if from_quantized:
from torch.distributed.checkpoint.quantized_hf_storage import (
QuantizedHuggingFaceStorageReader,
)

# NOTE: Now we use Quantized HF storage reader to read GPT-OSS model where
# expert weights are saved in MXFP4 format.
# If loading checkpoints without quantization, use HuggingFaceStorageReader instead
return QuantizedHuggingFaceStorageReader(
path=path,
thread_count=4,
)
else:
return HuggingFaceStorageReader(path)

def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
"""
Convert from a tt model state dict to a hf format state dict.

Only map keys without changing shapes to the same as MXFP4 checkpoint.
For loading from quantized checkpoints, the QuantizedHuggingFaceStorageReader
will handle dequantization during load.

Warning: Conversion does not support saving to mxfp4 quantization format.
One can save into unquantized hf checkpoints with last_save_in_hf = true.
"""
to_hf_map = {v: k for k, v in self.from_hf_map.items()}
hf_state_dict = {}

for key, value in state_dict.items():
if "layers" in key:
abstract_key = re.sub(r"(\d+)", "{}", key, count=1)
if abstract_key not in to_hf_map:
continue
layer_num = re.search(r"\d+", key).group(0)
hf_key = to_hf_map[abstract_key]
hf_key = hf_key.format(layer_num)
hf_state_dict[hf_key] = value
else:
if key not in to_hf_map:
continue
hf_key = to_hf_map[key]
hf_state_dict[hf_key] = value

return hf_state_dict

def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
"""
Convert from hf format state dict to tt model state dict.
"""

state_dict = {}

for key, value in hf_state_dict.items():
if "layers" in key:
layer_num = re.search(r"\d+", key).group(0)
abstract_key = re.sub(r"(\d+)", "{}", key, count=1)
tt_key = self.from_hf_map[abstract_key]
tt_key = tt_key.format(layer_num)
state_dict[tt_key] = value
else:
tt_key = self.from_hf_map[key]
state_dict[tt_key] = value

return state_dict
Loading