@@ -62,8 +62,8 @@ def apply(self, router_logits) -> (torch.Tensor, torch.Tensor):
6262 raise NotImplementedError (f"Not support balance_method { self .balance_method } " )
6363 return token_selected_experts , token_final_scales
6464
65- @functools .cache
6665 @staticmethod
66+ @functools .cache
6767 def get_balanced_selection (num_tokens , top_k , num_experts , dtype , world_size , rank ):
6868 a = torch .arange (num_tokens * world_size * top_k , dtype = dtype , device = "cuda" ).view (
6969 num_tokens , world_size , top_k
@@ -90,8 +90,8 @@ def apply_balance_ratio(imbalanced_experts, num_experts, balance_ratio, world_si
9090 mixed_experts [num_balanced_tokens :] = imbalanced_experts [num_balanced_tokens :]
9191 return mixed_experts
9292
93- @functools .cache
9493 @staticmethod
94+ @functools .cache
9595 def get_all_to_one_selection (
9696 num_tokens , top_k , num_experts , balance_ratio , dtype , world_size , rank
9797 ):
@@ -103,8 +103,8 @@ def get_all_to_one_selection(
103103 imbalanced_experts , num_experts , balance_ratio , world_size , rank
104104 )
105105
106- @functools .cache
107106 @staticmethod
107+ @functools .cache
108108 def get_balanced_rank_imbalanced_expert_selection (
109109 num_tokens , top_k , num_experts , balance_ratio , dtype , world_size , rank
110110 ):
0 commit comments