Skip to content

Commit 2ade913

Browse files
committed
fixing lora testing
Signed-off-by: Vahid Janfaza <vjanfaza@qti.qualcomm.com>
1 parent 0b88a32 commit 2ade913

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

QEfficient/peft/lora/layers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,15 @@ def forward(self, x: torch.Tensor, lora_ids: torch.Tensor):
4242
# multilora implementation: lora_ids <batch_size, 1>
4343
other_indices_a = torch.arange(self.lora_a_weights.shape[2]).view(1, 1, -1)
4444
selected_lora_a_weights = CtxGatherFuncCB.apply(
45-
self.lora_a_weights, lora_ids, other_indices_a
45+
self.lora_a_weights, lora_ids, other_indices_a, self.lora_a_weights.shape[2]
4646
) # <num_loras, 1, feature, r>
4747
other_indices_b = torch.arange(self.lora_b_weights.shape[2]).view(1, 1, -1)
4848
selected_lora_b_weights = CtxGatherFuncCB.apply(
49-
self.lora_b_weights, lora_ids, other_indices_b
49+
self.lora_b_weights, lora_ids, other_indices_b, self.lora_b_weights.shape[2]
5050
) # <num_loras, 1, r, feature>
5151
other_indices_s = torch.arange(self.lora_scalings.shape[2]).view(1, 1, -1)
5252
selected_lora_scalings = CtxGatherFuncCB.apply(
53-
self.lora_scalings, lora_ids, other_indices_s
53+
self.lora_scalings, lora_ids, other_indices_s, self.lora_scalings.shape[2]
5454
) # <num_loras, 1, 1, 1>
5555

5656
selected_lora_a_weights = selected_lora_a_weights.squeeze(1)

0 commit comments

Comments
 (0)