@@ -224,7 +224,7 @@ def __init__(
224224 qk_norm : bool = False ,
225225 attn_drop : float = 0. ,
226226 proj_drop : float = 0. ,
227- input_norm_layer : nn .Module = partial (LayerNorm2d , eps = 1e-5 ),
227+ input_norm_layer : nn .Module = partial (LayerNorm , eps = 1e-5 ),
228228 norm_layer : nn .Module = partial (LayerNorm , eps = 1e-5 ),
229229 init_values : Optional [float ] = None ,
230230 drop_path : float = 0. ,
@@ -326,7 +326,7 @@ def __init__(
326326 qk_norm : bool = False ,
327327 attn_drop : float = 0. ,
328328 proj_drop : float = 0. ,
329- input_norm_layer = LayerNorm2d ,
329+ input_norm_layer = LayerNorm ,
330330 norm_layer : nn .Module = LayerNorm ,
331331 init_values : Optional [float ] = None ,
332332 drop_path_rates : List [float ] = [0. ],
@@ -417,7 +417,7 @@ def __init__(
417417 qk_norm : bool = False ,
418418 attn_drop : float = 0. ,
419419 proj_drop : float = 0. ,
420- input_norm_layer = partial (LayerNorm2d , eps = 1e-5 ),
420+ input_norm_layer = partial (LayerNorm , eps = 1e-5 ),
421421 norm_layer : nn .Module = partial (LayerNorm , eps = 1e-5 ),
422422 init_values : Optional [float ] = None ,
423423 drop_path_rate : float = 0. ,
0 commit comments