Skip to content

Commit d4c00d6

Browse files
committed
Merge branch 'amaarora-convit'
2 parents 6e04da0 + b7de82e commit d4c00d6

File tree

3 files changed

+353
-1
lines changed

3 files changed

+353
-1
lines changed

tests/test_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
torch._C._jit_set_profiling_mode(False)
1616

1717
# transformer models don't support many of the spatial / feature based model functionalities
18-
NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*']
18+
NON_STD_FILTERS = [
19+
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*']
1920
NUM_NON_STD = len(NON_STD_FILTERS)
2021

2122
# exclude models that cause specific test failures

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .byobnet import *
33
from .cait import *
44
from .coat import *
5+
from .convit import *
56
from .cspnet import *
67
from .densenet import *
78
from .dla import *

timm/models/convit.py

Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
""" ConViT Model
2+
3+
@article{d2021convit,
4+
title={ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases},
5+
author={d'Ascoli, St{\'e}phane and Touvron, Hugo and Leavitt, Matthew and Morcos, Ari and Biroli, Giulio and Sagun, Levent},
6+
journal={arXiv preprint arXiv:2103.10697},
7+
year={2021}
8+
}
9+
10+
Paper link: https://arxiv.org/abs/2103.10697
11+
Original code: https://github.com/facebookresearch/convit, original copyright below
12+
"""
13+
# Copyright (c) 2015-present, Facebook, Inc.
14+
# All rights reserved.
15+
#
16+
# This source code is licensed under the CC-by-NC license found in the
17+
# LICENSE file in the root directory of this source tree.
18+
#
19+
'''These modules are adapted from those of timm, see
20+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
21+
'''
22+
23+
import torch
24+
import torch.nn as nn
25+
from functools import partial
26+
import torch.nn.functional as F
27+
28+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
29+
from .helpers import build_model_with_cfg
30+
from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp
31+
from .registry import register_model
32+
from .vision_transformer_hybrid import HybridEmbed
33+
34+
import torch
35+
import torch.nn as nn
36+
37+
38+
def _cfg(url='', **kwargs):
39+
return {
40+
'url': url,
41+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
42+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
43+
'first_conv': 'patch_embed.proj', 'classifier': 'head',
44+
**kwargs
45+
}
46+
47+
48+
default_cfgs = {
49+
# ConViT
50+
'convit_tiny': _cfg(
51+
url="https://dl.fbaipublicfiles.com/convit/convit_tiny.pth"),
52+
'convit_small': _cfg(
53+
url="https://dl.fbaipublicfiles.com/convit/convit_small.pth"),
54+
'convit_base': _cfg(
55+
url="https://dl.fbaipublicfiles.com/convit/convit_base.pth")
56+
}
57+
58+
59+
class GPSA(nn.Module):
60+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,
61+
locality_strength=1.):
62+
super().__init__()
63+
self.num_heads = num_heads
64+
self.dim = dim
65+
head_dim = dim // num_heads
66+
self.scale = qk_scale or head_dim ** -0.5
67+
self.locality_strength = locality_strength
68+
69+
self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
70+
self.v = nn.Linear(dim, dim, bias=qkv_bias)
71+
72+
self.attn_drop = nn.Dropout(attn_drop)
73+
self.proj = nn.Linear(dim, dim)
74+
self.pos_proj = nn.Linear(3, num_heads)
75+
self.proj_drop = nn.Dropout(proj_drop)
76+
self.locality_strength = locality_strength
77+
self.gating_param = nn.Parameter(torch.ones(self.num_heads))
78+
self.rel_indices: torch.Tensor = torch.zeros(1, 1, 1, 3) # silly torchscript hack, won't work with None
79+
80+
def forward(self, x):
81+
B, N, C = x.shape
82+
if self.rel_indices is None or self.rel_indices.shape[1] != N:
83+
self.rel_indices = self.get_rel_indices(N)
84+
attn = self.get_attention(x)
85+
v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
86+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
87+
x = self.proj(x)
88+
x = self.proj_drop(x)
89+
return x
90+
91+
def get_attention(self, x):
92+
B, N, C = x.shape
93+
qk = self.qk(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
94+
q, k = qk[0], qk[1]
95+
pos_score = self.rel_indices.expand(B, -1, -1, -1)
96+
pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2)
97+
patch_score = (q @ k.transpose(-2, -1)) * self.scale
98+
patch_score = patch_score.softmax(dim=-1)
99+
pos_score = pos_score.softmax(dim=-1)
100+
101+
gating = self.gating_param.view(1, -1, 1, 1)
102+
attn = (1. - torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score
103+
attn /= attn.sum(dim=-1).unsqueeze(-1)
104+
attn = self.attn_drop(attn)
105+
return attn
106+
107+
def get_attention_map(self, x, return_map=False):
108+
attn_map = self.get_attention(x).mean(0) # average over batch
109+
distances = self.rel_indices.squeeze()[:, :, -1] ** .5
110+
dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / distances.size(0)
111+
if return_map:
112+
return dist, attn_map
113+
else:
114+
return dist
115+
116+
def local_init(self):
117+
self.v.weight.data.copy_(torch.eye(self.dim))
118+
locality_distance = 1 # max(1,1/locality_strength**.5)
119+
120+
kernel_size = int(self.num_heads ** .5)
121+
center = (kernel_size - 1) / 2 if kernel_size % 2 == 0 else kernel_size // 2
122+
for h1 in range(kernel_size):
123+
for h2 in range(kernel_size):
124+
position = h1 + kernel_size * h2
125+
self.pos_proj.weight.data[position, 2] = -1
126+
self.pos_proj.weight.data[position, 1] = 2 * (h1 - center) * locality_distance
127+
self.pos_proj.weight.data[position, 0] = 2 * (h2 - center) * locality_distance
128+
self.pos_proj.weight.data *= self.locality_strength
129+
130+
def get_rel_indices(self, num_patches: int) -> torch.Tensor:
131+
img_size = int(num_patches ** .5)
132+
rel_indices = torch.zeros(1, num_patches, num_patches, 3)
133+
ind = torch.arange(img_size).view(1, -1) - torch.arange(img_size).view(-1, 1)
134+
indx = ind.repeat(img_size, img_size)
135+
indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1)
136+
indd = indx ** 2 + indy ** 2
137+
rel_indices[:, :, :, 2] = indd.unsqueeze(0)
138+
rel_indices[:, :, :, 1] = indy.unsqueeze(0)
139+
rel_indices[:, :, :, 0] = indx.unsqueeze(0)
140+
device = self.qk.weight.device
141+
return rel_indices.to(device)
142+
143+
144+
class MHSA(nn.Module):
145+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
146+
super().__init__()
147+
self.num_heads = num_heads
148+
head_dim = dim // num_heads
149+
self.scale = qk_scale or head_dim ** -0.5
150+
151+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
152+
self.attn_drop = nn.Dropout(attn_drop)
153+
self.proj = nn.Linear(dim, dim)
154+
self.proj_drop = nn.Dropout(proj_drop)
155+
156+
def get_attention_map(self, x, return_map=False):
157+
B, N, C = x.shape
158+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
159+
q, k, v = qkv[0], qkv[1], qkv[2]
160+
attn_map = (q @ k.transpose(-2, -1)) * self.scale
161+
attn_map = attn_map.softmax(dim=-1).mean(0)
162+
163+
img_size = int(N ** .5)
164+
ind = torch.arange(img_size).view(1, -1) - torch.arange(img_size).view(-1, 1)
165+
indx = ind.repeat(img_size, img_size)
166+
indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1)
167+
indd = indx ** 2 + indy ** 2
168+
distances = indd ** .5
169+
distances = distances.to('cuda')
170+
171+
dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / N
172+
if return_map:
173+
return dist, attn_map
174+
else:
175+
return dist
176+
177+
def forward(self, x):
178+
B, N, C = x.shape
179+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
180+
q, k, v = qkv[0], qkv[1], qkv[2]
181+
182+
attn = (q @ k.transpose(-2, -1)) * self.scale
183+
attn = attn.softmax(dim=-1)
184+
attn = self.attn_drop(attn)
185+
186+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
187+
x = self.proj(x)
188+
x = self.proj_drop(x)
189+
return x
190+
191+
192+
class Block(nn.Module):
193+
194+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
195+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs):
196+
super().__init__()
197+
self.norm1 = norm_layer(dim)
198+
self.use_gpsa = use_gpsa
199+
if self.use_gpsa:
200+
self.attn = GPSA(
201+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
202+
proj_drop=drop, **kwargs)
203+
else:
204+
self.attn = MHSA(
205+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
206+
proj_drop=drop, **kwargs)
207+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
208+
self.norm2 = norm_layer(dim)
209+
mlp_hidden_dim = int(dim * mlp_ratio)
210+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
211+
212+
def forward(self, x):
213+
x = x + self.drop_path(self.attn(self.norm1(x)))
214+
x = x + self.drop_path(self.mlp(self.norm2(x)))
215+
return x
216+
217+
218+
class ConViT(nn.Module):
219+
""" Vision Transformer with support for patch or hybrid CNN input stage
220+
"""
221+
222+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
223+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
224+
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, global_pool=None,
225+
local_up_to_layer=3, locality_strength=1., use_pos_embed=True):
226+
super().__init__()
227+
embed_dim *= num_heads
228+
self.num_classes = num_classes
229+
self.local_up_to_layer = local_up_to_layer
230+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
231+
self.locality_strength = locality_strength
232+
self.use_pos_embed = use_pos_embed
233+
234+
if hybrid_backbone is not None:
235+
self.patch_embed = HybridEmbed(
236+
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
237+
else:
238+
self.patch_embed = PatchEmbed(
239+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
240+
num_patches = self.patch_embed.num_patches
241+
self.num_patches = num_patches
242+
243+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
244+
self.pos_drop = nn.Dropout(p=drop_rate)
245+
246+
if self.use_pos_embed:
247+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
248+
trunc_normal_(self.pos_embed, std=.02)
249+
250+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
251+
self.blocks = nn.ModuleList([
252+
Block(
253+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
254+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
255+
use_gpsa=True,
256+
locality_strength=locality_strength)
257+
if i < local_up_to_layer else
258+
Block(
259+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
260+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
261+
use_gpsa=False)
262+
for i in range(depth)])
263+
self.norm = norm_layer(embed_dim)
264+
265+
# Classifier head
266+
self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]
267+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
268+
269+
trunc_normal_(self.cls_token, std=.02)
270+
self.apply(self._init_weights)
271+
for n, m in self.named_modules():
272+
if hasattr(m, 'local_init'):
273+
m.local_init()
274+
275+
def _init_weights(self, m):
276+
if isinstance(m, nn.Linear):
277+
trunc_normal_(m.weight, std=.02)
278+
if isinstance(m, nn.Linear) and m.bias is not None:
279+
nn.init.constant_(m.bias, 0)
280+
elif isinstance(m, nn.LayerNorm):
281+
nn.init.constant_(m.bias, 0)
282+
nn.init.constant_(m.weight, 1.0)
283+
284+
@torch.jit.ignore
285+
def no_weight_decay(self):
286+
return {'pos_embed', 'cls_token'}
287+
288+
def get_classifier(self):
289+
return self.head
290+
291+
def reset_classifier(self, num_classes, global_pool=''):
292+
self.num_classes = num_classes
293+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
294+
295+
def forward_features(self, x):
296+
B = x.shape[0]
297+
x = self.patch_embed(x)
298+
299+
cls_tokens = self.cls_token.expand(B, -1, -1)
300+
301+
if self.use_pos_embed:
302+
x = x + self.pos_embed
303+
x = self.pos_drop(x)
304+
305+
for u, blk in enumerate(self.blocks):
306+
if u == self.local_up_to_layer:
307+
x = torch.cat((cls_tokens, x), dim=1)
308+
x = blk(x)
309+
310+
x = self.norm(x)
311+
return x[:, 0]
312+
313+
def forward(self, x):
314+
x = self.forward_features(x)
315+
x = self.head(x)
316+
return x
317+
318+
319+
def _create_convit(variant, pretrained=False, **kwargs):
320+
return build_model_with_cfg(
321+
ConViT, variant, pretrained,
322+
default_cfg=default_cfgs[variant],
323+
**kwargs)
324+
325+
326+
@register_model
327+
def convit_tiny(pretrained=False, **kwargs):
328+
model_args = dict(
329+
local_up_to_layer=10, locality_strength=1.0, embed_dim=48,
330+
num_heads=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
331+
model = _create_convit(variant='convit_tiny', pretrained=pretrained, **model_args)
332+
return model
333+
334+
335+
@register_model
336+
def convit_small(pretrained=False, **kwargs):
337+
model_args = dict(
338+
local_up_to_layer=10, locality_strength=1.0, embed_dim=48,
339+
num_heads=9, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
340+
model = _create_convit(variant='convit_small', pretrained=pretrained, **model_args)
341+
return model
342+
343+
344+
@register_model
345+
def convit_base(pretrained=False, **kwargs):
346+
model_args = dict(
347+
local_up_to_layer=10, locality_strength=1.0, embed_dim=48,
348+
num_heads=16, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
349+
model = _create_convit(variant='convit_base', pretrained=pretrained, **model_args)
350+
return model

0 commit comments

Comments
 (0)