77from transformers .models .gpt_oss .configuration_gpt_oss import GptOssConfig
88from llmcompressor .utils .dev import skip_weights_initialize
99
10- from compressed_tensors import update_offload_parameter
10+ from compressed_tensors . utils import update_offload_parameter , align_module_device
1111
1212
1313class GptOssExpert (torch .nn .Module ):
14+ gate_proj : torch .nn .Linear
15+ up_proj : torch .nn .Linear
16+ down_proj : torch .nn .Linear
17+
1418 def __init__ (self , hidden_size : int , expert_dim : int , alpha : float , limit : float ):
1519 super ().__init__ ()
1620
@@ -57,17 +61,21 @@ def __init__(self, experts: GptOssExpert):
5761 self .limit = experts .limit
5862
5963 def load_weights (self , experts : GptOssExperts ):
60- for expert_index , expert in enumerate (self .experts ):
61- update_offload_parameter (expert .gate_proj , "weight" , experts .gate_up_proj [expert_index , ..., ::2 ].T )
62- update_offload_parameter (expert .gate_proj , "bias" , experts .gate_up_proj_bias [expert_index , ..., ::2 ])
64+ # TODO: this code is inefficient. If there was a "get_offloaded_data" util,
65+ # we could avoid having to move from cpu -> gpu -> cpu
66+ with align_module_device (experts ):
67+ for expert_index , expert in enumerate (self .experts ):
68+ update_offload_parameter (expert .gate_proj , "weight" , experts .gate_up_proj [expert_index , ..., ::2 ].T )
69+ update_offload_parameter (expert .gate_proj , "bias" , experts .gate_up_proj_bias [expert_index , ..., ::2 ])
6370
64- update_offload_parameter (expert .up_proj , "weight" , experts .gate_up_proj [expert_index , ..., 1 ::2 ].T )
65- update_offload_parameter (expert .up_proj , "bias" , experts .gate_up_proj_bias [expert_index , ..., 1 ::2 ])
71+ update_offload_parameter (expert .up_proj , "weight" , experts .gate_up_proj [expert_index , ..., 1 ::2 ].T )
72+ update_offload_parameter (expert .up_proj , "bias" , experts .gate_up_proj_bias [expert_index , ..., 1 ::2 ])
6673
67- update_offload_parameter (expert .down_proj , "weight" , experts .down_proj [expert_index ].T )
68- update_offload_parameter (expert .down_proj , "bias" , experts .down_proj_bias [expert_index ])
74+ update_offload_parameter (expert .down_proj , "weight" , experts .down_proj [expert_index ].T )
75+ update_offload_parameter (expert .down_proj , "bias" , experts .down_proj_bias [expert_index ])
6976
7077 def to_original (self ) -> GptOssExperts :
78+ # TODO: this doesn't really handle offloading or correct device placement
7179 with skip_weights_initialize ():
7280 fake_config = GptOssConfig (
7381 intermediate_size = self .intermediate_size ,
@@ -78,14 +86,17 @@ def to_original(self) -> GptOssExperts:
7886 experts = GptOssExperts (fake_config )
7987
8088 for expert_index , expert in enumerate (self .experts ):
81- experts .gate_up_proj [expert_index , ..., ::2 ].data = expert .gate_proj .weight .data .T
82- experts .gate_up_proj_bias [expert_index , ..., ::2 ].data = expert .gate_proj .bias .data
89+ # TODO: this code is inefficient. If there was a "get_offloaded_data" util,
90+ # we could avoid having to move from cpu -> gpu -> cpu
91+ with align_module_device (expert ):
92+ experts .gate_up_proj [expert_index , ..., ::2 ].data = expert .gate_proj .weight .data .T
93+ experts .gate_up_proj_bias [expert_index , ..., ::2 ].data = expert .gate_proj .bias .data
8394
84- experts .gate_up_proj [expert_index , ..., 1 ::2 ].data = expert .up_proj .weight .data .T
85- experts .gate_up_proj_bias [expert_index , ..., 1 ::2 ].data = expert .up_proj .bias .data
95+ experts .gate_up_proj [expert_index , ..., 1 ::2 ].data = expert .up_proj .weight .data .T
96+ experts .gate_up_proj_bias [expert_index , ..., 1 ::2 ].data = expert .up_proj .bias .data
8697
87- experts .down_proj [expert_index ].data = expert .down_proj .weight .data .T
88- experts .down_proj_bias [expert_index ] = expert .down_proj .bias .data
98+ experts .down_proj [expert_index ].data = expert .down_proj .weight .data .T
99+ experts .down_proj_bias [expert_index ] = expert .down_proj .bias .data
89100
90101 # update offloaded state dict
91102 update_offload_parameter (experts , "gate_up_proj" , experts .gate_up_proj )
@@ -134,6 +145,7 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
134145 setattr (original , name , getattr (original , name ).normal_ ())
135146
136147 original .eval ()
148+ assert original .training == False
137149 true_output = original (input , routing_weights = routing_weights )
138150
139151 linear = GptOssExpertsLinear (original )
0 commit comments