77from transformers .models .gpt_oss .configuration_gpt_oss import GptOssConfig
88from llmcompressor .utils .dev import skip_weights_initialize
99
10+ from compressed_tensors import update_offload_parameter
11+
1012
1113class GptOssExpert (torch .nn .Module ):
1214 def __init__ (self , hidden_size : int , expert_dim : int , alpha : float , limit : float ):
@@ -56,18 +58,42 @@ def __init__(self, experts: GptOssExpert):
5658
5759 def load_weights (self , experts : GptOssExperts ):
5860 for expert_index , expert in enumerate (self .experts ):
59- expert .gate_proj .weight .data = experts .gate_up_proj [expert_index , ..., ::2 ].data .T
60- expert .gate_proj .bias .data = experts .gate_up_proj_bias [expert_index , ..., ::2 ].data
61-
62- expert .up_proj .weight .data = experts .gate_up_proj [expert_index , ..., 1 ::2 ].data .T
63- expert .up_proj .bias .data = experts .gate_up_proj_bias [expert_index , ..., 1 ::2 ].data
61+ update_offload_parameter (expert .gate_proj , "weight" , experts .gate_up_proj [expert_index , ..., ::2 ].T )
62+ update_offload_parameter (expert .gate_proj , "bias" , experts .gate_up_proj_bias [expert_index , ..., ::2 ])
6463
65- expert .down_proj . weight . data = experts .down_proj [expert_index ].T
66- expert .down_proj . bias . data = experts .down_proj_bias [expert_index ]
64+ update_offload_parameter ( expert .up_proj , " weight" , experts .gate_up_proj [expert_index , ..., 1 :: 2 ].T )
65+ update_offload_parameter ( expert .up_proj , " bias" , experts .gate_up_proj_bias [expert_index , ..., 1 :: 2 ])
6766
67+ update_offload_parameter (expert .down_proj , "weight" , experts .down_proj [expert_index ].T )
68+ update_offload_parameter (expert .down_proj , "bias" , experts .down_proj_bias [expert_index ])
6869
6970 def to_original (self ) -> GptOssExperts :
70- pass
71+ with skip_weights_initialize ():
72+ fake_config = GptOssConfig (
73+ intermediate_size = self .intermediate_size ,
74+ num_local_experts = self .num_experts ,
75+ hidden_size = self .hidden_size ,
76+
77+ )
78+ experts = GptOssExperts (fake_config )
79+
80+ for expert_index , expert in enumerate (self .experts ):
81+ experts .gate_up_proj [expert_index , ..., ::2 ].data = expert .gate_proj .weight .data .T
82+ experts .gate_up_proj_bias [expert_index , ..., ::2 ].data = expert .gate_proj .bias .data
83+
84+ experts .gate_up_proj [expert_index , ..., 1 ::2 ].data = expert .up_proj .weight .data .T
85+ experts .gate_up_proj_bias [expert_index , ..., 1 ::2 ].data = expert .up_proj .bias .data
86+
87+ experts .down_proj [expert_index ].data = expert .down_proj .weight .data .T
88+ experts .down_proj_bias [expert_index ] = expert .down_proj .bias .data
89+
90+ # update offloaded state dict
91+ update_offload_parameter (experts , "gate_up_proj" , experts .gate_up_proj )
92+ update_offload_parameter (experts , "gate_up_proj_bias" , experts .gate_up_proj_bias )
93+ update_offload_parameter (experts , "down_proj" , experts .down_proj )
94+ update_offload_parameter (experts , "down_proj_bias" , experts .down_proj_bias )
95+
96+ return experts
7197
7298
7399 def forward (self , hidden_states : torch .Tensor , router_indices = None , routing_weights = None ) -> torch .Tensor :
@@ -113,5 +139,8 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
113139 linear = GptOssExpertsLinear (original )
114140 output = linear (input , routing_weights = routing_weights )
115141
116- breakpoint ()
117- assert torch .allclose (output , true_output , atol = 1e-3 , rtol = 0.0 )
142+ assert torch .allclose (output , true_output , atol = 1e-3 , rtol = 0.0 )
143+
144+ restored = linear .to_original ()
145+ restored_output = linear (input , routing_weights = routing_weights )
146+ assert torch .allclose (restored_output , true_output , atol = 1e-3 , rtol = 0.0 )
0 commit comments