44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- import math
87import re
98from typing import Any
109
11- import torch
12- from torch .distributed .tensor import DTensor
10+ from torch .distributed .checkpoint import HuggingFaceStorageReader
1311from torchtitan .models .utils import MoEStateDictAdapter
1412
1513from .args import GptOssModelArgs
1614
1715
18- FP4_VALUES = [
19- + 0.0 ,
20- + 0.5 ,
21- + 1.0 ,
22- + 1.5 ,
23- + 2.0 ,
24- + 3.0 ,
25- + 4.0 ,
26- + 6.0 ,
27- - 0.0 ,
28- - 0.5 ,
29- - 1.0 ,
30- - 1.5 ,
31- - 2.0 ,
32- - 3.0 ,
33- - 4.0 ,
34- - 6.0 ,
35- ]
36-
37-
38- def get_mxfp4_tensor (
39- blocks ,
40- scales ,
41- * ,
42- dtype : torch .dtype = torch .bfloat16 ,
43- rows_per_chunk : int = 16384 * 512 ,
44- ) -> torch .Tensor :
45- """
46- Adapted from openai's implementation of mxfp4 dequantization:
47- https://github.com/openai/gpt-oss/blob/8890e95919f975a490fc0ba09ffb10890ec7319d/gpt_oss/torch/weights.py#L68
48- """
49-
50- is_dtensor = isinstance (blocks , DTensor )
51- if is_dtensor :
52- device_mesh = blocks .device_mesh
53- placements = blocks .placements
54- blocks = blocks .to_local ()
55- scales = scales .to_local ()
56-
57- scales = scales .to (torch .int32 ) - 127
58-
59- assert (
60- blocks .shape [:- 1 ] == scales .shape
61- ), f"{ blocks .shape = } does not match { scales .shape = } "
62-
63- lut = torch .tensor (FP4_VALUES , dtype = dtype , device = blocks .device )
64-
65- * prefix_shape , G , B = blocks .shape
66- rows_total = math .prod (prefix_shape ) * G
67-
68- blocks = blocks .reshape (rows_total , B )
69- scales = scales .reshape (rows_total , 1 )
70-
71- out = torch .empty (rows_total , B * 2 , dtype = dtype , device = blocks .device )
72-
73- for r0 in range (0 , rows_total , rows_per_chunk ):
74- r1 = min (r0 + rows_per_chunk , rows_total )
75-
76- blk = blocks [r0 :r1 ]
77- exp = scales [r0 :r1 ]
78-
79- # nibble indices -> int64
80- idx_lo = (blk & 0x0F ).to (torch .long )
81- idx_hi = (blk >> 4 ).to (torch .long )
82-
83- sub = out [r0 :r1 ]
84- sub [:, 0 ::2 ] = lut [idx_lo ]
85- sub [:, 1 ::2 ] = lut [idx_hi ]
86-
87- torch .ldexp (sub , exp , out = sub )
88- del idx_lo , idx_hi , blk , exp
89-
90- result = out .reshape (* prefix_shape , G , B * 2 ).view (* prefix_shape , G * B * 2 )
91-
92- if is_dtensor :
93- result = DTensor .from_local (
94- result , device_mesh = device_mesh , placements = placements
95- )
96-
97- return result
98-
99-
10016class GptOssStateDictAdapter (MoEStateDictAdapter ):
10117 def __init__ (self , model_args : GptOssModelArgs , hf_assets_path : str | None ):
10218 super ().__init__ (model_args , hf_assets_path )
@@ -116,29 +32,47 @@ def __init__(self, model_args: GptOssModelArgs, hf_assets_path: str | None):
11632 "model.layers.{}.input_layernorm.weight" : "layers.{}.attention_norm.weight" ,
11733 "model.layers.{}.post_attention_layernorm.weight" : "layers.{}.ffn_norm.weight" ,
11834 # MoE
119- (
120- "model.layers.{}.mlp.experts.gate_up_proj_blocks" ,
121- "model.layers.{}.mlp.experts.gate_up_proj_scales" ,
122- ): "layers.{}.moe.experts.mlp1_weight" ,
35+ "model.layers.{}.mlp.experts.gate_up_proj_blocks" : "layers.{}.moe.experts.mlp1_weight" ,
12336 "model.layers.{}.mlp.experts.gate_up_proj_bias" : "layers.{}.moe.experts.mlp1_bias" ,
124- (
125- "model.layers.{}.mlp.experts.down_proj_blocks" ,
126- "model.layers.{}.mlp.experts.down_proj_scales" ,
127- ): "layers.{}.moe.experts.mlp2_weight" ,
37+ "model.layers.{}.mlp.experts.down_proj_blocks" : "layers.{}.moe.experts.mlp2_weight" ,
12838 "model.layers.{}.mlp.experts.down_proj_bias" : "layers.{}.moe.experts.mlp2_bias" ,
12939 "model.layers.{}.mlp.router.weight" : "layers.{}.moe.router.gate.weight" ,
13040 "model.layers.{}.mlp.router.bias" : "layers.{}.moe.router.gate.bias" ,
13141 "model.norm.weight" : "norm.weight" ,
13242 "lm_head.weight" : "output.weight" ,
13343 }
13444
45+ def get_hf_storage_reader (
46+ self , path : str , from_quantized : bool = False
47+ ) -> HuggingFaceStorageReader :
48+ """
49+ Override default get_hf_storage_reader function to return QuantizedHFStorageReader.
50+ """
51+ if from_quantized :
52+ from torch .distributed .checkpoint .quantized_hf_storage import (
53+ QuantizedHuggingFaceStorageReader ,
54+ )
55+
56+ # NOTE: Now we use Quantized HF storage reader to read GPT-OSS model where
57+ # expert weights are saved in MXFP4 format.
58+ # If loading checkpoints without quantization, use HuggingFaceStorageReader instead
59+ return QuantizedHuggingFaceStorageReader (
60+ path = path ,
61+ thread_count = 4 ,
62+ )
63+ else :
64+ return HuggingFaceStorageReader (path )
65+
13566 def to_hf (self , state_dict : dict [str , Any ]) -> dict [str , Any ]:
13667 """
13768 Convert from a tt model state dict to a hf format state dict.
138- Warning: Conversion does not support mxfp4 quantization,
139- and the function is only for the purpose of loading from hf checkpoints.
140- TODO: Add support for exact conversion of mxfp4 quantized tensors,
141- then one can save into hf checkpoints with last_save_in_hf = true.
69+
70+ Only map keys without changing shapes to the same as MXFP4 checkpoint.
71+ For loading from quantized checkpoints, the QuantizedHuggingFaceStorageReader
72+ will handle dequantization during load.
73+
74+ Warning: Conversion does not support saving to mxfp4 quantization format.
75+ One can save into unquantized hf checkpoints with last_save_in_hf = true.
14276 """
14377 to_hf_map = {v : k for k , v in self .from_hf_map .items ()}
14478 hf_state_dict = {}
@@ -150,54 +84,30 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
15084 continue
15185 layer_num = re .search (r"\d+" , key ).group (0 )
15286 hf_key = to_hf_map [abstract_key ]
153- match hf_key :
154- case (blocks , scales ):
155- blocks = blocks .format (layer_num )
156- scales = scales .format (layer_num )
157- hf_state_dict [blocks ] = value .new_empty (
158- (* value .shape [:2 ], value .shape [2 ] // 32 , 16 ),
159- dtype = torch .uint8 ,
160- )
161- hf_state_dict [scales ] = value .new_empty (
162- (* value .shape [:2 ], value .shape [2 ] // 32 ),
163- dtype = torch .uint8 ,
164- )
165- case tensor_name :
166- tensor_name = tensor_name .format (layer_num )
167- hf_state_dict [tensor_name ] = value
87+ hf_key = hf_key .format (layer_num )
88+ hf_state_dict [hf_key ] = value
16889 else :
90+ if key not in to_hf_map :
91+ continue
16992 hf_key = to_hf_map [key ]
17093 hf_state_dict [hf_key ] = value
17194
17295 return hf_state_dict
17396
17497 def from_hf (self , hf_state_dict : dict [str , Any ]) -> dict [str , Any ]:
17598 """
176- Convert from quantized hf format state dict to tt model state dict.
99+ Convert from hf format state dict to tt model state dict.
177100 """
178101
179102 state_dict = {}
180103
181- subtract_key = lambda key : re .sub (r"(\d+)" , "{}" , key , count = 1 )
182-
183104 for key , value in hf_state_dict .items ():
184105 if "layers" in key :
185106 layer_num = re .search (r"\d+" , key ).group (0 )
186- if "_blocks" in key :
187- value_scale = hf_state_dict [key .replace ("_blocks" , "_scales" )]
188- abstract_key = (
189- subtract_key (key ),
190- subtract_key (key .replace ("_blocks" , "_scales" )),
191- )
192- tt_key = self .from_hf_map [abstract_key ]
193- tt_key = tt_key .format (layer_num )
194- dequantized_values = get_mxfp4_tensor (value , value_scale )
195- state_dict [tt_key ] = dequantized_values
196- elif "_scales" not in key :
197- abstract_key = subtract_key (key )
198- tt_key = self .from_hf_map [abstract_key ]
199- tt_key = tt_key .format (layer_num )
200- state_dict [tt_key ] = value
107+ abstract_key = re .sub (r"(\d+)" , "{}" , key , count = 1 )
108+ tt_key = self .from_hf_map [abstract_key ]
109+ tt_key = tt_key .format (layer_num )
110+ state_dict [tt_key ] = value
201111 else :
202112 tt_key = self .from_hf_map [key ]
203113 state_dict [tt_key ] = value
0 commit comments