1111import torch .nn as nn
1212import torch .nn .functional as F
1313
14- from .fast_norm import is_fast_norm , fast_group_norm , fast_layer_norm , fast_rms_norm , fast_simple_norm
14+ from .fast_norm import is_fast_norm , fast_group_norm , fast_layer_norm , fast_rms_norm , fast_simple_norm , simple_norm
15+
16+ try :
17+ from torch .nn .functional import rms_norm
18+ except ImportError :
19+ from .fast_norm import rms_norm
1520
1621
1722class GroupNorm (nn .GroupNorm ):
23+ _fast_norm : torch .jit .Final [bool ]
24+
1825 def __init__ (self , num_channels , num_groups = 32 , eps = 1e-5 , affine = True ):
1926 # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
2027 super ().__init__ (num_groups , num_channels , eps = eps , affine = affine )
21- self .fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
28+ self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
2229
2330 def forward (self , x ):
24- if self .fast_norm :
31+ if self ._fast_norm :
2532 return fast_group_norm (x , self .num_groups , self .weight , self .bias , self .eps )
2633 else :
2734 return F .group_norm (x , self .num_groups , self .weight , self .bias , self .eps )
@@ -31,13 +38,14 @@ class GroupNorm1(nn.GroupNorm):
3138 """ Group Normalization with 1 group.
3239 Input: tensor in shape [B, C, *]
3340 """
41+ _fast_norm : torch .jit .Final [bool ]
3442
3543 def __init__ (self , num_channels , ** kwargs ):
3644 super ().__init__ (1 , num_channels , ** kwargs )
37- self .fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
45+ self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
3846
3947 def forward (self , x : torch .Tensor ) -> torch .Tensor :
40- if self .fast_norm :
48+ if self ._fast_norm :
4149 return fast_group_norm (x , self .num_groups , self .weight , self .bias , self .eps )
4250 else :
4351 return F .group_norm (x , self .num_groups , self .weight , self .bias , self .eps )
@@ -46,6 +54,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4654class LayerNorm (nn .LayerNorm ):
4755 """ LayerNorm w/ fast norm option
4856 """
57+ _fast_norm : torch .jit .Final [bool ]
58+
4959 def __init__ (self , num_channels , eps = 1e-6 , affine = True ):
5060 super ().__init__ (num_channels , eps = eps , elementwise_affine = affine )
5161 self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
@@ -60,6 +70,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
6070
6171class LayerNorm2d (nn .LayerNorm ):
6272 """ LayerNorm for channels of '2D' spatial NCHW tensors """
73+ _fast_norm : torch .jit .Final [bool ]
74+
6375 def __init__ (self , num_channels , eps = 1e-6 , affine = True ):
6476 super ().__init__ (num_channels , eps = eps , elementwise_affine = affine )
6577 self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
@@ -121,10 +133,11 @@ def forward(self, x) -> torch.Tensor:
121133class RmsNorm (nn .Module ):
122134 """ RmsNorm w/ fast (apex) norm if available
123135 """
124- __constants__ = ['normalized_shape' , 'eps' , 'elementwise_affine' ]
136+ __constants__ = ['normalized_shape' , 'eps' , 'elementwise_affine' , '_fast_norm' ]
125137 normalized_shape : Tuple [int , ...]
126138 eps : float
127139 elementwise_affine : bool
140+ _fast_norm : bool
128141
129142 def __init__ (self , channels , eps = 1e-6 , affine = True , device = None , dtype = None ) -> None :
130143 factory_kwargs = {'device' : device , 'dtype' : dtype }
@@ -136,6 +149,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
136149 self .normalized_shape = tuple (normalized_shape ) # type: ignore[arg-type]
137150 self .eps = eps
138151 self .elementwise_affine = affine
152+ self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
153+
139154 if self .elementwise_affine :
140155 self .weight = nn .Parameter (torch .empty (self .normalized_shape , ** factory_kwargs ))
141156 else :
@@ -150,17 +165,21 @@ def reset_parameters(self) -> None:
150165 def forward (self , x : torch .Tensor ) -> torch .Tensor :
151166 # NOTE fast norm fallback needs our rms norm impl, so both paths through here.
152167 # Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
153- x = fast_rms_norm (x , self .normalized_shape , self .weight , self .eps )
168+ if self ._fast_norm :
169+ x = fast_rms_norm (x , self .normalized_shape , self .weight , self .eps )
170+ else :
171+ x = rms_norm (x , self .normalized_shape , self .weight , self .eps )
154172 return x
155173
156174
157175class RmsNorm2d (nn .Module ):
158176 """ RmsNorm w/ fast (apex) norm if available
159177 """
160- __constants__ = ['normalized_shape' , 'eps' , 'elementwise_affine' ]
178+ __constants__ = ['normalized_shape' , 'eps' , 'elementwise_affine' , '_fast_norm' ]
161179 normalized_shape : Tuple [int , ...]
162180 eps : float
163181 elementwise_affine : bool
182+ _fast_norm : bool
164183
165184 def __init__ (self , channels , eps = 1e-6 , affine = True , device = None , dtype = None ) -> None :
166185 factory_kwargs = {'device' : device , 'dtype' : dtype }
@@ -172,6 +191,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
172191 self .normalized_shape = tuple (normalized_shape ) # type: ignore[arg-type]
173192 self .eps = eps
174193 self .elementwise_affine = affine
194+ self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
195+
175196 if self .elementwise_affine :
176197 self .weight = nn .Parameter (torch .empty (self .normalized_shape , ** factory_kwargs ))
177198 else :
@@ -187,18 +208,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
187208 x = x .permute (0 , 2 , 3 , 1 )
188209 # NOTE fast norm fallback needs our rms norm impl, so both paths through here.
189210 # Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
190- x = fast_rms_norm (x , self .normalized_shape , self .weight , self .eps )
211+ if self ._fast_norm :
212+ x = fast_rms_norm (x , self .normalized_shape , self .weight , self .eps )
213+ else :
214+ x = rms_norm (x , self .normalized_shape , self .weight , self .eps )
191215 x = x .permute (0 , 3 , 1 , 2 )
192216 return x
193217
194218
195219class SimpleNorm (nn .Module ):
196220 """ SimpleNorm (x / std(x))
197221 """
198- __constants__ = ['normalized_shape' , 'eps' , 'elementwise_affine' ]
222+ __constants__ = ['normalized_shape' , 'eps' , 'elementwise_affine' , '_fast_norm' ]
199223 normalized_shape : Tuple [int , ...]
200224 eps : float
201225 elementwise_affine : bool
226+ _fast_norm : bool
202227
203228 def __init__ (self , channels , eps = 1e-6 , affine = True , device = None , dtype = None ) -> None :
204229 factory_kwargs = {'device' : device , 'dtype' : dtype }
@@ -210,6 +235,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
210235 self .normalized_shape = tuple (normalized_shape ) # type: ignore[arg-type]
211236 self .eps = eps
212237 self .elementwise_affine = affine
238+ self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
239+
213240 if self .elementwise_affine :
214241 self .weight = nn .Parameter (torch .empty (self .normalized_shape , ** factory_kwargs ))
215242 else :
@@ -222,17 +249,21 @@ def reset_parameters(self) -> None:
222249 nn .init .ones_ (self .weight )
223250
224251 def forward (self , x : torch .Tensor ) -> torch .Tensor :
225- x = fast_simple_norm (x , self .normalized_shape , self .weight , self .eps )
252+ if self ._fast_norm :
253+ x = fast_simple_norm (x , self .normalized_shape , self .weight , self .eps )
254+ else :
255+ x = simple_norm (x , self .normalized_shape , self .weight , self .eps )
226256 return x
227257
228258
229259class SimpleNorm2d (nn .Module ):
230260 """ SimpleNorm for NCHW tensors
231261 """
232- __constants__ = ['normalized_shape' , 'eps' , 'elementwise_affine' ]
262+ __constants__ = ['normalized_shape' , 'eps' , 'elementwise_affine' , '_fast_norm' ]
233263 normalized_shape : Tuple [int , ...]
234264 eps : float
235265 elementwise_affine : bool
266+ _fast_norm : bool
236267
237268 def __init__ (self , channels , eps = 1e-6 , affine = True , device = None , dtype = None ) -> None :
238269 factory_kwargs = {'device' : device , 'dtype' : dtype }
@@ -244,6 +275,8 @@ def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) ->
244275 self .normalized_shape = tuple (normalized_shape ) # type: ignore[arg-type]
245276 self .eps = eps
246277 self .elementwise_affine = affine
278+ self ._fast_norm = is_fast_norm () # can't script unless we have these flags here (no globals)
279+
247280 if self .elementwise_affine :
248281 self .weight = nn .Parameter (torch .empty (self .normalized_shape , ** factory_kwargs ))
249282 else :
@@ -257,6 +290,9 @@ def reset_parameters(self) -> None:
257290
258291 def forward (self , x : torch .Tensor ) -> torch .Tensor :
259292 x = x .permute (0 , 2 , 3 , 1 )
260- x = fast_simple_norm (x , self .normalized_shape , self .weight , self .eps )
293+ if self ._fast_norm :
294+ x = fast_simple_norm (x , self .normalized_shape , self .weight , self .eps )
295+ else :
296+ x = simple_norm (x , self .normalized_shape , self .weight , self .eps )
261297 x = x .permute (0 , 3 , 1 , 2 )
262298 return x
0 commit comments