Skip to content

Commit ac1b08d

Browse files
committed
fix_init on vit & relpos vit
1 parent 935950c commit ac1b08d

File tree

2 files changed

+85
-56
lines changed

2 files changed

+85
-56
lines changed

timm/models/vision_transformer.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ def __init__(
421421
attn_drop_rate: float = 0.,
422422
drop_path_rate: float = 0.,
423423
weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '',
424+
fix_init: bool = False,
424425
embed_layer: Callable = PatchEmbed,
425426
norm_layer: Optional[LayerType] = None,
426427
act_layer: Optional[LayerType] = None,
@@ -449,6 +450,7 @@ def __init__(
449450
attn_drop_rate: Attention dropout rate.
450451
drop_path_rate: Stochastic depth rate.
451452
weight_init: Weight initialization scheme.
453+
fix_init: Apply weight initialization fix (scaling w/ layer index).
452454
embed_layer: Patch embedding layer.
453455
norm_layer: Normalization layer.
454456
act_layer: MLP activation layer.
@@ -536,8 +538,18 @@ def __init__(
536538

537539
if weight_init != 'skip':
538540
self.init_weights(weight_init)
541+
if fix_init:
542+
self.fix_init_weight()
539543

540-
def init_weights(self, mode: Literal['jax', 'jax_nlhb', 'moco', ''] = '') -> None:
544+
def fix_init_weight(self):
545+
def rescale(param, _layer_id):
546+
param.div_(math.sqrt(2.0 * _layer_id))
547+
548+
for layer_id, layer in enumerate(self.blocks):
549+
rescale(layer.attn.proj.weight.data, layer_id + 1)
550+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
551+
552+
def init_weights(self, mode: str = '') -> None:
541553
assert mode in ('jax', 'jax_nlhb', 'moco', '')
542554
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
543555
trunc_normal_(self.pos_embed, std=.02)
@@ -737,7 +749,7 @@ def init_weights_vit_moco(module: nn.Module, name: str = '') -> None:
737749
module.init_weights()
738750

739751

740-
def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> None:
752+
def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable:
741753
if 'jax' in mode:
742754
return partial(init_weights_vit_jax, head_bias=head_bias)
743755
elif 'moco' in mode:

timm/models/vision_transformer_relpos.py

Lines changed: 71 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,23 @@
77
import logging
88
import math
99
from functools import partial
10-
from typing import Optional, Tuple
10+
from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union, List
11+
try:
12+
from typing import Literal
13+
except ImportError:
14+
from typing_extensions import Literal
1115

1216
import torch
1317
import torch.nn as nn
1418
from torch.jit import Final
1519
from torch.utils.checkpoint import checkpoint
1620

1721
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
18-
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn
22+
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn, LayerType
1923
from ._builder import build_model_with_cfg
24+
from ._manipulate import named_apply
2025
from ._registry import generate_default_cfgs, register_model
26+
from .vision_transformer import get_init_weights_vit
2127

2228
__all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this
2329

@@ -215,59 +221,61 @@ class VisionTransformerRelPos(nn.Module):
215221

216222
def __init__(
217223
self,
218-
img_size=224,
219-
patch_size=16,
220-
in_chans=3,
221-
num_classes=1000,
222-
global_pool='avg',
223-
embed_dim=768,
224-
depth=12,
225-
num_heads=12,
226-
mlp_ratio=4.,
227-
qkv_bias=True,
228-
qk_norm=False,
229-
init_values=1e-6,
230-
class_token=False,
231-
fc_norm=False,
232-
rel_pos_type='mlp',
233-
rel_pos_dim=None,
234-
shared_rel_pos=False,
235-
drop_rate=0.,
236-
proj_drop_rate=0.,
237-
attn_drop_rate=0.,
238-
drop_path_rate=0.,
239-
weight_init='skip',
240-
embed_layer=PatchEmbed,
241-
norm_layer=None,
242-
act_layer=None,
243-
block_fn=RelPosBlock
224+
img_size: Union[int, Tuple[int, int]] = 224,
225+
patch_size: Union[int, Tuple[int, int]] = 16,
226+
in_chans: int = 3,
227+
num_classes: int = 1000,
228+
global_pool: Literal['', 'avg', 'token', 'map'] = 'avg',
229+
embed_dim: int = 768,
230+
depth: int = 12,
231+
num_heads: int = 12,
232+
mlp_ratio: float = 4.,
233+
qkv_bias: bool = True,
234+
qk_norm: bool = False,
235+
init_values: Optional[float] = 1e-6,
236+
class_token: bool = False,
237+
fc_norm: bool = False,
238+
rel_pos_type: str = 'mlp',
239+
rel_pos_dim: Optional[int] = None,
240+
shared_rel_pos: bool = False,
241+
drop_rate: float = 0.,
242+
proj_drop_rate: float = 0.,
243+
attn_drop_rate: float = 0.,
244+
drop_path_rate: float = 0.,
245+
weight_init: Literal['skip', 'jax', 'moco', ''] = 'skip',
246+
fix_init: bool = False,
247+
embed_layer: Type[nn.Module] = PatchEmbed,
248+
norm_layer: Optional[LayerType] = None,
249+
act_layer: Optional[LayerType] = None,
250+
block_fn: Type[nn.Module] = RelPosBlock
244251
):
245252
"""
246253
Args:
247-
img_size (int, tuple): input image size
248-
patch_size (int, tuple): patch size
249-
in_chans (int): number of input channels
250-
num_classes (int): number of classes for classification head
251-
global_pool (str): type of global pooling for final sequence (default: 'avg')
252-
embed_dim (int): embedding dimension
253-
depth (int): depth of transformer
254-
num_heads (int): number of attention heads
255-
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
256-
qkv_bias (bool): enable bias for qkv if True
257-
qk_norm (bool): Enable normalization of query and key in attention
258-
init_values: (float): layer-scale init values
259-
class_token (bool): use class token (default: False)
260-
fc_norm (bool): use pre classifier norm instead of pre-pool
261-
rel_pos_ty pe (str): type of relative position
262-
shared_rel_pos (bool): share relative pos across all blocks
263-
drop_rate (float): dropout rate
264-
proj_drop_rate (float): projection dropout rate
265-
attn_drop_rate (float): attention dropout rate
266-
drop_path_rate (float): stochastic depth rate
267-
weight_init (str): weight init scheme
268-
embed_layer (nn.Module): patch embedding layer
269-
norm_layer: (nn.Module): normalization layer
270-
act_layer: (nn.Module): MLP activation layer
254+
img_size: input image size
255+
patch_size: patch size
256+
in_chans: number of input channels
257+
num_classes: number of classes for classification head
258+
global_pool: type of global pooling for final sequence (default: 'avg')
259+
embed_dim: embedding dimension
260+
depth: depth of transformer
261+
num_heads: number of attention heads
262+
mlp_ratio: ratio of mlp hidden dim to embedding dim
263+
qkv_bias: enable bias for qkv if True
264+
qk_norm: Enable normalization of query and key in attention
265+
init_values: layer-scale init values
266+
class_token: use class token (default: False)
267+
fc_norm: use pre classifier norm instead of pre-pool
268+
rel_pos_type: type of relative position
269+
shared_rel_pos: share relative pos across all blocks
270+
drop_rate: dropout rate
271+
proj_drop_rate: projection dropout rate
272+
attn_drop_rate: attention dropout rate
273+
drop_path_rate: stochastic depth rate
274+
weight_init: weight init scheme
275+
fix_init: apply weight initialization fix (scaling w/ layer index)
276+
embed_layer: patch embedding layer
277+
norm_layer: normalization layer
278+
act_layer: MLP activation layer
271279
"""
272280
super().__init__()
273281
assert global_pool in ('', 'avg', 'token')
@@ -332,13 +340,22 @@ def __init__(
332340

333341
if weight_init != 'skip':
334342
self.init_weights(weight_init)
343+
if fix_init:
344+
self.fix_init_weight()
335345

336346
def init_weights(self, mode=''):
337347
assert mode in ('jax', 'moco', '')
338348
if self.cls_token is not None:
339349
nn.init.normal_(self.cls_token, std=1e-6)
340-
# FIXME weight init scheme using PyTorch defaults curently
341-
#named_apply(get_init_weights_vit(mode, head_bias), self)
350+
named_apply(get_init_weights_vit(mode), self)
351+
352+
def fix_init_weight(self):
353+
def rescale(param, _layer_id):
354+
param.div_(math.sqrt(2.0 * _layer_id))
355+
356+
for layer_id, layer in enumerate(self.blocks):
357+
rescale(layer.attn.proj.weight.data, layer_id + 1)
358+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
342359

343360
@torch.jit.ignore
344361
def no_weight_decay(self):

0 commit comments

Comments
 (0)