Skip to content

Commit 621e1b2

Browse files
committed
Add ideas from 'Scaling ViT to 22-B Params', testing PyTorch 2.0 fused F.scaled_dot_product_attention impl in vit, vit_relpos, maxxvit / coatnet.
1 parent a3d5285 commit 621e1b2

File tree

7 files changed

+420
-91
lines changed

7 files changed

+420
-91
lines changed

timm/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from .mixed_conv2d import MixedConv2d
2929
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp
3030
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
31-
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
31+
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
3232
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
3333
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
3434
from .padding import get_padding, get_same_padding, pad_same

timm/layers/fast_norm.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@
1717
except ImportError:
1818
has_apex = False
1919

20+
try:
21+
from apex.normalization.fused_layer_norm import fused_rms_norm_affine, fused_rms_norm
22+
has_apex_rmsnorm = True
23+
except ImportError:
24+
has_apex_rmsnorm = False
25+
2026

2127
# fast (ie lower precision LN) can be disabled with this flag if issues crop up
2228
_USE_FAST_NORM = False # defaulting to False for now
@@ -76,3 +82,32 @@ def fast_layer_norm(
7682

7783
with torch.cuda.amp.autocast(enabled=False):
7884
return F.layer_norm(x, normalized_shape, weight, bias, eps)
85+
86+
87+
def rms_norm(
88+
x: torch.Tensor,
89+
normalized_shape: List[int],
90+
weight: Optional[torch.Tensor] = None,
91+
eps: float = 1e-5,
92+
):
93+
dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1))
94+
v = torch.var(x, dim=dims, keepdim=True)
95+
x = x * torch.rsqrt(v + eps)
96+
if weight is not None:
97+
x = x * weight
98+
return x
99+
100+
101+
def fast_rms_norm(
102+
x: torch.Tensor,
103+
normalized_shape: List[int],
104+
weight: Optional[torch.Tensor] = None,
105+
eps: float = 1e-5,
106+
) -> torch.Tensor:
107+
if torch.jit.is_scripting() or not has_apex_rmsnorm:
108+
return rms_norm(x, normalized_shape, weight, eps)
109+
110+
if weight is None:
111+
return fused_rms_norm(x, normalized_shape, eps)
112+
else:
113+
return fused_rms_norm_affine(x, weight, normalized_shape, eps)

timm/layers/norm.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
55
Hacked together by / Copyright 2022 Ross Wightman
66
"""
7+
import numbers
8+
from typing import Tuple
79

810
import torch
911
import torch.nn as nn
1012
import torch.nn.functional as F
1113

12-
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
14+
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm
1315

1416

1517
class GroupNorm(nn.GroupNorm):
@@ -115,3 +117,38 @@ def forward(self, x) -> torch.Tensor:
115117
else:
116118
x = _layer_norm_cf(x, self.weight, self.bias, self.eps)
117119
return x
120+
121+
122+
class RmsNorm(nn.Module):
123+
""" RmsNorm w/ fast (apex) norm if available
124+
"""
125+
normalized_shape: Tuple[int, ...]
126+
eps: float
127+
elementwise_affine: bool
128+
129+
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
130+
factory_kwargs = {'device': device, 'dtype': dtype}
131+
super().__init__()
132+
normalized_shape = channels
133+
if isinstance(normalized_shape, numbers.Integral):
134+
# mypy error: incompatible types in assignment
135+
normalized_shape = (normalized_shape,) # type: ignore[assignment]
136+
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
137+
self.eps = eps
138+
self.elementwise_affine = affine
139+
if self.elementwise_affine:
140+
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
141+
else:
142+
self.register_parameter('weight', None)
143+
144+
self.reset_parameters()
145+
146+
def reset_parameters(self) -> None:
147+
if self.elementwise_affine:
148+
nn.init.ones_(self.weight)
149+
150+
def forward(self, x: torch.Tensor) -> torch.Tensor:
151+
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
152+
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
153+
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
154+
return x

timm/layers/pos_embed_rel.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ def gen_relative_log_coords(
8383
pretrained_win_size: Tuple[int, int] = (0, 0),
8484
mode='swin',
8585
):
86-
assert mode in ('swin', 'cr', 'rw')
87-
# as per official swin-v2 impl, supporting timm specific 'cr' and 'rw' log coords as well
86+
assert mode in ('swin', 'cr')
87+
# as per official swin-v2 impl, supporting timm specific 'cr' log coords as well
8888
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32)
8989
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32)
9090
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
@@ -100,18 +100,9 @@ def gen_relative_log_coords(
100100
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
101101
1.0 + relative_coords_table.abs()) / math.log2(8)
102102
else:
103-
if mode == 'rw':
104-
# cr w/ window size normalization -> [-1,1] log coords
105-
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
106-
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
107-
relative_coords_table *= 8 # scale to -8, 8
108-
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
109-
1.0 + relative_coords_table.abs())
110-
relative_coords_table /= math.log2(9) # -> [-1, 1]
111-
else:
112-
# mode == 'cr'
113-
relative_coords_table = torch.sign(relative_coords_table) * torch.log(
114-
1.0 + relative_coords_table.abs())
103+
# mode == 'cr'
104+
relative_coords_table = torch.sign(relative_coords_table) * torch.log(
105+
1.0 + relative_coords_table.abs())
115106

116107
return relative_coords_table
117108

@@ -141,10 +132,6 @@ def __init__(
141132
self.bias_act = nn.Sigmoid()
142133
self.bias_gain = 16
143134
mlp_bias = (True, False)
144-
elif mode == 'rw':
145-
self.bias_act = nn.Tanh()
146-
self.bias_gain = 4
147-
mlp_bias = True
148135
else:
149136
self.bias_act = nn.Identity()
150137
self.bias_gain = None

timm/models/maxxvit.py

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def __init__(
160160
self.dim_head = dim_head
161161
self.head_first = head_first
162162
self.scale = dim_head ** -0.5
163+
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
163164

164165
self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias)
165166
self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None
@@ -175,15 +176,31 @@ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
175176
else:
176177
q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1)
177178

178-
attn = (q.transpose(-2, -1) @ k) * self.scale
179-
if self.rel_pos is not None:
180-
attn = self.rel_pos(attn)
181-
elif shared_rel_pos is not None:
182-
attn = attn + shared_rel_pos
183-
attn = attn.softmax(dim=-1)
184-
attn = self.attn_drop(attn)
179+
if self.fast_attn:
180+
if self.rel_pos is not None:
181+
attn_bias = self.rel_pos.get_bias()
182+
elif shared_rel_pos is not None:
183+
attn_bias = shared_rel_pos
184+
else:
185+
attn_bias = None
186+
x = torch.nn.functional.scaled_dot_product_attention(
187+
q.transpose(-1, -2),
188+
k.transpose(-1, -2),
189+
v.transpose(-1, -2),
190+
attn_mask=attn_bias,
191+
dropout_p=self.attn_drop.p,
192+
).transpose(-1, -2).reshape(B, -1, H, W)
193+
else:
194+
q = q * self.scale
195+
attn = q.transpose(-2, -1) @ k
196+
if self.rel_pos is not None:
197+
attn = self.rel_pos(attn)
198+
elif shared_rel_pos is not None:
199+
attn = attn + shared_rel_pos
200+
attn = attn.softmax(dim=-1)
201+
attn = self.attn_drop(attn)
202+
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
185203

186-
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
187204
x = self.proj(x)
188205
x = self.proj_drop(x)
189206
return x
@@ -211,6 +228,7 @@ def __init__(
211228
self.dim_head = dim_head
212229
self.head_first = head_first
213230
self.scale = dim_head ** -0.5
231+
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
214232

215233
self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias)
216234
self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None
@@ -227,15 +245,30 @@ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
227245
else:
228246
q, k, v = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.dim_head).transpose(1, 3).unbind(2)
229247

230-
attn = (q @ k.transpose(-2, -1)) * self.scale
231-
if self.rel_pos is not None:
232-
attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos)
233-
elif shared_rel_pos is not None:
234-
attn = attn + shared_rel_pos
235-
attn = attn.softmax(dim=-1)
236-
attn = self.attn_drop(attn)
237-
238-
x = (attn @ v).transpose(1, 2).reshape(restore_shape + (-1,))
248+
if self.fast_attn:
249+
if self.rel_pos is not None:
250+
attn_bias = self.rel_pos.get_bias()
251+
elif shared_rel_pos is not None:
252+
attn_bias = shared_rel_pos
253+
else:
254+
attn_bias = None
255+
x = torch.nn.functional.scaled_dot_product_attention(
256+
q, k, v,
257+
attn_mask=attn_bias,
258+
dropout_p=self.attn_drop.p,
259+
)
260+
else:
261+
q = q * self.scale
262+
attn = q @ k.transpose(-2, -1)
263+
if self.rel_pos is not None:
264+
attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos)
265+
elif shared_rel_pos is not None:
266+
attn = attn + shared_rel_pos
267+
attn = attn.softmax(dim=-1)
268+
attn = self.attn_drop(attn)
269+
x = attn @ v
270+
271+
x = x.transpose(1, 2).reshape(restore_shape + (-1,))
239272
x = self.proj(x)
240273
x = self.proj_drop(x)
241274
return x

0 commit comments

Comments
 (0)