11from collections import OrderedDict
2- from typing import Callable , List , Optional , Type , Union
2+ from functools import partial
3+ from typing import Any , Callable , List , Optional , Type , Union
34
45import torch .nn as nn
5- from pydantic import BaseModel
6+ from pydantic import BaseModel , root_validator
67
78from .layers import ConvBnAct , SEModule , SimpleSelfAttention
89
910__all__ = [
1011 "init_cnn" ,
11- "act_fn" ,
1212 "ResBlock" ,
1313 "ModelConstructor" ,
1414 "XResNet34" ,
1515 "XResNet50" ,
1616]
1717
1818
19- act_fn = nn .ReLU (inplace = True )
20-
21-
2219class ResBlock (nn .Module ):
2320 """Universal Resnet block. Basic block if expansion is 1, otherwise is Bottleneck."""
2421
@@ -29,13 +26,13 @@ def __init__(
2926 mid_channels : int ,
3027 stride : int = 1 ,
3128 conv_layer = ConvBnAct ,
32- act_fn : nn .Module = act_fn ,
29+ act_fn : Type [ nn .Module ] = nn . ReLU ,
3330 zero_bn : bool = True ,
3431 bn_1st : bool = True ,
3532 groups : int = 1 ,
3633 dw : bool = False ,
3734 div_groups : Union [None , int ] = None ,
38- pool : Union [nn .Module , None ] = None ,
35+ pool : Union [Callable [[ Any ], nn .Module ] , None ] = None ,
3936 se : Union [nn .Module , None ] = None ,
4037 sa : Union [nn .Module , None ] = None ,
4138 ):
@@ -100,7 +97,7 @@ def __init__(
10097 if stride != 1 or in_channels != out_channels :
10198 id_layers = []
10299 if stride != 1 and pool is not None : # if pool - reduce by pool else stride 2 art id_conv
103- id_layers .append (("pool" , pool ))
100+ id_layers .append (("pool" , pool () ))
104101 if in_channels != out_channels or (stride != 1 and pool is None ):
105102 id_layers += [("id_conv" , conv_layer (
106103 in_channels ,
@@ -112,7 +109,7 @@ def __init__(
112109 self .id_conv = nn .Sequential (OrderedDict (id_layers ))
113110 else :
114111 self .id_conv = None
115- self .act_fn = act_fn
112+ self .act_fn = act_fn ( inplace = True ) # type: ignore
116113
117114 def forward (self , x ):
118115 identity = self .id_conv (x ) if self .id_conv is not None else x
@@ -130,8 +127,8 @@ class ModelCfg(BaseModel):
130127 block_sizes : List [int ] = [64 , 128 , 256 , 512 ]
131128 layers : List [int ] = [2 , 2 , 2 , 2 ]
132129 norm : Type [nn .Module ] = nn .BatchNorm2d
133- act_fn : nn .Module = nn .ReLU ( inplace = True )
134- pool : nn .Module = nn .AvgPool2d ( 2 , ceil_mode = True )
130+ act_fn : Type [ nn .Module ] = nn .ReLU
131+ pool : Callable [[ Any ], nn .Module ] = partial ( nn .AvgPool2d , kernel_size = 2 , ceil_mode = True )
135132 expansion : int = 1
136133 groups : int = 1
137134 dw : bool = False
@@ -144,7 +141,7 @@ class ModelCfg(BaseModel):
144141 zero_bn : bool = True
145142 stem_stride_on : int = 0
146143 stem_sizes : List [int ] = [32 , 32 , 64 ]
147- stem_pool : Union [nn .Module , None ] = nn .MaxPool2d ( kernel_size = 3 , stride = 2 , padding = 1 ) # type: ignore
144+ stem_pool : Union [Callable [[ Any ], nn .Module ] , None ] = partial ( nn .MaxPool2d , kernel_size = 3 , stride = 2 , padding = 1 )
148145 stem_bn_end : bool = False
149146 init_cnn : Optional [Callable [[nn .Module ], None ]] = None
150147 make_stem : Optional [Callable ] = None
@@ -192,7 +189,7 @@ def make_stem(self: ModelCfg) -> nn.Sequential:
192189 for i in range (len (self .stem_sizes ) - 1 )
193190 ]
194191 if self .stem_pool :
195- stem .append (("stem_pool" , self .stem_pool ))
192+ stem .append (("stem_pool" , self .stem_pool () ))
196193 if self .stem_bn_end :
197194 stem .append (("norm" , self .norm (self .stem_sizes [- 1 ]))) # type: ignore
198195 return nn .Sequential (OrderedDict (stem ))
@@ -260,29 +257,30 @@ def make_head(cfg: ModelCfg) -> nn.Sequential:
260257class ModelConstructor (ModelCfg ):
261258 """Model constructor. As default - xresnet18"""
262259
263- def __init__ ( self , ** data ):
264- super (). __init__ ( ** data )
265- if self . init_cnn is None :
266- self . init_cnn = init_cnn
267- if self . make_stem is None :
268- self . make_stem = make_stem
269- if self . make_layer is None :
270- self . make_layer = make_layer
271- if self . make_body is None :
272- self . make_body = make_body
273- if self . make_head is None :
274- self . make_head = make_head
275-
276- if self . stem_sizes [0 ] != self . in_chans :
277- self . stem_sizes = [self . in_chans ] + self . stem_sizes
278- if self . se and isinstance (self . se , (bool , int )): # if se=1 or se=True
279- self . se = SEModule
280- if self . sa and isinstance (self . sa , (bool , int )): # if sa=1 or sa=True
281- self . sa = SimpleSelfAttention # default: ks=1, sym=sym
282- if self . se_module or self . se_reduction : # pragma: no cover
260+ @ root_validator
261+ def post_init ( cls , values ):
262+ if values [ " init_cnn" ] is None :
263+ values [ " init_cnn" ] = init_cnn
264+ if values [ " make_stem" ] is None :
265+ values [ " make_stem" ] = make_stem
266+ if values [ " make_layer" ] is None :
267+ values [ " make_layer" ] = make_layer
268+ if values [ " make_body" ] is None :
269+ values [ " make_body" ] = make_body
270+ if values [ " make_head" ] is None :
271+ values [ " make_head" ] = make_head
272+
273+ if values [ " stem_sizes" ] [0 ] != values [ " in_chans" ] :
274+ values [ " stem_sizes" ] = [values [ " in_chans" ]] + values [ " stem_sizes" ]
275+ if values [ "se" ] and isinstance (values [ "se" ] , (bool , int )): # if se=1 or se=True
276+ values [ "se" ] = SEModule
277+ if values [ "sa" ] and isinstance (values [ "sa" ] , (bool , int )): # if sa=1 or sa=True
278+ values [ "sa" ] = SimpleSelfAttention # default: ks=1, sym=sym
279+ if values [ " se_module" ] or values [ " se_reduction" ] : # pragma: no cover
283280 print (
284281 "Deprecated. Pass se_module as se argument, se_reduction as arg to se."
285282 ) # add deprecation warning.
283+ return values
286284
287285 @property
288286 def stem (self ):
@@ -309,11 +307,12 @@ def __call__(self):
309307 return model
310308
311309 def __repr__ (self ):
310+ se_repr = self .se .__name__ if self .se else "False"
312311 return (
313312 f"{ self .name } constructor\n "
314313 f" in_chans: { self .in_chans } , num_classes: { self .num_classes } \n "
315314 f" expansion: { self .expansion } , groups: { self .groups } , dw: { self .dw } , div_groups: { self .div_groups } \n "
316- f" sa : { self .sa } , se : { self .se } \n "
315+ f" act_fn : { self .act_fn . __name__ } , sa : { self .sa } , se: { se_repr } \n "
317316 f" stem sizes: { self .stem_sizes } , stride on { self .stem_stride_on } \n "
318317 f" body sizes { self .block_sizes } \n "
319318 f" layers: { self .layers } "
0 commit comments