Skip to content

Commit b0d9bc1

Browse files
committed
Add sinks to tokamax splash
1 parent 11d5e08 commit b0d9bc1

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

src/MaxText/layers/attention_op.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,12 +1288,10 @@ def wrap_flash_attention(
12881288
decoder_segment_ids_tuple = None
12891289

12901290
if self.config.use_tokamax_splash:
1291-
if max_logit_value is not None:
1292-
attention_output = jax.vmap(partial(splash_kernel, max_logit_value=max_logit_value))(
1293-
query, key, value, decoder_segment_ids_tuple
1294-
)
1295-
else:
1296-
attention_output = jax.vmap(splash_kernel)(query, key, value, decoder_segment_ids_tuple)
1291+
kernel = partial(splash_kernel, max_logit_value=max_logit_value)
1292+
attention_output = jax.vmap(lambda q, k, v, d, s: kernel(q, k, v, d, sinks=s), in_axes=(0, 0, 0, 0, None))(
1293+
query, key, value, decoder_segment_ids_tuple, sinks
1294+
)
12971295
else:
12981296
attention_output = jax.vmap(splash_kernel, in_axes=(0, 0, 0, 0, None))(
12991297
query, key, value, decoder_segment_ids_tuple, sinks

0 commit comments

Comments
 (0)