Skip to content

Commit 3eaf729

Browse files
committed
F.sdpa for visformer fails w/o contiguous on qkv, make experimental
1 parent cf1884b commit 3eaf729

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

timm/models/visformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(self, dim, num_heads=8, head_dim_ratio=1., attn_drop=0., proj_drop=
8080
head_dim = round(dim // num_heads * head_dim_ratio)
8181
self.head_dim = head_dim
8282
self.scale = head_dim ** -0.5
83-
self.fused_attn = use_fused_attn()
83+
self.fused_attn = use_fused_attn(experimental=True)
8484

8585
self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False)
8686
self.attn_drop = nn.Dropout(attn_drop)
@@ -94,7 +94,7 @@ def forward(self, x):
9494

9595
if self.fused_attn:
9696
x = torch.nn.functional.scaled_dot_product_attention(
97-
q, k, v,
97+
q.contiguous(), k.contiguous(), v.contiguous(),
9898
dropout_p=self.attn_drop.p,
9999
)
100100
else:

0 commit comments

Comments
 (0)