Skip to content

Commit 1f49061

Browse files
Merge pull request #2716 from AI-Hypercomputer:tokamax_splash_sink
PiperOrigin-RevId: 834512715
2 parents 432fb3d + b0d9bc1 commit 1f49061

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

src/MaxText/layers/attention_op.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,12 +1291,10 @@ def wrap_flash_attention(
12911291
decoder_segment_ids_tuple = None
12921292

12931293
if self.config.use_tokamax_splash:
1294-
if max_logit_value is not None:
1295-
attention_output = jax.vmap(partial(splash_kernel, max_logit_value=max_logit_value))(
1296-
query, key, value, decoder_segment_ids_tuple
1297-
)
1298-
else:
1299-
attention_output = jax.vmap(splash_kernel)(query, key, value, decoder_segment_ids_tuple)
1294+
kernel = partial(splash_kernel, max_logit_value=max_logit_value)
1295+
attention_output = jax.vmap(lambda q, k, v, d, s: kernel(q, k, v, d, sinks=s), in_axes=(0, 0, 0, 0, None))(
1296+
query, key, value, decoder_segment_ids_tuple, sinks
1297+
)
13001298
else:
13011299
attention_output = jax.vmap(splash_kernel, in_axes=(0, 0, 0, 0, None))(
13021300
query, key, value, decoder_segment_ids_tuple, sinks

0 commit comments

Comments
 (0)