Skip to content

Commit 446c0ee

Browse files
committed
[GPT-OSS] Add HF state dict adapter to support loading from HF checkpoints
1 parent 157d30d commit 446c0ee

File tree

3 files changed

+208
-1
lines changed

3 files changed

+208
-1
lines changed

docs/checkpoint.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ NGPU=1 CONFIG_FILE=<path_to_model_config> ./run_train.sh --checkpoint.enable --c
6868
### HuggingFace
6969
`torchtitan` offers two ways to work with Hugging Face models: either by directly saving and loading a Hugging Face checkpoint during training, or by using an example conversion script to directly reformat the model weights on cpu.
7070

71-
1. You can directly save huggingface model weights during training by using the `--checkpoint.last_save_in_safetensors_format` and `--checkpoint.last_save_model_only` options together. To directly load a `torchtitan` training session from a huggingface safetensors file, enable `--checkpoint.initial_load_in_hf`, and set either `--model.hf_assets_path` or `--checkpoint.initial_load_path` to the directory containing the huggingface checkpoint. `--checkpoint.initial_load_path` overrides `--model.hf_assets_path` if both are set.
71+
1. You can directly save huggingface model weights during training by using the `--checkpoint.last_save_in_hf` and `--checkpoint.last_save_model_only` options together. To directly load a `torchtitan` training session from a huggingface safetensors file, enable `--checkpoint.initial_load_in_hf`, and set either `--model.hf_assets_path` or `--checkpoint.initial_load_path` to the directory containing the huggingface checkpoint. `--checkpoint.initial_load_path` overrides `--model.hf_assets_path` if both are set.
7272

7373
2. To directly reformat the weights without the need to run a training loop, run the corresponding conversion script. The naming scheme is `torchtitan`-centric, e.g. convert_from_hf means convert hf->tt.
7474

torchtitan/experiments/gpt_oss/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .infra.parallelize import parallelize_gptoss
1717
from .model.args import GptOssModelArgs
1818
from .model.model import GptOssModel
19+
from .model.state_dict_adapter import GptOssStateDictAdapter
1920

2021
__all__ = [
2122
"parallelize_gptoss",
@@ -84,4 +85,5 @@ def get_train_spec() -> TrainSpec:
8485
build_dataloader_fn=build_text_dataloader,
8586
build_tokenizer_fn=build_hf_tokenizer,
8687
build_loss_fn=build_cross_entropy_loss,
88+
state_dict_adapter=GptOssStateDictAdapter,
8789
)
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import math
8+
import re
9+
from typing import Any
10+
11+
import torch
12+
from torch.distributed.tensor import DTensor
13+
from torchtitan.models.utils import MoEStateDictAdapter
14+
15+
from .args import GptOssModelArgs
16+
17+
18+
FP4_VALUES = [
19+
+0.0,
20+
+0.5,
21+
+1.0,
22+
+1.5,
23+
+2.0,
24+
+3.0,
25+
+4.0,
26+
+6.0,
27+
-0.0,
28+
-0.5,
29+
-1.0,
30+
-1.5,
31+
-2.0,
32+
-3.0,
33+
-4.0,
34+
-6.0,
35+
]
36+
37+
38+
def get_mxfp4_tensor(
39+
blocks,
40+
scales,
41+
*,
42+
dtype: torch.dtype = torch.bfloat16,
43+
rows_per_chunk: int = 16384 * 512,
44+
) -> torch.Tensor:
45+
"""
46+
Adapted from openai's implementation of mxfp4 dequantization:
47+
https://github.com/openai/gpt-oss/blob/8890e95919f975a490fc0ba09ffb10890ec7319d/gpt_oss/torch/weights.py#L68
48+
"""
49+
50+
is_dtensor = isinstance(blocks, DTensor)
51+
if is_dtensor:
52+
device_mesh = blocks.device_mesh
53+
placements = blocks.placements
54+
blocks = blocks.to_local()
55+
scales = scales.to_local()
56+
57+
scales = scales.to(torch.int32) - 127
58+
59+
assert (
60+
blocks.shape[:-1] == scales.shape
61+
), f"{blocks.shape=} does not match {scales.shape=}"
62+
63+
lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)
64+
65+
*prefix_shape, G, B = blocks.shape
66+
rows_total = math.prod(prefix_shape) * G
67+
68+
blocks = blocks.reshape(rows_total, B)
69+
scales = scales.reshape(rows_total, 1)
70+
71+
out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device)
72+
73+
for r0 in range(0, rows_total, rows_per_chunk):
74+
r1 = min(r0 + rows_per_chunk, rows_total)
75+
76+
blk = blocks[r0:r1]
77+
exp = scales[r0:r1]
78+
79+
# nibble indices -> int64
80+
idx_lo = (blk & 0x0F).to(torch.long)
81+
idx_hi = (blk >> 4).to(torch.long)
82+
83+
sub = out[r0:r1]
84+
sub[:, 0::2] = lut[idx_lo]
85+
sub[:, 1::2] = lut[idx_hi]
86+
87+
torch.ldexp(sub, exp, out=sub)
88+
del idx_lo, idx_hi, blk, exp
89+
90+
result = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2)
91+
92+
if is_dtensor:
93+
result = DTensor.from_local(
94+
result, device_mesh=device_mesh, placements=placements
95+
)
96+
97+
return result
98+
99+
100+
class GptOssStateDictAdapter(MoEStateDictAdapter):
101+
def __init__(self, model_args: GptOssModelArgs, hf_assets_path: str | None):
102+
super().__init__(model_args, hf_assets_path)
103+
self.from_hf_map = {
104+
"model.embed_tokens.weight": "tok_embeddings.weight",
105+
# Attention module
106+
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
107+
"model.layers.{}.self_attn.q_proj.bias": "layers.{}.attention.wq.bias",
108+
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
109+
"model.layers.{}.self_attn.k_proj.bias": "layers.{}.attention.wk.bias",
110+
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
111+
"model.layers.{}.self_attn.v_proj.bias": "layers.{}.attention.wv.bias",
112+
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
113+
"model.layers.{}.self_attn.o_proj.bias": "layers.{}.attention.wo.bias",
114+
"model.layers.{}.self_attn.sinks": "layers.{}.attention.sinks",
115+
# Transformer layer
116+
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
117+
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
118+
# MoE
119+
(
120+
"model.layers.{}.mlp.experts.gate_up_proj_blocks",
121+
"model.layers.{}.mlp.experts.gate_up_proj_scales",
122+
): "layers.{}.moe.experts.mlp1_weight",
123+
"model.layers.{}.mlp.experts.gate_up_proj_bias": "layers.{}.moe.experts.mlp1_bias",
124+
(
125+
"model.layers.{}.mlp.experts.down_proj_blocks",
126+
"model.layers.{}.mlp.experts.down_proj_scales",
127+
): "layers.{}.moe.experts.mlp2_weight",
128+
"model.layers.{}.mlp.experts.down_proj_bias": "layers.{}.moe.experts.mlp2_bias",
129+
"model.layers.{}.mlp.router.weight": "layers.{}.moe.router.gate.weight",
130+
"model.layers.{}.mlp.router.bias": "layers.{}.moe.router.gate.bias",
131+
"model.norm.weight": "norm.weight",
132+
"lm_head.weight": "output.weight",
133+
}
134+
135+
def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
136+
"""
137+
Convert from a tt model state dict to a hf format state dict.
138+
Warning: Conversion does not support mxfp4 quantization,
139+
and the function is only for the purpose of loading from hf checkpoints.
140+
TODO: Add support for exact conversion of mxfp4 quantized tensors,
141+
then one can save into hf checkpoints with last_save_in_hf = true.
142+
"""
143+
to_hf_map = {v: k for k, v in self.from_hf_map.items()}
144+
hf_state_dict = {}
145+
146+
for key, value in state_dict.items():
147+
if "layers" in key:
148+
abstract_key = re.sub(r"(\d+)", "{}", key, count=1)
149+
if abstract_key not in to_hf_map:
150+
continue
151+
layer_num = re.search(r"\d+", key).group(0)
152+
hf_key = to_hf_map[abstract_key]
153+
match hf_key:
154+
case (blocks, scales):
155+
blocks = blocks.format(layer_num)
156+
scales = scales.format(layer_num)
157+
hf_state_dict[blocks] = value.new_empty(
158+
(*value.shape[:2], value.shape[2] // 32, 16),
159+
dtype=torch.uint8,
160+
)
161+
hf_state_dict[scales] = value.new_empty(
162+
(*value.shape[:2], value.shape[2] // 32),
163+
dtype=torch.uint8,
164+
)
165+
case tensor_name:
166+
tensor_name = tensor_name.format(layer_num)
167+
hf_state_dict[tensor_name] = value
168+
else:
169+
hf_key = to_hf_map[key]
170+
hf_state_dict[hf_key] = value
171+
172+
return hf_state_dict
173+
174+
def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
175+
"""
176+
Convert from quantized hf format state dict to tt model state dict.
177+
"""
178+
179+
state_dict = {}
180+
181+
subtract_key = lambda key: re.sub(r"(\d+)", "{}", key, count=1)
182+
183+
for key, value in hf_state_dict.items():
184+
if "layers" in key:
185+
layer_num = re.search(r"\d+", key).group(0)
186+
if "_blocks" in key:
187+
value_scale = hf_state_dict[key.replace("_blocks", "_scales")]
188+
abstract_key = (
189+
subtract_key(key),
190+
subtract_key(key.replace("_blocks", "_scales")),
191+
)
192+
tt_key = self.from_hf_map[abstract_key]
193+
tt_key = tt_key.format(layer_num)
194+
dequantized_values = get_mxfp4_tensor(value, value_scale)
195+
state_dict[tt_key] = dequantized_values
196+
elif "_scales" not in key:
197+
abstract_key = subtract_key(key)
198+
tt_key = self.from_hf_map[abstract_key]
199+
tt_key = tt_key.format(layer_num)
200+
state_dict[tt_key] = value
201+
else:
202+
tt_key = self.from_hf_map[key]
203+
state_dict[tt_key] = value
204+
205+
return state_dict

0 commit comments

Comments
 (0)