Skip to content

Commit 7661fb0

Browse files
committed
[GPT-OSS] Offload dequantization to QuantizedHuggingFaceStorageReader
1 parent 446c0ee commit 7661fb0

File tree

1 file changed

+40
-130
lines changed

1 file changed

+40
-130
lines changed

torchtitan/experiments/gpt_oss/model/state_dict_adapter.py

Lines changed: 40 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -4,99 +4,15 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import math
87
import re
98
from typing import Any
109

11-
import torch
12-
from torch.distributed.tensor import DTensor
10+
from torch.distributed.checkpoint import HuggingFaceStorageReader
1311
from torchtitan.models.utils import MoEStateDictAdapter
1412

1513
from .args import GptOssModelArgs
1614

1715

18-
FP4_VALUES = [
19-
+0.0,
20-
+0.5,
21-
+1.0,
22-
+1.5,
23-
+2.0,
24-
+3.0,
25-
+4.0,
26-
+6.0,
27-
-0.0,
28-
-0.5,
29-
-1.0,
30-
-1.5,
31-
-2.0,
32-
-3.0,
33-
-4.0,
34-
-6.0,
35-
]
36-
37-
38-
def get_mxfp4_tensor(
39-
blocks,
40-
scales,
41-
*,
42-
dtype: torch.dtype = torch.bfloat16,
43-
rows_per_chunk: int = 16384 * 512,
44-
) -> torch.Tensor:
45-
"""
46-
Adapted from openai's implementation of mxfp4 dequantization:
47-
https://github.com/openai/gpt-oss/blob/8890e95919f975a490fc0ba09ffb10890ec7319d/gpt_oss/torch/weights.py#L68
48-
"""
49-
50-
is_dtensor = isinstance(blocks, DTensor)
51-
if is_dtensor:
52-
device_mesh = blocks.device_mesh
53-
placements = blocks.placements
54-
blocks = blocks.to_local()
55-
scales = scales.to_local()
56-
57-
scales = scales.to(torch.int32) - 127
58-
59-
assert (
60-
blocks.shape[:-1] == scales.shape
61-
), f"{blocks.shape=} does not match {scales.shape=}"
62-
63-
lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)
64-
65-
*prefix_shape, G, B = blocks.shape
66-
rows_total = math.prod(prefix_shape) * G
67-
68-
blocks = blocks.reshape(rows_total, B)
69-
scales = scales.reshape(rows_total, 1)
70-
71-
out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device)
72-
73-
for r0 in range(0, rows_total, rows_per_chunk):
74-
r1 = min(r0 + rows_per_chunk, rows_total)
75-
76-
blk = blocks[r0:r1]
77-
exp = scales[r0:r1]
78-
79-
# nibble indices -> int64
80-
idx_lo = (blk & 0x0F).to(torch.long)
81-
idx_hi = (blk >> 4).to(torch.long)
82-
83-
sub = out[r0:r1]
84-
sub[:, 0::2] = lut[idx_lo]
85-
sub[:, 1::2] = lut[idx_hi]
86-
87-
torch.ldexp(sub, exp, out=sub)
88-
del idx_lo, idx_hi, blk, exp
89-
90-
result = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2)
91-
92-
if is_dtensor:
93-
result = DTensor.from_local(
94-
result, device_mesh=device_mesh, placements=placements
95-
)
96-
97-
return result
98-
99-
10016
class GptOssStateDictAdapter(MoEStateDictAdapter):
10117
def __init__(self, model_args: GptOssModelArgs, hf_assets_path: str | None):
10218
super().__init__(model_args, hf_assets_path)
@@ -116,29 +32,47 @@ def __init__(self, model_args: GptOssModelArgs, hf_assets_path: str | None):
11632
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
11733
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
11834
# MoE
119-
(
120-
"model.layers.{}.mlp.experts.gate_up_proj_blocks",
121-
"model.layers.{}.mlp.experts.gate_up_proj_scales",
122-
): "layers.{}.moe.experts.mlp1_weight",
35+
"model.layers.{}.mlp.experts.gate_up_proj_blocks": "layers.{}.moe.experts.mlp1_weight",
12336
"model.layers.{}.mlp.experts.gate_up_proj_bias": "layers.{}.moe.experts.mlp1_bias",
124-
(
125-
"model.layers.{}.mlp.experts.down_proj_blocks",
126-
"model.layers.{}.mlp.experts.down_proj_scales",
127-
): "layers.{}.moe.experts.mlp2_weight",
37+
"model.layers.{}.mlp.experts.down_proj_blocks": "layers.{}.moe.experts.mlp2_weight",
12838
"model.layers.{}.mlp.experts.down_proj_bias": "layers.{}.moe.experts.mlp2_bias",
12939
"model.layers.{}.mlp.router.weight": "layers.{}.moe.router.gate.weight",
13040
"model.layers.{}.mlp.router.bias": "layers.{}.moe.router.gate.bias",
13141
"model.norm.weight": "norm.weight",
13242
"lm_head.weight": "output.weight",
13343
}
13444

45+
def get_hf_storage_reader(
46+
self, path: str, from_quantized: bool = False
47+
) -> HuggingFaceStorageReader:
48+
"""
49+
Override default get_hf_storage_reader function to return QuantizedHFStorageReader.
50+
"""
51+
if from_quantized:
52+
from torch.distributed.checkpoint.quantized_hf_storage import (
53+
QuantizedHuggingFaceStorageReader,
54+
)
55+
56+
# NOTE: Now we use Quantized HF storage reader to read GPT-OSS model where
57+
# expert weights are saved in MXFP4 format.
58+
# If loading checkpoints without quantization, use HuggingFaceStorageReader instead
59+
return QuantizedHuggingFaceStorageReader(
60+
path=path,
61+
thread_count=4,
62+
)
63+
else:
64+
return HuggingFaceStorageReader(path)
65+
13566
def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
13667
"""
13768
Convert from a tt model state dict to a hf format state dict.
138-
Warning: Conversion does not support mxfp4 quantization,
139-
and the function is only for the purpose of loading from hf checkpoints.
140-
TODO: Add support for exact conversion of mxfp4 quantized tensors,
141-
then one can save into hf checkpoints with last_save_in_hf = true.
69+
70+
Only map keys without changing shapes to the same as MXFP4 checkpoint.
71+
For loading from quantized checkpoints, the QuantizedHuggingFaceStorageReader
72+
will handle dequantization during load.
73+
74+
Warning: Conversion does not support saving to mxfp4 quantization format.
75+
One can save into unquantized hf checkpoints with last_save_in_hf = true.
14276
"""
14377
to_hf_map = {v: k for k, v in self.from_hf_map.items()}
14478
hf_state_dict = {}
@@ -150,54 +84,30 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
15084
continue
15185
layer_num = re.search(r"\d+", key).group(0)
15286
hf_key = to_hf_map[abstract_key]
153-
match hf_key:
154-
case (blocks, scales):
155-
blocks = blocks.format(layer_num)
156-
scales = scales.format(layer_num)
157-
hf_state_dict[blocks] = value.new_empty(
158-
(*value.shape[:2], value.shape[2] // 32, 16),
159-
dtype=torch.uint8,
160-
)
161-
hf_state_dict[scales] = value.new_empty(
162-
(*value.shape[:2], value.shape[2] // 32),
163-
dtype=torch.uint8,
164-
)
165-
case tensor_name:
166-
tensor_name = tensor_name.format(layer_num)
167-
hf_state_dict[tensor_name] = value
87+
hf_key = hf_key.format(layer_num)
88+
hf_state_dict[hf_key] = value
16889
else:
90+
if key not in to_hf_map:
91+
continue
16992
hf_key = to_hf_map[key]
17093
hf_state_dict[hf_key] = value
17194

17295
return hf_state_dict
17396

17497
def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
17598
"""
176-
Convert from quantized hf format state dict to tt model state dict.
99+
Convert from hf format state dict to tt model state dict.
177100
"""
178101

179102
state_dict = {}
180103

181-
subtract_key = lambda key: re.sub(r"(\d+)", "{}", key, count=1)
182-
183104
for key, value in hf_state_dict.items():
184105
if "layers" in key:
185106
layer_num = re.search(r"\d+", key).group(0)
186-
if "_blocks" in key:
187-
value_scale = hf_state_dict[key.replace("_blocks", "_scales")]
188-
abstract_key = (
189-
subtract_key(key),
190-
subtract_key(key.replace("_blocks", "_scales")),
191-
)
192-
tt_key = self.from_hf_map[abstract_key]
193-
tt_key = tt_key.format(layer_num)
194-
dequantized_values = get_mxfp4_tensor(value, value_scale)
195-
state_dict[tt_key] = dequantized_values
196-
elif "_scales" not in key:
197-
abstract_key = subtract_key(key)
198-
tt_key = self.from_hf_map[abstract_key]
199-
tt_key = tt_key.format(layer_num)
200-
state_dict[tt_key] = value
107+
abstract_key = re.sub(r"(\d+)", "{}", key, count=1)
108+
tt_key = self.from_hf_map[abstract_key]
109+
tt_key = tt_key.format(layer_num)
110+
state_dict[tt_key] = value
201111
else:
202112
tt_key = self.from_hf_map[key]
203113
state_dict[tt_key] = value

0 commit comments

Comments
 (0)