From 515ac8795d49c85c3b291f6782145686bc9e1bde Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Mon, 21 Jul 2025 21:46:21 -0400 Subject: [PATCH 1/6] fix tests Signed-off-by: Yu Chin Fabian Lim --- tests/kernels/mamba/test_mamba_ssm_ssd.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index 6a3f21ba543f..c3f9f57942d5 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -216,14 +216,14 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, return_final_states=True) # just test the last in sequence - torch.allclose(Y[:, -1], Y_min[:, -1], atol=1e-3, rtol=1e-3) + torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=5e-2, rtol=5e-2) # just test the last head # NOTE, in the kernel we always cast states to fp32 - torch.allclose(final_state[:, -1], + torch.testing.assert_close(final_state[:, -1], final_state_min[:, -1].to(torch.float32), - atol=1e-3, - rtol=1e-3) + atol=1e-2, + rtol=5e-2) @pytest.mark.parametrize("itype", [torch.float32, torch.float16]) @@ -300,7 +300,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, # just test one dim and dstate Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0] Y_min_eg = Y_min[i][:, 0, 0] - torch.allclose(Y_eg, Y_min_eg, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(Y_eg, Y_min_eg, atol=5e-1, rtol=1) # update states states = new_states From 2670f047b267a2fc3a477f0e69fb4685e9d5d2b4 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Mon, 21 Jul 2025 22:06:43 -0400 Subject: [PATCH 2/6] relax Signed-off-by: Yu Chin Fabian Lim --- tests/kernels/mamba/test_mamba_ssm_ssd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index c3f9f57942d5..5a52dd5663d0 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -300,7 +300,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, # just test one dim and dstate Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0] Y_min_eg = Y_min[i][:, 0, 0] - torch.testing.assert_close(Y_eg, Y_min_eg, atol=5e-1, rtol=1) + torch.testing.assert_close(Y_eg, Y_min_eg, atol=5e-1, rtol=5e-1) # update states states = new_states From 32ce905dc04b6ebf558053d0c19d6eebc42ec0c2 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Mon, 21 Jul 2025 22:14:29 -0400 Subject: [PATCH 3/6] fix mixer2 test Signed-off-by: Yu Chin Fabian Lim --- tests/kernels/mamba/test_mamba_mixer2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/kernels/mamba/test_mamba_mixer2.py b/tests/kernels/mamba/test_mamba_mixer2.py index f5c6a18614ff..7f001a8c9c56 100644 --- a/tests/kernels/mamba/test_mamba_mixer2.py +++ b/tests/kernels/mamba/test_mamba_mixer2.py @@ -119,7 +119,7 @@ def mixer2_gated_norm_tensor_parallel( gate_states[..., local_rank * N:(local_rank + 1) * N], ) ref_output = mixer_single_gpu(hidden_states, gate_states) - torch.allclose(output, + torch.testing.assert_close(output, ref_output[..., local_rank * N:(local_rank + 1) * N], - atol=1e-3, + atol=5e-3, rtol=1e-3) From 7ae9b85a664f5e2413f9b4670f39a0b8061c5e4f Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Tue, 22 Jul 2025 15:12:10 +0000 Subject: [PATCH 4/6] fix yapf Signed-off-by: Yu Chin Fabian Lim --- tests/kernels/mamba/test_mamba_mixer2.py | 7 ++++--- tests/kernels/mamba/test_mamba_ssm_ssd.py | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/kernels/mamba/test_mamba_mixer2.py b/tests/kernels/mamba/test_mamba_mixer2.py index 7f001a8c9c56..16c310726ad1 100644 --- a/tests/kernels/mamba/test_mamba_mixer2.py +++ b/tests/kernels/mamba/test_mamba_mixer2.py @@ -120,6 +120,7 @@ def mixer2_gated_norm_tensor_parallel( ) ref_output = mixer_single_gpu(hidden_states, gate_states) torch.testing.assert_close(output, - ref_output[..., local_rank * N:(local_rank + 1) * N], - atol=5e-3, - rtol=1e-3) + ref_output[..., + local_rank * N:(local_rank + 1) * N], + atol=5e-3, + rtol=1e-3) diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index 5a52dd5663d0..0dd8e9aa6e46 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -221,9 +221,9 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, # just test the last head # NOTE, in the kernel we always cast states to fp32 torch.testing.assert_close(final_state[:, -1], - final_state_min[:, -1].to(torch.float32), - atol=1e-2, - rtol=5e-2) + final_state_min[:, -1].to(torch.float32), + atol=1e-2, + rtol=5e-2) @pytest.mark.parametrize("itype", [torch.float32, torch.float16]) From 83565de7d668e517c76e99c2e669153efed377af Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Tue, 22 Jul 2025 14:34:36 -0400 Subject: [PATCH 5/6] adjust thresholds Signed-off-by: Yu Chin Fabian Lim --- tests/kernels/mamba/test_mamba_ssm_ssd.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index 0dd8e9aa6e46..08fb3018a5a6 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -193,6 +193,13 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, # this tests the kernels on a single example (no batching) + # TODO: the bfloat16 case requires higher thresholds. To be investigated + + if itype == torch.bfloat16: + atol, rtol = 5e-2, 5e-2 + else: + atol, rtol = 8e-3, 5e-3 + # set seed batch_size = 1 # batch_size # ssd_minimal_discrete requires chunk_size divide seqlen @@ -216,14 +223,14 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, return_final_states=True) # just test the last in sequence - torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=5e-2, rtol=5e-2) + torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol) # just test the last head # NOTE, in the kernel we always cast states to fp32 torch.testing.assert_close(final_state[:, -1], final_state_min[:, -1].to(torch.float32), - atol=1e-2, - rtol=5e-2) + atol=atol, + rtol=rtol) @pytest.mark.parametrize("itype", [torch.float32, torch.float16]) @@ -262,6 +269,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, # (i.e. chunked prefill) seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases + + # TODO: the irregular chunk size cases have some issues and require higher + # tolerance. This is to be invesigated + if chunk_size not in {8, 256}: + atol, rtol = 5e-1, 5e-1 + else: + atol, rtol = 5e-3, 5e-3 # hold state during the cutting process so we know if an # example has been exhausted and needs to cycle @@ -300,7 +314,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, # just test one dim and dstate Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0] Y_min_eg = Y_min[i][:, 0, 0] - torch.testing.assert_close(Y_eg, Y_min_eg, atol=5e-1, rtol=5e-1) + torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol) # update states states = new_states From 2360bd1a260473248a43bd2baf6127ead0bf9383 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Tue, 22 Jul 2025 22:02:23 +0000 Subject: [PATCH 6/6] yapf Signed-off-by: Yu Chin Fabian Lim --- tests/kernels/mamba/test_mamba_ssm_ssd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index 08fb3018a5a6..00c1a2911d7d 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -269,7 +269,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, # (i.e. chunked prefill) seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases - + # TODO: the irregular chunk size cases have some issues and require higher # tolerance. This is to be invesigated if chunk_size not in {8, 256}: