@@ -138,14 +138,14 @@ def __call__(self, xq, xk, xv, mask, cache):
138138 """
139139 Args:
140140 xq: torch.Tensor of (batch size, num_heads, seqlen, head_dim)
141- xk: torch.Tensor of (batch size, num_heads , seqlen, head_dim)
142- xv: torch.Tensor of (batch size, num_heads , seqlen, head_dim)
141+ xk: torch.Tensor of (batch size, num_kv_heads , seqlen, head_dim)
142+ xv: torch.Tensor of (batch size, num_kv_heads , seqlen, head_dim)
143143 mask: mask with 0 and -inf, or None
144144 cache: CacheManagerInterface object
145145 """
146146 bsz , num_heads , seqlen , head_dim = xq .shape
147- _ , _ , _ , kv_head_dim = xk .shape
148- n_rep = head_dim // kv_head_dim
147+ _ , num_kv_heads , _ , kv_head_dim = xk .shape
148+ n_rep = num_heads // num_kv_heads
149149 if seqlen == 1 :
150150 xq = torch .broadcast_to (xq , (xq .shape [0 ], xq .shape [1 ], 2 , xq .shape [3 ]))
151151
@@ -191,14 +191,14 @@ def __call__(self, xq, xk, xv, mask, cache):
191191 """
192192 Args:
193193 xq: torch.Tensor of (batch size, num_heads, seqlen, head_dim)
194- xk: torch.Tensor of (batch size, num_heads , seqlen, head_dim)
195- xv: torch.Tensor of (batch size, num_heads , seqlen, head_dim)
194+ xk: torch.Tensor of (batch size, num_kv_heads , seqlen, head_dim)
195+ xv: torch.Tensor of (batch size, num_kv_heads , seqlen, head_dim)
196196 mask: mask with 0 and -inf, or None
197197 cache: CacheManagerInterface object
198198 """
199199 bsz , num_heads , seqlen , head_dim = xq .shape
200- _ , _ , _ , kv_head_dim = xk .shape
201- n_rep = head_dim // kv_head_dim
200+ _ , num_kv_heads , _ , kv_head_dim = xk .shape
201+ n_rep = num_heads // num_kv_heads
202202 if seqlen == 1 :
203203 xq = torch .broadcast_to (xq , (xq .shape [0 ], xq .shape [1 ], 2 , xq .shape [3 ]))
204204
0 commit comments