@@ -187,7 +187,7 @@ def end_boundary(n: int):
187187 [torch .float32 , torch .float16 , torch .bfloat16 ])
188188@pytest .mark .parametrize ("n_heads" , [3 , 4 , 11 , 16 , 32 ])
189189@pytest .mark .parametrize ("d_head" , [5 , 8 , 19 , 32 , 128 ])
190- @pytest .mark .parametrize ("seq_len_chunk_size" , [(119 , 17 ), (128 , 32 )])
190+ @pytest .mark .parametrize ("seq_len_chunk_size" , [(112 , 16 ), (128 , 32 )])
191191def test_mamba_chunk_scan_single_example (d_head , n_heads , seq_len_chunk_size ,
192192 itype ):
193193
@@ -253,15 +253,15 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
253253 (8 , 8 , 16 , 32 , 16 ),
254254 ]), # mode examples with varied lengths
255255
256- # odd chunk_size
257- (64 , 29 , 2 , [(11 , 4 ), (13 , 23 ), (19 , 22 ),
258- (21 , 15 )]), # irregular sizes
259-
260256 # large-ish chunk_size (256)
261257 (64 , 256 , 1 , [(5 , ), (1 , ), (1 , ),
262258 (1 , )]), # irregular sizes with small sequences
263259 (64 , 256 , 2 , [(5 , 30 ), (1 , 2 ), (1 , 2 ),
264260 (1 , 2 )]), # irregular sizes with small sequences
261+
262+ # we also need to test some large seqlen
263+ # to catch errors with init states decay
264+ (768 , 128 , 2 , [(138 , 225 ), (138 , 225 )]),
265265 ])
266266def test_mamba_chunk_scan_cont_batch (d_head , n_heads , seq_len_chunk_size_cases ,
267267 itype ):
@@ -271,10 +271,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
271271
272272 seqlen , chunk_size , num_examples , cases = seq_len_chunk_size_cases
273273
274- # TODO: the irregular chunk size cases have some issues and require higher
275- # tolerance. This is to be invesigated
276- if chunk_size not in {8 , 256 }:
277- atol , rtol = 5e-1 , 5e-1
274+ # This test can have larger error for longer sequences
275+ if seqlen > 256 :
276+ atol , rtol = 1e-2 , 5e-3
278277 else :
279278 atol , rtol = 5e-3 , 5e-3
280279
0 commit comments