1+ from typing import List
2+
3+ import torch
4+ import contextlib
5+
6+ from transformers import GptOssForCausalLM
7+ from transformers .models .gpt_oss .modeling_gpt_oss import GptOssExperts
8+ from transformers .models .gpt_oss .configuration_gpt_oss import GptOssConfig
9+ from llmcompressor .utils .dev import skip_weights_initialize
10+
11+ from compressed_tensors .utils import update_offload_parameter , align_module_device
12+
13+
14+ class GptOssExpert (torch .nn .Module ):
15+ gate_proj : torch .nn .Linear
16+ up_proj : torch .nn .Linear
17+ down_proj : torch .nn .Linear
18+
19+ def __init__ (self , experts : GptOssExperts ):
20+ super ().__init__ ()
21+
22+ self .hidden_size = experts .hidden_size
23+ self .expert_dim = experts .expert_dim
24+ self .alpha = experts .alpha
25+ self .limit = experts .limit
26+
27+ assert experts .gate_up_proj .dtype == experts .gate_up_proj_bias .dtype
28+ assert experts .down_proj .dtype == experts .down_proj_bias .dtype
29+
30+ with skip_weights_initialize ():
31+ self .gate_proj = torch .nn .Linear (self .hidden_size , self .expert_dim , bias = True , dtype = experts .gate_up_proj .dtype )
32+ self .up_proj = torch .nn .Linear (self .hidden_size , self .expert_dim , bias = True , dtype = experts .gate_up_proj .dtype )
33+ self .down_proj = torch .nn .Linear (self .expert_dim , self .hidden_size , bias = True , dtype = experts .down_proj .dtype )
34+
35+ def forward (self , hidden_states : torch .Tensor ):
36+ gate = self .gate_proj (hidden_states )
37+ gate = gate .clamp (min = None , max = self .limit )
38+
39+ up = self .up_proj (hidden_states )
40+ up = up .clamp (min = - self .limit , max = self .limit )
41+
42+ glu = gate * torch .sigmoid (gate * self .alpha )
43+ return self .down_proj ((up + 1 ) * glu )
44+
45+
46+
47+ class GptOssExpertsLinear (torch .nn .Module ):
48+ experts : List [GptOssExpert ]
49+
50+ def __init__ (self , experts : GptOssExperts ):
51+ super ().__init__ ()
52+
53+ self .intermediate_size = experts .intermediate_size
54+ self .num_experts = experts .num_experts
55+ self .hidden_size = experts .hidden_size
56+ self .expert_dim = experts .expert_dim
57+
58+ with skip_weights_initialize ():
59+ self .experts = torch .nn .ModuleList ([GptOssExpert (experts ) for _ in range (self .num_experts )])
60+
61+ self .load_weights (experts )
62+
63+ self .alpha = experts .alpha
64+ self .limit = experts .limit
65+
66+ def load_weights (self , experts : GptOssExperts ):
67+ with align_module_device (experts ):
68+ for expert_index , expert in enumerate (self .experts ):
69+ update_offload_parameter (expert .gate_proj , "weight" , experts .gate_up_proj [expert_index , ..., ::2 ].T )
70+ update_offload_parameter (expert .gate_proj , "bias" , experts .gate_up_proj_bias [expert_index , ..., ::2 ])
71+
72+ update_offload_parameter (expert .up_proj , "weight" , experts .gate_up_proj [expert_index , ..., 1 ::2 ].T )
73+ update_offload_parameter (expert .up_proj , "bias" , experts .gate_up_proj_bias [expert_index , ..., 1 ::2 ])
74+
75+ update_offload_parameter (expert .down_proj , "weight" , experts .down_proj [expert_index ].T )
76+ update_offload_parameter (expert .down_proj , "bias" , experts .down_proj_bias [expert_index ])
77+
78+ def to_original (self ) -> GptOssExperts :
79+ # TODO: this doesn't really handle offloading or correct device placement
80+ with skip_weights_initialize (use_zeros = True ):
81+ fake_config = GptOssConfig (
82+ intermediate_size = self .intermediate_size ,
83+ num_local_experts = self .num_experts ,
84+ hidden_size = self .hidden_size ,
85+ )
86+ experts = GptOssExperts (fake_config )
87+ experts .gate_up_proj = torch .nn .Parameter (experts .gate_up_proj .to (dtype = self .experts [0 ].gate_proj .weight .dtype ), requires_grad = False )
88+ experts .gate_up_proj_bias = torch .nn .Parameter (experts .gate_up_proj_bias .to (dtype = self .experts [0 ].gate_proj .weight .dtype ), requires_grad = False )
89+ experts .down_proj = torch .nn .Parameter (experts .down_proj .to (dtype = self .experts [0 ].down_proj .weight .dtype ), requires_grad = False )
90+ experts .down_proj_bias = torch .nn .Parameter (experts .down_proj_bias .to (dtype = self .experts [0 ].down_proj .weight .dtype ), requires_grad = False )
91+
92+ for expert_index , expert in enumerate (self .experts ):
93+ with align_module_device (expert .gate_proj , "cpu" ), align_module_device (expert .up_proj , "cpu" ), align_module_device (expert .down_proj , "cpu" ):
94+ experts .gate_up_proj [expert_index , ..., ::2 ].copy_ (expert .gate_proj .weight .data .T )
95+ experts .gate_up_proj_bias [expert_index , ..., ::2 ].copy_ (expert .gate_proj .bias .data )
96+
97+ experts .gate_up_proj [expert_index , ..., 1 ::2 ].copy_ (expert .up_proj .weight .data .T )
98+ experts .gate_up_proj_bias [expert_index , ..., 1 ::2 ].copy_ (expert .up_proj .bias .data )
99+
100+ experts .down_proj [expert_index ].copy_ (expert .down_proj .weight .data .T )
101+ experts .down_proj_bias [expert_index ].copy_ (expert .down_proj .bias .data )
102+
103+ print ("converted, for some reason slows down over time" )
104+ import time
105+ print (time .time ())
106+
107+ experts .eval ()
108+ return experts
109+
110+
111+ def forward (self , hidden_states : torch .Tensor , router_indices = None , routing_weights = None ) -> torch .Tensor :
112+ """
113+ When training is is more efficient to just loop over the experts and compute the output for each expert
114+ as otherwise the memory would explode.
115+
116+ For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
117+
118+ Args:
119+ hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
120+ selected_experts (torch.Tensor): (batch_size * token_num, top_k)
121+ routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
122+ Returns:
123+ torch.Tensor
124+ """
125+ original_shape = hidden_states .shape
126+ hidden_states = hidden_states .reshape (- 1 , self .hidden_size ) # (num_tokens, hidden_size)
127+
128+ next_states = torch .zeros_like (hidden_states , dtype = hidden_states .dtype , device = hidden_states .device )
129+ for expert_index , expert in enumerate (self .experts ):
130+ next_states += expert (hidden_states ) * routing_weights .T [expert_index ].unsqueeze (- 1 )
131+
132+ next_states = next_states .reshape (original_shape )
133+ return next_states
134+
135+ def replace_gpt_oss (config : GptOssConfig , module : GptOssExpert ):
136+ return GptOssExpertsLinear (module )
137+
138+
139+ def test_restore ():
140+ config = GptOssConfig (hidden_size = 7 , num_local_experts = 3 , expert_dim = 5 )
141+
142+ original = GptOssExperts (config )
143+ linear = GptOssExpertsLinear (original )
144+
145+ restored = linear .to_original ()
146+ for param_name , param in original .named_parameters (recurse = False ):
147+ restored_param = getattr (restored , param_name )
148+ assert param .shape == restored_param .shape
149+ assert param .dtype == restored_param .dtype
150+
151+ assert torch .all (getattr (restored , param_name ) == param )
152+
153+
154+ def test_correctness ():
155+ batch_size , seq_len = 13 , 12
156+ config = GptOssConfig (hidden_size = 7 , num_local_experts = 3 , expert_dim = 5 )
157+
158+ input = torch .rand ((batch_size , seq_len , config .hidden_size ))
159+ routing_weights = torch .rand ((batch_size * seq_len , config .num_local_experts ))
160+
161+ with torch .no_grad ():
162+ original = GptOssExperts (config )
163+ for name in ["gate_up_proj" , "gate_up_proj_bias" , "down_proj" , "down_proj_bias" ]:
164+ setattr (original , name , getattr (original , name ).normal_ ())
165+
166+ original .eval ()
167+ assert original .training == False
168+ true_output = original (input , routing_weights = routing_weights )
169+
170+ linear = GptOssExpertsLinear (original )
171+ output = linear (input , routing_weights = routing_weights )
172+
173+ assert torch .allclose (output , true_output , atol = 1e-3 , rtol = 0.0 )
174+
175+ restored = linear .to_original ()
176+ restored_output = restored (input , routing_weights = routing_weights )
177+ assert torch .allclose (restored_output , true_output , atol = 1e-3 , rtol = 0.0 )
178+
179+
180+ if __name__ == "__main__" :
181+ test_restore ()
0 commit comments