@@ -42,15 +42,15 @@ def forward(self, x: torch.Tensor, lora_ids: torch.Tensor):
4242 # multilora implementation: lora_ids <batch_size, 1>
4343 other_indices_a = torch .arange (self .lora_a_weights .shape [2 ]).view (1 , 1 , - 1 )
4444 selected_lora_a_weights = CtxGatherFuncCB .apply (
45- self .lora_a_weights , lora_ids , other_indices_a
45+ self .lora_a_weights , lora_ids , other_indices_a , self . lora_a_weights . shape [ 2 ]
4646 ) # <num_loras, 1, feature, r>
4747 other_indices_b = torch .arange (self .lora_b_weights .shape [2 ]).view (1 , 1 , - 1 )
4848 selected_lora_b_weights = CtxGatherFuncCB .apply (
49- self .lora_b_weights , lora_ids , other_indices_b
49+ self .lora_b_weights , lora_ids , other_indices_b , self . lora_b_weights . shape [ 2 ]
5050 ) # <num_loras, 1, r, feature>
5151 other_indices_s = torch .arange (self .lora_scalings .shape [2 ]).view (1 , 1 , - 1 )
5252 selected_lora_scalings = CtxGatherFuncCB .apply (
53- self .lora_scalings , lora_ids , other_indices_s
53+ self .lora_scalings , lora_ids , other_indices_s , self . lora_scalings . shape [ 2 ]
5454 ) # <num_loras, 1, 1, 1>
5555
5656 selected_lora_a_weights = selected_lora_a_weights .squeeze (1 )
0 commit comments