Skip to content

Commit 517d847

Browse files
wang2yn84lsy323
andauthored
Integrates ragged attention to JetStream Pytorch (#93)
* Stable version of ragged attention. * Converts the attention output types the same as q. * Fixes the typo for the ragged attention. * Provides the default value for partition_by_axis. * Provides mesh to the shard_map. * Fixes typo. * Fixes typo, should be start instead of start_pos. * Should use "//" instead of "/" to get int results. * Use block size // 2 as the starting current position for better initial performance. Fix the typo that should use jax.lax.div instead of jnp.div * Updates the run_interactive script to use the correct result token processing API from JetStream. * Fix typo, should use token_utils.process_result_token. * Fix typo. * Fixes the sampled tokens list. * Use text_tokens_to_str to convert the output tokens. * Reshape the precomputed grid indices to 1D. Removes the dense_attention_quantized and use option to control if it's quantization or not. Use the new torch_xla2 API. * Should check if X is None instead of if X * Fix the dense_attention not returning data. * Reshape the kv scaler to 3 dim for ragged attention. * Cannot stop the input_pos counter from increasing since we are using a ring buffer. Will cause error. * Adds starting_position and profiling_prefill for better testing and benchmarking. * Move flags in scripts to a common function (#92) * refactor flags * clean up: * fix run_server * move common flags to global * format * update * udpate readme * update run_interactive * Stable version of ragged attention. * Fix the merge conflicts * Fixes the missing pieces after merging conflicts. Adds couple of new flags for debugging and performance tuning. * Integrates ragged attention to Gemma too. * Somehow have some local changes to run_interactive, reverting them to align with main. * Set the default value for the newly added parameters. * Adds more descriptions to the ragged attention index precompuation function. * Merges the quantized ragged attention kernel with the non quantized version. * Moves the attention calculation to attention.py for better code structure. * Fix run issues refactoring. * Fix the quantized version for ragged attention. * Fix test_attention by adding default value for the newly added arguments. The error message is missing positional arguments. * Fixes unit tests, changes the Transformer model call argument order(input_pos) back to original to avoid unnecessary issues. * Format attention_kernel.py * Add descrpitions to ragged attention outputs. * Fix quantization tests by adding default value to quantization kernel class. * Reformat attention_kernel.py. Format with black doesn't comply with the pylink rules. * Ignores R0913: Too many arguments link error for ragged attention kernel. Fix other lint errors. * Ignore R0903: Too few public methods. Fix lint errors. * Fix the rest of the lint errors. --------- Co-authored-by: Siyuan Liu <lsiyuan@google.com>
1 parent 65c39d4 commit 517d847

File tree

9 files changed

+679
-66
lines changed

9 files changed

+679
-66
lines changed

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
[MESSAGES CONTROL]
2-
disable=C0114,R0801,E1102,W0613,R1711,too-many-locals
2+
disable=C0114,R0801,R0903,R0913,E1102,W0613,R1711,too-many-locals

jetstream_pt/attention_kernel.py

Lines changed: 347 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
import functools
2+
import math
3+
4+
import jax
5+
import jax.numpy as jnp
6+
from jax.experimental import pallas as pl
7+
from jax.experimental.pallas import tpu as pltpu
8+
from jax.experimental.shard_map import shard_map
9+
10+
import torch
11+
import torch.nn.functional as F
12+
13+
import numpy as np
14+
15+
DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max)
16+
17+
18+
def ragged_flash_attention_kernel(
19+
start_ref,
20+
end_ref,
21+
line_end_ref,
22+
pre_b_ref,
23+
pre_i_ref,
24+
q_ref,
25+
k_ref,
26+
v_ref,
27+
k_scaler_ref,
28+
v_scaler_ref,
29+
o_ref, # outputs
30+
m_ref, # row max
31+
l_ref, # propogation coefficient
32+
bk: int,
33+
mask_value: float,
34+
normalize_var: bool,
35+
quantized: bool,
36+
):
37+
"""Pallas kernel for flash attention."""
38+
with jax.named_scope("attention_kernel"):
39+
b, i = pl.program_id(0), pl.program_id(1)
40+
41+
@pl.when(i == 0)
42+
def init():
43+
with jax.named_scope("init"):
44+
m_ref[...] = jnp.full_like(m_ref, -jnp.inf)
45+
l_ref[...] = jnp.zeros_like(l_ref)
46+
o_ref[...] = jnp.zeros_like(o_ref)
47+
48+
length = line_end_ref[b]
49+
start = start_ref[b]
50+
end = end_ref[b]
51+
52+
@pl.when(jnp.logical_and(i * bk < length, start != end))
53+
def run():
54+
with jax.named_scope("run_qk"):
55+
q = q_ref[...].astype(jnp.float32)
56+
k = k_ref[...].astype(jnp.float32)
57+
v = v_ref[...].astype(jnp.float32)
58+
m_prev, l_prev = m_ref[...], l_ref[...]
59+
60+
qk = jax.lax.dot_general(
61+
q, k, (((1,), (1,)), ((), ())), preferred_element_type=jnp.float32
62+
)
63+
if normalize_var:
64+
qk = qk / jnp.sqrt(k.shape[-1])
65+
if quantized:
66+
qk = qk * k_scaler_ref[...]
67+
with jax.named_scope("run_mask"):
68+
start = start_ref[b]
69+
end = end_ref[b]
70+
iota = jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1)
71+
mask_start_lt_end = jnp.logical_and(
72+
i * bk + iota >= start, i * bk + iota < end
73+
).astype(jnp.int32)
74+
mask_start_gt_end = jnp.logical_or(
75+
i * bk + iota >= start, i * bk + iota < end
76+
).astype(jnp.int32)
77+
# mask = jax.lax.cond(start <= end, lambda: mask_start_lt_end, lambda: mask_start_gt_end)
78+
mask = jnp.where(start <= end, mask_start_lt_end, mask_start_gt_end)
79+
80+
qk = qk + jnp.where(mask, 0.0, mask_value)
81+
82+
with jax.named_scope("run_softmax"):
83+
m_curr = qk.max(axis=-1)
84+
85+
s_curr = jnp.exp(qk - m_curr[..., None])
86+
87+
l_curr = jax.lax.broadcast_in_dim(
88+
s_curr.sum(axis=-1), l_prev.shape, (0,)
89+
)
90+
if quantized:
91+
s_curr = s_curr * v_scaler_ref[...]
92+
o_curr_times_l_curr = jnp.dot(s_curr, v)
93+
m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,))
94+
m_next = jnp.maximum(m_prev, m_curr)
95+
alpha = jnp.exp(m_prev - m_next)
96+
beta = jnp.exp(m_curr - m_next)
97+
l_next = alpha * l_prev + beta * l_curr
98+
l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next)
99+
100+
m_ref[...], l_ref[...] = m_next, l_next_safe
101+
o_ref[...] = (
102+
(l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr)
103+
/ l_next_safe
104+
).astype(o_ref.dtype)
105+
106+
107+
@functools.partial(
108+
jax.jit, static_argnames=["bk", "mask_value", "normalize_var"]
109+
)
110+
def ragged_mqa(
111+
q: jax.Array,
112+
k: jax.Array,
113+
v: jax.Array,
114+
start: jax.Array,
115+
end: jax.Array,
116+
k_scaler: jax.Array | None = None,
117+
v_scaler: jax.Array | None = None,
118+
ragged_batch_index=None,
119+
ragged_block_index=None,
120+
bk: int = 512,
121+
mask_value: float = DEFAULT_MASK_VALUE,
122+
normalize_var: bool = True,
123+
) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:
124+
"""Ragged multi query attention."""
125+
with jax.named_scope("ragged_mqa"):
126+
batch_size, num_heads, head_dim = q.shape
127+
seq_len = k.shape[1]
128+
129+
def kv_index_map(
130+
b,
131+
i,
132+
start_ref,
133+
end_ref,
134+
line_end_ref,
135+
ragged_batch_index_ref,
136+
ragged_block_index_ref,
137+
):
138+
index = b * (seq_len // bk) + i
139+
return ragged_batch_index_ref[index], ragged_block_index_ref[index], 0
140+
141+
def q_index_map(
142+
b,
143+
i,
144+
start_ref,
145+
end_ref,
146+
line_end_ref,
147+
ragged_batch_index_ref,
148+
ragged_block_index_ref,
149+
):
150+
index = b * (seq_len // bk) + i
151+
return ragged_batch_index_ref[index], 0, 0
152+
153+
def scaler_index_map(b, i, *_):
154+
return b, 0, i
155+
156+
line_end = jnp.where(start < end, end, seq_len - 1)
157+
158+
in_specs = [
159+
pl.BlockSpec(q_index_map, (None, num_heads, head_dim)),
160+
pl.BlockSpec(kv_index_map, (None, bk, head_dim)),
161+
pl.BlockSpec(kv_index_map, (None, bk, head_dim)),
162+
]
163+
inputs = (
164+
start,
165+
end,
166+
line_end,
167+
ragged_batch_index,
168+
ragged_block_index,
169+
q,
170+
k,
171+
v,
172+
)
173+
quantized = False
174+
if k_scaler is not None:
175+
in_specs = in_specs + [
176+
pl.BlockSpec(scaler_index_map, (None, 1, bk)),
177+
pl.BlockSpec(scaler_index_map, (None, 1, bk)),
178+
]
179+
inputs = inputs + (k_scaler, v_scaler)
180+
quantized = True
181+
182+
out, m, l = pl.pallas_call(
183+
functools.partial(
184+
ragged_flash_attention_kernel,
185+
bk=bk,
186+
mask_value=mask_value,
187+
normalize_var=normalize_var,
188+
quantized=quantized,
189+
),
190+
grid_spec=pltpu.PrefetchScalarGridSpec(
191+
num_scalar_prefetch=5,
192+
in_specs=in_specs,
193+
out_specs=[
194+
pl.BlockSpec(q_index_map, (None, num_heads, head_dim)),
195+
pl.BlockSpec(q_index_map, (None, num_heads, head_dim)),
196+
pl.BlockSpec(q_index_map, (None, num_heads, head_dim)),
197+
],
198+
grid=(batch_size, seq_len // bk),
199+
),
200+
compiler_params={"dimension_semantics": ("parallel", "arbitrary")},
201+
out_shape=[
202+
q,
203+
jax.ShapeDtypeStruct(
204+
(batch_size, num_heads, head_dim), jnp.float32
205+
),
206+
jax.ShapeDtypeStruct(
207+
(batch_size, num_heads, head_dim), jnp.float32
208+
),
209+
],
210+
)(*inputs)
211+
return out, (m[..., 0], l[..., 0])
212+
213+
214+
@functools.partial(
215+
jax.jit, static_argnames=["bk", "mask_value", "normalize_var", "shard_axis"]
216+
)
217+
def ragged_mha(
218+
q: jax.Array,
219+
k: jax.Array,
220+
v: jax.Array,
221+
start: jax.Array,
222+
end: jax.Array,
223+
ragged_batch_index: jax.Array,
224+
ragged_block_index: jax.Array,
225+
k_scaler: jax.Array | None = None,
226+
v_scaler: jax.Array | None = None,
227+
bk: int = 512,
228+
mask_value: float = DEFAULT_MASK_VALUE,
229+
normalize_var: bool = True,
230+
shard_axis: int = 1,
231+
) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:
232+
"""Ragged multi head attention.
233+
Args:
234+
q: A [batch_size, compute_dim, num_heads, head_dim] jax.Array.
235+
k: A [batch_size, num_heads, seq_len, head_dim] jax.Array or
236+
PartitionQuantizedTensor.
237+
v: A [batch_size, num_heads, seq_len, head_dim] jax.Array or
238+
PartitionQuantizedTensor.
239+
start: A i32[batch_size] jax.Array
240+
end: A i32[batch_size] jax.Array
241+
bk: An integer that is the sequence block size.
242+
logit_cap: An optional float that caps logits via tanh. By default there is
243+
no logit capping.
244+
mask_value: The value used for padding in attention. By default it is a very
245+
negative floating point number.
246+
out_dtype: An optional dtype for the output. If not provided, the output
247+
dtype will be q's dtype.
248+
Returns:
249+
The output of attention([batch_size, num_heads, compute_dim, head_dim]),
250+
along with the max logit ([batch_size, num_heads, compute_dim, 1]) and
251+
softmax denominator ([batch_size, num_heads, compute_dim, 1]).
252+
"""
253+
mask_value = DEFAULT_MASK_VALUE
254+
if k_scaler is None:
255+
replicated_in_axes = 4
256+
replicated_inputs = (ragged_batch_index, ragged_block_index)
257+
else:
258+
replicated_in_axes = 6
259+
replicated_inputs = (
260+
jnp.squeeze(k_scaler, -1),
261+
jnp.squeeze(v_scaler, -1),
262+
ragged_batch_index,
263+
ragged_block_index,
264+
)
265+
266+
with jax.named_scope("ragged_mha_vmap"):
267+
out, (m, l) = jax.vmap(
268+
functools.partial(
269+
ragged_mqa,
270+
bk=bk,
271+
mask_value=mask_value,
272+
normalize_var=normalize_var,
273+
# out_dtype=out_dtype,
274+
),
275+
in_axes=(
276+
shard_axis,
277+
shard_axis,
278+
shard_axis,
279+
*([None] * replicated_in_axes),
280+
),
281+
out_axes=shard_axis,
282+
)(q, k, v, start, end, *replicated_inputs)
283+
return out, (m, l)
284+
285+
286+
def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None):
287+
"""The vanilla attention kernel implementation."""
288+
289+
bsz, _, _, head_dim = xq.shape
290+
with jax.named_scope("attn_mat1"):
291+
## Attention start
292+
# scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim)
293+
scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim)
294+
if k_scaler is not None:
295+
scores = scores * (k_scaler.reshape(bsz, 1, 1, keys.shape[2]))
296+
if mask is not None:
297+
# if mask.shape != (1,1,16,16):
298+
# breakpoint()
299+
scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen)
300+
with jax.named_scope("attn_soft"):
301+
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
302+
if v_scaler is not None:
303+
scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2]))
304+
305+
with jax.named_scope("attn_mat2"):
306+
# output = torch.einsum(
307+
# "ikjm,ikml->ikjl", scores, values
308+
# ) # (bs, n_local_heads, seqlen, head_dim)
309+
output = torch.einsum("ikjm,ikml->ikjl", scores, values)
310+
return output
311+
312+
313+
class RaggedAttentionKernel:
314+
"""Ragged attention kernel."""
315+
316+
def __init__(self, env, input_specs, output_specs, sharding_axis):
317+
self.binded_ragged_mha = functools.partial(
318+
ragged_mha, bk=env.block_size, shard_axis=sharding_axis
319+
)
320+
self.binded_ragged_mha = shard_map(
321+
ragged_mha, env.mesh, input_specs, output_specs, check_rep=False
322+
)
323+
self.binded_ragged_mha = jax.jit(self.binded_ragged_mha)
324+
325+
def __call__(
326+
self,
327+
xq,
328+
keys,
329+
values,
330+
start,
331+
end,
332+
ragged_batch_index,
333+
ragged_block_index,
334+
k_scaler=None,
335+
v_scaler=None,
336+
):
337+
return self.binded_ragged_mha(
338+
xq,
339+
keys,
340+
values,
341+
start,
342+
end,
343+
ragged_batch_index,
344+
ragged_block_index,
345+
k_scaler,
346+
v_scaler,
347+
)

jetstream_pt/config.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,26 @@
5858
lambda value: value in _VALID_QUANTIZATION_TYPE,
5959
f"quantize_type is invalid, supported quantization types are {_VALID_QUANTIZATION_TYPE}",
6060
)
61+
flags.DEFINE_bool(
62+
"profiling_prefill",
63+
False,
64+
"Whether to profile the prefill, "
65+
"if set to false, profile generate function only",
66+
required=False,
67+
)
68+
flags.DEFINE_bool(
69+
"ragged_mha",
70+
False,
71+
"Whether to enable Ragged multi head attention",
72+
required=False,
73+
)
74+
flags.DEFINE_integer(
75+
"starting_position",
76+
512,
77+
"The starting position of decoding, "
78+
"for performance tuning and debugging only",
79+
required=False,
80+
)
6181

6282

6383
def create_quantization_config_from_flags():
@@ -112,6 +132,8 @@ def create_engine_from_config_flags():
112132
max_cache_length=FLAGS.max_cache_length,
113133
sharding_config=sharding_file_name,
114134
shard_on_batch=FLAGS.shard_on_batch,
135+
ragged_mha=FLAGS.ragged_mha,
136+
starting_position=FLAGS.starting_position,
115137
)
116138

117139
print("Initialize engine", time.perf_counter() - start)

0 commit comments

Comments
 (0)