Skip to content

Commit 3de239b

Browse files
authored
Fix attention kernel of GQA use case (#78)
fix gpa
1 parent 84ac213 commit 3de239b

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

jetstream_pt/layers.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,14 @@ def __call__(self, xq, xk, xv, mask, cache):
138138
"""
139139
Args:
140140
xq: torch.Tensor of (batch size, num_heads, seqlen, head_dim)
141-
xk: torch.Tensor of (batch size, num_heads, seqlen, head_dim)
142-
xv: torch.Tensor of (batch size, num_heads, seqlen, head_dim)
141+
xk: torch.Tensor of (batch size, num_kv_heads, seqlen, head_dim)
142+
xv: torch.Tensor of (batch size, num_kv_heads, seqlen, head_dim)
143143
mask: mask with 0 and -inf, or None
144144
cache: CacheManagerInterface object
145145
"""
146146
bsz, num_heads, seqlen, head_dim = xq.shape
147-
_, _, _, kv_head_dim = xk.shape
148-
n_rep = head_dim // kv_head_dim
147+
_, num_kv_heads, _, kv_head_dim = xk.shape
148+
n_rep = num_heads // num_kv_heads
149149
if seqlen == 1:
150150
xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3]))
151151

@@ -191,14 +191,14 @@ def __call__(self, xq, xk, xv, mask, cache):
191191
"""
192192
Args:
193193
xq: torch.Tensor of (batch size, num_heads, seqlen, head_dim)
194-
xk: torch.Tensor of (batch size, num_heads, seqlen, head_dim)
195-
xv: torch.Tensor of (batch size, num_heads, seqlen, head_dim)
194+
xk: torch.Tensor of (batch size, num_kv_heads, seqlen, head_dim)
195+
xv: torch.Tensor of (batch size, num_kv_heads, seqlen, head_dim)
196196
mask: mask with 0 and -inf, or None
197197
cache: CacheManagerInterface object
198198
"""
199199
bsz, num_heads, seqlen, head_dim = xq.shape
200-
_, _, _, kv_head_dim = xk.shape
201-
n_rep = head_dim // kv_head_dim
200+
_, num_kv_heads, _, kv_head_dim = xk.shape
201+
n_rep = num_heads // num_kv_heads
202202
if seqlen == 1:
203203
xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3]))
204204

0 commit comments

Comments
 (0)