22import torch
33import torch .nn as nn
44
5- from timm .layers import create_act_layer , set_layer_config , get_act_layer , get_act_fn , MultiQueryAttentionV2
5+ from timm .layers import create_act_layer , set_layer_config , get_act_layer , get_act_fn , Attention2d , MultiQueryAttentionV2
66
77import importlib
88import os
@@ -120,6 +120,7 @@ def test_get_act_fn_none():
120120 assert get_act_fn (None ) is None
121121 assert get_act_fn ('' ) is None
122122
123+
123124@pytest .mark .parametrize ("dim" , [128 ])
124125@pytest .mark .parametrize ("dim_out" , [128 , 256 ])
125126@pytest .mark .parametrize ("use_m" , [True , False ])
@@ -134,4 +135,26 @@ def test_mqa_v2(dim, dim_out, use_m):
134135
135136 y = mqa (x , m = m )
136137
137- assert (y .shape ) == (1 , dim_out , 32 , 48 )
138+ assert (y .shape ) == (1 , dim_out , 32 , 48 )
139+
140+
141+ @pytest .mark .parametrize ("bias" , [True , False ])
142+ @pytest .mark .parametrize ("expand_first" , [True , False ])
143+ @pytest .mark .parametrize ("head_first" , [True , False ])
144+ @pytest .mark .parametrize ("attn_mask" , [True , False ])
145+ def test_attn2d (bias , expand_first , head_first , attn_mask ):
146+ x = torch .randn (1 , 128 , 32 , 48 )
147+ attn = Attention2d (
148+ 128 , 128 , num_heads = 4 , bias = bias , expand_first = expand_first , head_first = head_first
149+ )
150+
151+ if attn_mask :
152+ mask = torch .randint (0 , 1 , size = (32 * 48 , 32 * 48 ), dtype = torch .float32 )
153+ else :
154+ mask = None
155+
156+ o1 = attn (x , mask )
157+ attn .fused_attn = False
158+ o2 = attn (x , mask )
159+
160+ assert torch .allclose (o1 , o2 , atol = 1e-5 ), f"{ torch .abs (o1 - o2 ).max ()} "
0 commit comments