|
| 1 | +""" ConViT Model |
| 2 | +
|
| 3 | +@article{d2021convit, |
| 4 | + title={ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases}, |
| 5 | + author={d'Ascoli, St{\'e}phane and Touvron, Hugo and Leavitt, Matthew and Morcos, Ari and Biroli, Giulio and Sagun, Levent}, |
| 6 | + journal={arXiv preprint arXiv:2103.10697}, |
| 7 | + year={2021} |
| 8 | +} |
| 9 | +
|
| 10 | +Paper link: https://arxiv.org/abs/2103.10697 |
| 11 | +Original code: https://github.com/facebookresearch/convit, original copyright below |
| 12 | +""" |
| 13 | +# Copyright (c) 2015-present, Facebook, Inc. |
| 14 | +# All rights reserved. |
| 15 | +# |
| 16 | +# This source code is licensed under the CC-by-NC license found in the |
| 17 | +# LICENSE file in the root directory of this source tree. |
| 18 | +# |
| 19 | +'''These modules are adapted from those of timm, see |
| 20 | +https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py |
| 21 | +''' |
| 22 | + |
| 23 | +import torch |
| 24 | +import torch.nn as nn |
| 25 | +from functools import partial |
| 26 | +import torch.nn.functional as F |
| 27 | + |
| 28 | +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
| 29 | +from .helpers import build_model_with_cfg |
| 30 | +from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp |
| 31 | +from .registry import register_model |
| 32 | +from .vision_transformer_hybrid import HybridEmbed |
| 33 | + |
| 34 | +import torch |
| 35 | +import torch.nn as nn |
| 36 | + |
| 37 | + |
| 38 | +def _cfg(url='', **kwargs): |
| 39 | + return { |
| 40 | + 'url': url, |
| 41 | + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, |
| 42 | + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, |
| 43 | + 'first_conv': 'patch_embed.proj', 'classifier': 'head', |
| 44 | + **kwargs |
| 45 | + } |
| 46 | + |
| 47 | + |
| 48 | +default_cfgs = { |
| 49 | + # ConViT |
| 50 | + 'convit_tiny': _cfg( |
| 51 | + url="https://dl.fbaipublicfiles.com/convit/convit_tiny.pth"), |
| 52 | + 'convit_small': _cfg( |
| 53 | + url="https://dl.fbaipublicfiles.com/convit/convit_small.pth"), |
| 54 | + 'convit_base': _cfg( |
| 55 | + url="https://dl.fbaipublicfiles.com/convit/convit_base.pth") |
| 56 | +} |
| 57 | + |
| 58 | + |
| 59 | +class GPSA(nn.Module): |
| 60 | + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., |
| 61 | + locality_strength=1.): |
| 62 | + super().__init__() |
| 63 | + self.num_heads = num_heads |
| 64 | + self.dim = dim |
| 65 | + head_dim = dim // num_heads |
| 66 | + self.scale = qk_scale or head_dim ** -0.5 |
| 67 | + self.locality_strength = locality_strength |
| 68 | + |
| 69 | + self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias) |
| 70 | + self.v = nn.Linear(dim, dim, bias=qkv_bias) |
| 71 | + |
| 72 | + self.attn_drop = nn.Dropout(attn_drop) |
| 73 | + self.proj = nn.Linear(dim, dim) |
| 74 | + self.pos_proj = nn.Linear(3, num_heads) |
| 75 | + self.proj_drop = nn.Dropout(proj_drop) |
| 76 | + self.locality_strength = locality_strength |
| 77 | + self.gating_param = nn.Parameter(torch.ones(self.num_heads)) |
| 78 | + self.rel_indices: torch.Tensor = torch.zeros(1, 1, 1, 3) # silly torchscript hack, won't work with None |
| 79 | + |
| 80 | + def forward(self, x): |
| 81 | + B, N, C = x.shape |
| 82 | + if self.rel_indices is None or self.rel_indices.shape[1] != N: |
| 83 | + self.rel_indices = self.get_rel_indices(N) |
| 84 | + attn = self.get_attention(x) |
| 85 | + v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) |
| 86 | + x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| 87 | + x = self.proj(x) |
| 88 | + x = self.proj_drop(x) |
| 89 | + return x |
| 90 | + |
| 91 | + def get_attention(self, x): |
| 92 | + B, N, C = x.shape |
| 93 | + qk = self.qk(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| 94 | + q, k = qk[0], qk[1] |
| 95 | + pos_score = self.rel_indices.expand(B, -1, -1, -1) |
| 96 | + pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2) |
| 97 | + patch_score = (q @ k.transpose(-2, -1)) * self.scale |
| 98 | + patch_score = patch_score.softmax(dim=-1) |
| 99 | + pos_score = pos_score.softmax(dim=-1) |
| 100 | + |
| 101 | + gating = self.gating_param.view(1, -1, 1, 1) |
| 102 | + attn = (1. - torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score |
| 103 | + attn /= attn.sum(dim=-1).unsqueeze(-1) |
| 104 | + attn = self.attn_drop(attn) |
| 105 | + return attn |
| 106 | + |
| 107 | + def get_attention_map(self, x, return_map=False): |
| 108 | + attn_map = self.get_attention(x).mean(0) # average over batch |
| 109 | + distances = self.rel_indices.squeeze()[:, :, -1] ** .5 |
| 110 | + dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / distances.size(0) |
| 111 | + if return_map: |
| 112 | + return dist, attn_map |
| 113 | + else: |
| 114 | + return dist |
| 115 | + |
| 116 | + def local_init(self): |
| 117 | + self.v.weight.data.copy_(torch.eye(self.dim)) |
| 118 | + locality_distance = 1 # max(1,1/locality_strength**.5) |
| 119 | + |
| 120 | + kernel_size = int(self.num_heads ** .5) |
| 121 | + center = (kernel_size - 1) / 2 if kernel_size % 2 == 0 else kernel_size // 2 |
| 122 | + for h1 in range(kernel_size): |
| 123 | + for h2 in range(kernel_size): |
| 124 | + position = h1 + kernel_size * h2 |
| 125 | + self.pos_proj.weight.data[position, 2] = -1 |
| 126 | + self.pos_proj.weight.data[position, 1] = 2 * (h1 - center) * locality_distance |
| 127 | + self.pos_proj.weight.data[position, 0] = 2 * (h2 - center) * locality_distance |
| 128 | + self.pos_proj.weight.data *= self.locality_strength |
| 129 | + |
| 130 | + def get_rel_indices(self, num_patches: int) -> torch.Tensor: |
| 131 | + img_size = int(num_patches ** .5) |
| 132 | + rel_indices = torch.zeros(1, num_patches, num_patches, 3) |
| 133 | + ind = torch.arange(img_size).view(1, -1) - torch.arange(img_size).view(-1, 1) |
| 134 | + indx = ind.repeat(img_size, img_size) |
| 135 | + indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1) |
| 136 | + indd = indx ** 2 + indy ** 2 |
| 137 | + rel_indices[:, :, :, 2] = indd.unsqueeze(0) |
| 138 | + rel_indices[:, :, :, 1] = indy.unsqueeze(0) |
| 139 | + rel_indices[:, :, :, 0] = indx.unsqueeze(0) |
| 140 | + device = self.qk.weight.device |
| 141 | + return rel_indices.to(device) |
| 142 | + |
| 143 | + |
| 144 | +class MHSA(nn.Module): |
| 145 | + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): |
| 146 | + super().__init__() |
| 147 | + self.num_heads = num_heads |
| 148 | + head_dim = dim // num_heads |
| 149 | + self.scale = qk_scale or head_dim ** -0.5 |
| 150 | + |
| 151 | + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| 152 | + self.attn_drop = nn.Dropout(attn_drop) |
| 153 | + self.proj = nn.Linear(dim, dim) |
| 154 | + self.proj_drop = nn.Dropout(proj_drop) |
| 155 | + |
| 156 | + def get_attention_map(self, x, return_map=False): |
| 157 | + B, N, C = x.shape |
| 158 | + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| 159 | + q, k, v = qkv[0], qkv[1], qkv[2] |
| 160 | + attn_map = (q @ k.transpose(-2, -1)) * self.scale |
| 161 | + attn_map = attn_map.softmax(dim=-1).mean(0) |
| 162 | + |
| 163 | + img_size = int(N ** .5) |
| 164 | + ind = torch.arange(img_size).view(1, -1) - torch.arange(img_size).view(-1, 1) |
| 165 | + indx = ind.repeat(img_size, img_size) |
| 166 | + indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1) |
| 167 | + indd = indx ** 2 + indy ** 2 |
| 168 | + distances = indd ** .5 |
| 169 | + distances = distances.to('cuda') |
| 170 | + |
| 171 | + dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / N |
| 172 | + if return_map: |
| 173 | + return dist, attn_map |
| 174 | + else: |
| 175 | + return dist |
| 176 | + |
| 177 | + def forward(self, x): |
| 178 | + B, N, C = x.shape |
| 179 | + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| 180 | + q, k, v = qkv[0], qkv[1], qkv[2] |
| 181 | + |
| 182 | + attn = (q @ k.transpose(-2, -1)) * self.scale |
| 183 | + attn = attn.softmax(dim=-1) |
| 184 | + attn = self.attn_drop(attn) |
| 185 | + |
| 186 | + x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| 187 | + x = self.proj(x) |
| 188 | + x = self.proj_drop(x) |
| 189 | + return x |
| 190 | + |
| 191 | + |
| 192 | +class Block(nn.Module): |
| 193 | + |
| 194 | + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., |
| 195 | + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs): |
| 196 | + super().__init__() |
| 197 | + self.norm1 = norm_layer(dim) |
| 198 | + self.use_gpsa = use_gpsa |
| 199 | + if self.use_gpsa: |
| 200 | + self.attn = GPSA( |
| 201 | + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, |
| 202 | + proj_drop=drop, **kwargs) |
| 203 | + else: |
| 204 | + self.attn = MHSA( |
| 205 | + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, |
| 206 | + proj_drop=drop, **kwargs) |
| 207 | + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| 208 | + self.norm2 = norm_layer(dim) |
| 209 | + mlp_hidden_dim = int(dim * mlp_ratio) |
| 210 | + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) |
| 211 | + |
| 212 | + def forward(self, x): |
| 213 | + x = x + self.drop_path(self.attn(self.norm1(x))) |
| 214 | + x = x + self.drop_path(self.mlp(self.norm2(x))) |
| 215 | + return x |
| 216 | + |
| 217 | + |
| 218 | +class ConViT(nn.Module): |
| 219 | + """ Vision Transformer with support for patch or hybrid CNN input stage |
| 220 | + """ |
| 221 | + |
| 222 | + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, |
| 223 | + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., |
| 224 | + drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, global_pool=None, |
| 225 | + local_up_to_layer=3, locality_strength=1., use_pos_embed=True): |
| 226 | + super().__init__() |
| 227 | + embed_dim *= num_heads |
| 228 | + self.num_classes = num_classes |
| 229 | + self.local_up_to_layer = local_up_to_layer |
| 230 | + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models |
| 231 | + self.locality_strength = locality_strength |
| 232 | + self.use_pos_embed = use_pos_embed |
| 233 | + |
| 234 | + if hybrid_backbone is not None: |
| 235 | + self.patch_embed = HybridEmbed( |
| 236 | + hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) |
| 237 | + else: |
| 238 | + self.patch_embed = PatchEmbed( |
| 239 | + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) |
| 240 | + num_patches = self.patch_embed.num_patches |
| 241 | + self.num_patches = num_patches |
| 242 | + |
| 243 | + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| 244 | + self.pos_drop = nn.Dropout(p=drop_rate) |
| 245 | + |
| 246 | + if self.use_pos_embed: |
| 247 | + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) |
| 248 | + trunc_normal_(self.pos_embed, std=.02) |
| 249 | + |
| 250 | + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule |
| 251 | + self.blocks = nn.ModuleList([ |
| 252 | + Block( |
| 253 | + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, |
| 254 | + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, |
| 255 | + use_gpsa=True, |
| 256 | + locality_strength=locality_strength) |
| 257 | + if i < local_up_to_layer else |
| 258 | + Block( |
| 259 | + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, |
| 260 | + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, |
| 261 | + use_gpsa=False) |
| 262 | + for i in range(depth)]) |
| 263 | + self.norm = norm_layer(embed_dim) |
| 264 | + |
| 265 | + # Classifier head |
| 266 | + self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')] |
| 267 | + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() |
| 268 | + |
| 269 | + trunc_normal_(self.cls_token, std=.02) |
| 270 | + self.apply(self._init_weights) |
| 271 | + for n, m in self.named_modules(): |
| 272 | + if hasattr(m, 'local_init'): |
| 273 | + m.local_init() |
| 274 | + |
| 275 | + def _init_weights(self, m): |
| 276 | + if isinstance(m, nn.Linear): |
| 277 | + trunc_normal_(m.weight, std=.02) |
| 278 | + if isinstance(m, nn.Linear) and m.bias is not None: |
| 279 | + nn.init.constant_(m.bias, 0) |
| 280 | + elif isinstance(m, nn.LayerNorm): |
| 281 | + nn.init.constant_(m.bias, 0) |
| 282 | + nn.init.constant_(m.weight, 1.0) |
| 283 | + |
| 284 | + @torch.jit.ignore |
| 285 | + def no_weight_decay(self): |
| 286 | + return {'pos_embed', 'cls_token'} |
| 287 | + |
| 288 | + def get_classifier(self): |
| 289 | + return self.head |
| 290 | + |
| 291 | + def reset_classifier(self, num_classes, global_pool=''): |
| 292 | + self.num_classes = num_classes |
| 293 | + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() |
| 294 | + |
| 295 | + def forward_features(self, x): |
| 296 | + B = x.shape[0] |
| 297 | + x = self.patch_embed(x) |
| 298 | + |
| 299 | + cls_tokens = self.cls_token.expand(B, -1, -1) |
| 300 | + |
| 301 | + if self.use_pos_embed: |
| 302 | + x = x + self.pos_embed |
| 303 | + x = self.pos_drop(x) |
| 304 | + |
| 305 | + for u, blk in enumerate(self.blocks): |
| 306 | + if u == self.local_up_to_layer: |
| 307 | + x = torch.cat((cls_tokens, x), dim=1) |
| 308 | + x = blk(x) |
| 309 | + |
| 310 | + x = self.norm(x) |
| 311 | + return x[:, 0] |
| 312 | + |
| 313 | + def forward(self, x): |
| 314 | + x = self.forward_features(x) |
| 315 | + x = self.head(x) |
| 316 | + return x |
| 317 | + |
| 318 | + |
| 319 | +def _create_convit(variant, pretrained=False, **kwargs): |
| 320 | + return build_model_with_cfg( |
| 321 | + ConViT, variant, pretrained, |
| 322 | + default_cfg=default_cfgs[variant], |
| 323 | + **kwargs) |
| 324 | + |
| 325 | + |
| 326 | +@register_model |
| 327 | +def convit_tiny(pretrained=False, **kwargs): |
| 328 | + model_args = dict( |
| 329 | + local_up_to_layer=10, locality_strength=1.0, embed_dim=48, |
| 330 | + num_heads=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
| 331 | + model = _create_convit(variant='convit_tiny', pretrained=pretrained, **model_args) |
| 332 | + return model |
| 333 | + |
| 334 | + |
| 335 | +@register_model |
| 336 | +def convit_small(pretrained=False, **kwargs): |
| 337 | + model_args = dict( |
| 338 | + local_up_to_layer=10, locality_strength=1.0, embed_dim=48, |
| 339 | + num_heads=9, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
| 340 | + model = _create_convit(variant='convit_small', pretrained=pretrained, **model_args) |
| 341 | + return model |
| 342 | + |
| 343 | + |
| 344 | +@register_model |
| 345 | +def convit_base(pretrained=False, **kwargs): |
| 346 | + model_args = dict( |
| 347 | + local_up_to_layer=10, locality_strength=1.0, embed_dim=48, |
| 348 | + num_heads=16, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
| 349 | + model = _create_convit(variant='convit_base', pretrained=pretrained, **model_args) |
| 350 | + return model |
0 commit comments