Skip to content

Commit b8fb668

Browse files
Merge pull request #2604 from AI-Hypercomputer:rbierneni-qwen3-next-fullattention
PiperOrigin-RevId: 828972171
2 parents e3ddb1a + 9d13e0d commit b8fb668

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

src/MaxText/layers/attentions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -481,9 +481,6 @@ def __init__(
481481
else:
482482
self.sinks = None
483483

484-
self.query_norm = None
485-
self.key_norm = None
486-
487484
is_llama4_decoder_block = self.config.decoder_block == DecoderBlockType.LLAMA4
488485
if self.use_qk_norm and not is_llama4_decoder_block:
489486
self.query_norm = RMSNorm(
@@ -519,6 +516,9 @@ def __init__(
519516
weight_dtype=self.config.weight_dtype,
520517
rngs=self.rngs,
521518
)
519+
else:
520+
self.query_norm = None
521+
self.key_norm = None
522522

523523
self._maybe_shard_with_logical = functools.partial(
524524
maybe_shard_with_logical,

tests/train_compile_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,3 +701,19 @@ def test_gpt3_6b(self):
701701
"per_device_batch_size=1",
702702
)
703703
)
704+
705+
@pytest.mark.cpu_only
706+
def test_qwen3_qk_norm(self):
707+
"""AOT test for non-llama qk norm models"""
708+
compiled_trainstep_file = "/tmp/test_qwen3_qk_norm"
709+
train_compile_main(
710+
(
711+
"",
712+
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
713+
f"compiled_trainstep_file={compiled_trainstep_file}",
714+
"compile_topology=v5p-8",
715+
"compile_topology_num_slices=1",
716+
"model_name=qwen3-0.6b",
717+
"per_device_batch_size=1",
718+
)
719+
)

0 commit comments

Comments
 (0)