Skip to content

Commit 122621d

Browse files
committed
Add Final annotation to attn_fas to avoid symbol lookup of new scaled_dot_product_attn fn on old PyTorch in jit
1 parent 621e1b2 commit 122621d

File tree

3 files changed

+11
-0
lines changed

3 files changed

+11
-0
lines changed

timm/models/maxxvit.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
import torch
4444
from torch import nn
45+
from torch.jit import Final
4546

4647
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
4748
from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead
@@ -140,6 +141,8 @@ class MaxxVitCfg:
140141

141142

142143
class Attention2d(nn.Module):
144+
fast_attn: Final[bool]
145+
143146
""" multi-head attention for 2D NCHW tensors"""
144147
def __init__(
145148
self,
@@ -208,6 +211,8 @@ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
208211

209212
class AttentionCl(nn.Module):
210213
""" Channels-last multi-head attention (B, ..., C) """
214+
fast_attn: Final[bool]
215+
211216
def __init__(
212217
self,
213218
dim: int,

timm/models/vision_transformer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import torch.nn as nn
3434
import torch.nn.functional as F
3535
import torch.utils.checkpoint
36+
from torch.jit import Final
3637

3738
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
3839
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
@@ -51,6 +52,8 @@
5152

5253

5354
class Attention(nn.Module):
55+
fast_attn: Final[bool]
56+
5457
def __init__(
5558
self,
5659
dim,

timm/models/vision_transformer_relpos.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import torch
1313
import torch.nn as nn
14+
from torch.jit import Final
1415
from torch.utils.checkpoint import checkpoint
1516

1617
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
@@ -25,6 +26,8 @@
2526

2627

2728
class RelPosAttention(nn.Module):
29+
fast_attn: Final[bool]
30+
2831
def __init__(
2932
self,
3033
dim,

0 commit comments

Comments
 (0)