@@ -236,7 +236,7 @@ def forward(self, x):
236236
237237
238238class NormFreeBlock (nn .Module ):
239- """Normalization-free pre-activation block.
239+ """Normalization-Free pre-activation block.
240240 """
241241
242242 def __init__ (
@@ -351,6 +351,7 @@ def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None):
351351 return nn .Sequential (stem ), stem_stride , stem_feature
352352
353353
354+ # from https://github.com/deepmind/deepmind-research/tree/master/nfnets
354355_nonlin_gamma = dict (
355356 identity = 1.0 ,
356357 celu = 1.270926833152771 ,
@@ -371,10 +372,13 @@ def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None):
371372
372373
373374class NormFreeNet (nn .Module ):
374- """ Normalization-free ResNets and RegNets
375+ """ Normalization-Free Network
375376
376- As described in `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
377+ As described in :
378+ `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
377379 - https://arxiv.org/abs/2101.08692
380+ and
381+ `High-Performance Large-Scale Image Recognition Without Normalization` - https://arxiv.org/abs/2102.06171
378382
379383 This model aims to cover both the NFRegNet-Bx models as detailed in the paper's code snippets and
380384 the (preact) ResNet models described earlier in the paper.
@@ -432,7 +436,7 @@ def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg',
432436 blocks += [NormFreeBlock (
433437 in_chs = prev_chs , out_chs = out_chs ,
434438 alpha = cfg .alpha ,
435- beta = 1. / expected_var ** 0.5 , # NOTE: beta used as multiplier in block
439+ beta = 1. / expected_var ** 0.5 ,
436440 stride = stride if block_idx == 0 else 1 ,
437441 dilation = dilation ,
438442 first_dilation = first_dilation ,
@@ -477,8 +481,6 @@ def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg',
477481 if m .bias is not None :
478482 nn .init .zeros_ (m .bias )
479483 elif isinstance (m , nn .Conv2d ):
480- # as per discussion with paper authors, original in haiku is
481- # hk.initializers.VarianceScaling(1.0, 'fan_in', 'normal')' w/ zero'd bias
482484 nn .init .kaiming_normal_ (m .weight , mode = 'fan_in' , nonlinearity = 'linear' )
483485 if m .bias is not None :
484486 nn .init .zeros_ (m .bias )
0 commit comments