@@ -75,11 +75,11 @@ def create_quantized_state_dict(self):
7575 cur_state_dict [f"{ fqn } .weight" ] = int8_weight
7676 cur_state_dict [f"{ fqn } .scales" ] = scales .to (mod .weight .dtype )
7777 elif isinstance (mod , ConditionalFeedForward ):
78- num_experts , intermediate_size , dim = mod .w1 .shape
7978 for weight_idx in range (0 , 3 ):
8079 weight_name = f"w{ weight_idx + 1 } "
8180 scales_name = f"scales{ weight_idx + 1 } "
8281 weight = getattr (mod , weight_name )
82+ num_experts , intermediate_size , dim = weight .shape
8383
8484 bit8_weight_list = []
8585 scales_list = []
@@ -125,20 +125,20 @@ def __init__(self, num_experts, intermediate_size, dim, target_dtype):
125125 self .target_dtype = target_dtype
126126
127127 self .register_buffer ("w1" , torch .empty (num_experts , intermediate_size , dim , dtype = target_dtype ))
128- self .register_buffer ("w2" , torch .empty (num_experts , intermediate_size , dim , dtype = target_dtype ))
128+ self .register_buffer ("w2" , torch .empty (num_experts , dim , intermediate_size , dtype = target_dtype ))
129129 self .register_buffer ("w3" , torch .empty (num_experts , intermediate_size , dim , dtype = target_dtype ))
130130
131131 self .register_buffer ("scales1" , torch .empty (num_experts , intermediate_size , dtype = torch .bfloat16 ))
132- self .register_buffer ("scales2" , torch .empty (num_experts , intermediate_size , dtype = torch .bfloat16 ))
132+ self .register_buffer ("scales2" , torch .empty (num_experts , dim , dtype = torch .bfloat16 ))
133133 self .register_buffer ("scales3" , torch .empty (num_experts , intermediate_size , dtype = torch .bfloat16 ))
134134
135135 def forward (self , x , expert_indices ):
136- w1_weights = ( self .w1 .to (x .dtype )[expert_indices ] * self . scales1 [ expert_indices ]. to ( x . dtype ). unsqueeze ( - 1 )). transpose ( - 1 , - 2 ) # [T, A, D, D]
137- w3_weights = ( self .w3 .to (x .dtype )[expert_indices ] * self . scales3 [ expert_indices ]. to ( x . dtype ). unsqueeze ( - 1 )). transpose ( - 1 , - 2 ) # [T, A, D, D]
138- w2_weights = ( self .w2 .to (x .dtype )[expert_indices ] * self . scales2 [ expert_indices ]. to ( x . dtype ). unsqueeze ( - 1 )) # [T, A, D, D ]
139- x1 = F .silu (torch .einsum ('ti,taio -> tao' , x , w1_weights ))
140- x3 = torch .einsum ('ti, taio -> tao' , x , w3_weights )
141- expert_outs = torch .einsum ('tao, taoi -> tai' , (x1 * x3 ), w2_weights )
136+ w1_weights = self .w1 .to (x .dtype )[expert_indices ] # [T, A, D, D]
137+ w3_weights = self .w3 .to (x .dtype )[expert_indices ] # [T, A, D, D]
138+ w2_weights = self .w2 .to (x .dtype )[expert_indices ]
139+ x1 = F .silu (torch .einsum ('ti,taoi -> tao' , x , w1_weights ) * self . scales1 [ expert_indices ]. to ( x . dtype ))
140+ x3 = torch .einsum ('ti, taoi -> tao' , x , w3_weights ) * self . scales3 [ expert_indices ]. to ( x . dtype )
141+ expert_outs = torch .einsum ('tao, taio -> tai' , (x1 * x3 ), w2_weights ) * self . scales2 [ expert_indices ]. to ( x . dtype ) # [T, A, D, D]
142142 return expert_outs
143143
144144
0 commit comments