1+ from typing import List
2+
3+ import torch
4+ import contextlib
5+
6+ from transformers .models .gpt_oss .modeling_gpt_oss import GptOssExperts
7+ from transformers .models .gpt_oss .configuration_gpt_oss import GptOssConfig
8+ from llmcompressor .utils .dev import skip_weights_initialize
9+
10+
11+ class GptOssExpert (torch .nn .Module ):
12+ def __init__ (self , hidden_size : int , expert_dim : int , alpha : float , limit : float ):
13+ super ().__init__ ()
14+
15+ self .hidden_size = hidden_size
16+ self .expert_dim = expert_dim
17+ self .alpha = alpha
18+ self .limit = limit
19+
20+ with skip_weights_initialize ():
21+ self .gate_proj = torch .nn .Linear (self .hidden_size , self .expert_dim , bias = True )
22+ self .up_proj = torch .nn .Linear (self .hidden_size , self .expert_dim , bias = True )
23+ self .down_proj = torch .nn .Linear (self .expert_dim , self .hidden_size , bias = True )
24+
25+
26+ def forward (self , hidden_states : torch .Tensor ):
27+ gate = self .gate_proj (hidden_states )
28+ gate = gate .clamp (min = None , max = self .limit )
29+
30+ up = self .up_proj (hidden_states )
31+ up = up .clamp (min = - self .limit , max = self .limit )
32+
33+ glu = gate * torch .sigmoid (gate * self .alpha )
34+ return self .down_proj ((up + 1 ) * glu )
35+
36+
37+
38+ class GptOssExpertsLinear (torch .nn .Module ):
39+ experts : List [GptOssExpert ]
40+
41+ def __init__ (self , experts : GptOssExpert ):
42+ super ().__init__ ()
43+
44+ self .intermediate_size = experts .intermediate_size
45+ self .num_experts = experts .num_experts
46+ self .hidden_size = experts .hidden_size
47+ self .expert_dim = experts .expert_dim
48+
49+ with skip_weights_initialize ():
50+ self .experts = [GptOssExpert (self .hidden_size , self .expert_dim , experts .alpha , experts .limit ) for _ in range (self .num_experts )]
51+
52+ self .load_weights (experts )
53+
54+ self .alpha = experts .alpha
55+ self .limit = experts .limit
56+
57+ def load_weights (self , experts : GptOssExperts ):
58+ for expert_index , expert in enumerate (self .experts ):
59+ expert .gate_proj .weight .data = experts .gate_up_proj [expert_index , ..., ::2 ].data .T
60+ expert .gate_proj .bias .data = experts .gate_up_proj_bias [expert_index , ..., ::2 ].data
61+
62+ expert .up_proj .weight .data = experts .gate_up_proj [expert_index , ..., 1 ::2 ].data .T
63+ expert .up_proj .bias .data = experts .gate_up_proj_bias [expert_index , ..., 1 ::2 ].data
64+
65+ expert .down_proj .weight .data = experts .down_proj [expert_index ].T
66+ expert .down_proj .bias .data = experts .down_proj_bias [expert_index ]
67+
68+
69+ def to_original (self ) -> GptOssExperts :
70+ pass
71+
72+
73+ def forward (self , hidden_states : torch .Tensor , router_indices = None , routing_weights = None ) -> torch .Tensor :
74+ """
75+ When training is is more efficient to just loop over the experts and compute the output for each expert
76+ as otherwise the memory would explode.
77+
78+ For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
79+
80+ Args:
81+ hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)
82+ selected_experts (torch.Tensor): (batch_size * token_num, top_k)
83+ routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
84+ Returns:
85+ torch.Tensor
86+ """
87+ original_shape = hidden_states .shape
88+ hidden_states = hidden_states .reshape (- 1 , self .hidden_size ) # (num_tokens, hidden_size)
89+
90+ next_states = torch .zeros_like (hidden_states , dtype = hidden_states .dtype , device = hidden_states .device )
91+ for expert_index , expert in enumerate (self .experts ):
92+ next_states += expert (hidden_states ) * routing_weights .T [expert_index ].unsqueeze (- 1 )
93+
94+ next_states = next_states .reshape (original_shape )
95+ return next_states
96+
97+
98+ if __name__ == "__main__" :
99+ batch_size , seq_len = 13 , 12
100+ config = GptOssConfig (hidden_size = 7 , num_local_experts = 3 , expert_dim = 5 )
101+
102+ input = torch .rand ((batch_size , seq_len , config .hidden_size ))
103+ routing_weights = torch .rand ((batch_size * seq_len , config .num_local_experts ))
104+
105+ with torch .no_grad ():
106+ original = GptOssExperts (config )
107+ for name in ["gate_up_proj" , "gate_up_proj_bias" , "down_proj" , "down_proj_bias" ]:
108+ setattr (original , name , getattr (original , name ).normal_ ())
109+
110+ original .eval ()
111+ true_output = original (input , routing_weights = routing_weights )
112+
113+ linear = GptOssExpertsLinear (original )
114+ output = linear (input , routing_weights = routing_weights )
115+
116+ breakpoint ()
117+ assert torch .allclose (output , true_output , atol = 1e-3 , rtol = 0.0 )
0 commit comments