|
7 | 7 | import logging |
8 | 8 | import math |
9 | 9 | from functools import partial |
10 | | -from typing import Optional, Tuple |
| 10 | +from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union, List |
| 11 | +try: |
| 12 | + from typing import Literal |
| 13 | +except ImportError: |
| 14 | + from typing_extensions import Literal |
11 | 15 |
|
12 | 16 | import torch |
13 | 17 | import torch.nn as nn |
14 | 18 | from torch.jit import Final |
15 | 19 | from torch.utils.checkpoint import checkpoint |
16 | 20 |
|
17 | 21 | from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD |
18 | | -from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn |
| 22 | +from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn, LayerType |
19 | 23 | from ._builder import build_model_with_cfg |
| 24 | +from ._manipulate import named_apply |
20 | 25 | from ._registry import generate_default_cfgs, register_model |
| 26 | +from .vision_transformer import get_init_weights_vit |
21 | 27 |
|
22 | 28 | __all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this |
23 | 29 |
|
@@ -215,59 +221,61 @@ class VisionTransformerRelPos(nn.Module): |
215 | 221 |
|
216 | 222 | def __init__( |
217 | 223 | self, |
218 | | - img_size=224, |
219 | | - patch_size=16, |
220 | | - in_chans=3, |
221 | | - num_classes=1000, |
222 | | - global_pool='avg', |
223 | | - embed_dim=768, |
224 | | - depth=12, |
225 | | - num_heads=12, |
226 | | - mlp_ratio=4., |
227 | | - qkv_bias=True, |
228 | | - qk_norm=False, |
229 | | - init_values=1e-6, |
230 | | - class_token=False, |
231 | | - fc_norm=False, |
232 | | - rel_pos_type='mlp', |
233 | | - rel_pos_dim=None, |
234 | | - shared_rel_pos=False, |
235 | | - drop_rate=0., |
236 | | - proj_drop_rate=0., |
237 | | - attn_drop_rate=0., |
238 | | - drop_path_rate=0., |
239 | | - weight_init='skip', |
240 | | - embed_layer=PatchEmbed, |
241 | | - norm_layer=None, |
242 | | - act_layer=None, |
243 | | - block_fn=RelPosBlock |
| 224 | + img_size: Union[int, Tuple[int, int]] = 224, |
| 225 | + patch_size: Union[int, Tuple[int, int]] = 16, |
| 226 | + in_chans: int = 3, |
| 227 | + num_classes: int = 1000, |
| 228 | + global_pool: Literal['', 'avg', 'token', 'map'] = 'avg', |
| 229 | + embed_dim: int = 768, |
| 230 | + depth: int = 12, |
| 231 | + num_heads: int = 12, |
| 232 | + mlp_ratio: float = 4., |
| 233 | + qkv_bias: bool = True, |
| 234 | + qk_norm: bool = False, |
| 235 | + init_values: Optional[float] = 1e-6, |
| 236 | + class_token: bool = False, |
| 237 | + fc_norm: bool = False, |
| 238 | + rel_pos_type: str = 'mlp', |
| 239 | + rel_pos_dim: Optional[int] = None, |
| 240 | + shared_rel_pos: bool = False, |
| 241 | + drop_rate: float = 0., |
| 242 | + proj_drop_rate: float = 0., |
| 243 | + attn_drop_rate: float = 0., |
| 244 | + drop_path_rate: float = 0., |
| 245 | + weight_init: Literal['skip', 'jax', 'moco', ''] = 'skip', |
| 246 | + fix_init: bool = False, |
| 247 | + embed_layer: Type[nn.Module] = PatchEmbed, |
| 248 | + norm_layer: Optional[LayerType] = None, |
| 249 | + act_layer: Optional[LayerType] = None, |
| 250 | + block_fn: Type[nn.Module] = RelPosBlock |
244 | 251 | ): |
245 | 252 | """ |
246 | 253 | Args: |
247 | | - img_size (int, tuple): input image size |
248 | | - patch_size (int, tuple): patch size |
249 | | - in_chans (int): number of input channels |
250 | | - num_classes (int): number of classes for classification head |
251 | | - global_pool (str): type of global pooling for final sequence (default: 'avg') |
252 | | - embed_dim (int): embedding dimension |
253 | | - depth (int): depth of transformer |
254 | | - num_heads (int): number of attention heads |
255 | | - mlp_ratio (int): ratio of mlp hidden dim to embedding dim |
256 | | - qkv_bias (bool): enable bias for qkv if True |
257 | | - qk_norm (bool): Enable normalization of query and key in attention |
258 | | - init_values: (float): layer-scale init values |
259 | | - class_token (bool): use class token (default: False) |
260 | | - fc_norm (bool): use pre classifier norm instead of pre-pool |
261 | | - rel_pos_ty pe (str): type of relative position |
262 | | - shared_rel_pos (bool): share relative pos across all blocks |
263 | | - drop_rate (float): dropout rate |
264 | | - proj_drop_rate (float): projection dropout rate |
265 | | - attn_drop_rate (float): attention dropout rate |
266 | | - drop_path_rate (float): stochastic depth rate |
267 | | - weight_init (str): weight init scheme |
268 | | - embed_layer (nn.Module): patch embedding layer |
269 | | - norm_layer: (nn.Module): normalization layer |
270 | | - act_layer: (nn.Module): MLP activation layer |
| 254 | + img_size: input image size |
| 255 | + patch_size: patch size |
| 256 | + in_chans: number of input channels |
| 257 | + num_classes: number of classes for classification head |
| 258 | + global_pool: type of global pooling for final sequence (default: 'avg') |
| 259 | + embed_dim: embedding dimension |
| 260 | + depth: depth of transformer |
| 261 | + num_heads: number of attention heads |
| 262 | + mlp_ratio: ratio of mlp hidden dim to embedding dim |
| 263 | + qkv_bias: enable bias for qkv if True |
| 264 | + qk_norm: Enable normalization of query and key in attention |
| 265 | + init_values: layer-scale init values |
| 266 | + class_token: use class token (default: False) |
| 267 | + fc_norm: use pre classifier norm instead of pre-pool |
| 268 | + rel_pos_type: type of relative position |
| 269 | + shared_rel_pos: share relative pos across all blocks |
| 270 | + drop_rate: dropout rate |
| 271 | + proj_drop_rate: projection dropout rate |
| 272 | + attn_drop_rate: attention dropout rate |
| 273 | + drop_path_rate: stochastic depth rate |
| 274 | + weight_init: weight init scheme |
| 275 | + fix_init: apply weight initialization fix (scaling w/ layer index) |
| 276 | + embed_layer: patch embedding layer |
| 277 | + norm_layer: normalization layer |
| 278 | + act_layer: MLP activation layer |
271 | 279 | """ |
272 | 280 | super().__init__() |
273 | 281 | assert global_pool in ('', 'avg', 'token') |
@@ -332,13 +340,22 @@ def __init__( |
332 | 340 |
|
333 | 341 | if weight_init != 'skip': |
334 | 342 | self.init_weights(weight_init) |
| 343 | + if fix_init: |
| 344 | + self.fix_init_weight() |
335 | 345 |
|
336 | 346 | def init_weights(self, mode=''): |
337 | 347 | assert mode in ('jax', 'moco', '') |
338 | 348 | if self.cls_token is not None: |
339 | 349 | nn.init.normal_(self.cls_token, std=1e-6) |
340 | | - # FIXME weight init scheme using PyTorch defaults curently |
341 | | - #named_apply(get_init_weights_vit(mode, head_bias), self) |
| 350 | + named_apply(get_init_weights_vit(mode), self) |
| 351 | + |
| 352 | + def fix_init_weight(self): |
| 353 | + def rescale(param, _layer_id): |
| 354 | + param.div_(math.sqrt(2.0 * _layer_id)) |
| 355 | + |
| 356 | + for layer_id, layer in enumerate(self.blocks): |
| 357 | + rescale(layer.attn.proj.weight.data, layer_id + 1) |
| 358 | + rescale(layer.mlp.fc2.weight.data, layer_id + 1) |
342 | 359 |
|
343 | 360 | @torch.jit.ignore |
344 | 361 | def no_weight_decay(self): |
|
0 commit comments