33from typing import Any , Callable , List , Optional , Type , Union
44
55import torch .nn as nn
6- from pydantic import BaseModel
6+ from pydantic import BaseModel , root_validator
77
88from .layers import ConvBnAct , SEModule , SimpleSelfAttention
99
1010__all__ = [
1111 "init_cnn" ,
12- # "act_fn",
1312 "ResBlock" ,
1413 "ModelConstructor" ,
1514 "XResNet34" ,
1615 "XResNet50" ,
1716]
1817
1918
20- # act_fn = nn.ReLU
21-
22-
2319class ResBlock (nn .Module ):
2420 """Universal Resnet block. Basic block if expansion is 1, otherwise is Bottleneck."""
2521
@@ -261,29 +257,30 @@ def make_head(cfg: ModelCfg) -> nn.Sequential:
261257class ModelConstructor (ModelCfg ):
262258 """Model constructor. As default - xresnet18"""
263259
264- def __init__ ( self , ** data ):
265- super (). __init__ ( ** data )
266- if self . init_cnn is None :
267- self . init_cnn = init_cnn
268- if self . make_stem is None :
269- self . make_stem = make_stem
270- if self . make_layer is None :
271- self . make_layer = make_layer
272- if self . make_body is None :
273- self . make_body = make_body
274- if self . make_head is None :
275- self . make_head = make_head
276-
277- if self . stem_sizes [0 ] != self . in_chans :
278- self . stem_sizes = [self . in_chans ] + self . stem_sizes
279- if self . se and isinstance (self . se , (bool , int )): # if se=1 or se=True
280- self . se = SEModule
281- if self . sa and isinstance (self . sa , (bool , int )): # if sa=1 or sa=True
282- self . sa = SimpleSelfAttention # default: ks=1, sym=sym
283- if self . se_module or self . se_reduction : # pragma: no cover
260+ @ root_validator
261+ def post_init ( cls , values ):
262+ if values [ " init_cnn" ] is None :
263+ values [ " init_cnn" ] = init_cnn
264+ if values [ " make_stem" ] is None :
265+ values [ " make_stem" ] = make_stem
266+ if values [ " make_layer" ] is None :
267+ values [ " make_layer" ] = make_layer
268+ if values [ " make_body" ] is None :
269+ values [ " make_body" ] = make_body
270+ if values [ " make_head" ] is None :
271+ values [ " make_head" ] = make_head
272+
273+ if values [ " stem_sizes" ] [0 ] != values [ " in_chans" ] :
274+ values [ " stem_sizes" ] = [values [ " in_chans" ]] + values [ " stem_sizes" ]
275+ if values [ "se" ] and isinstance (values [ "se" ] , (bool , int )): # if se=1 or se=True
276+ values [ "se" ] = SEModule
277+ if values [ "sa" ] and isinstance (values [ "sa" ] , (bool , int )): # if sa=1 or sa=True
278+ values [ "sa" ] = SimpleSelfAttention # default: ks=1, sym=sym
279+ if values [ " se_module" ] or values [ " se_reduction" ] : # pragma: no cover
284280 print (
285281 "Deprecated. Pass se_module as se argument, se_reduction as arg to se."
286282 ) # add deprecation warning.
283+ return values
287284
288285 @property
289286 def stem (self ):
@@ -310,11 +307,12 @@ def __call__(self):
310307 return model
311308
312309 def __repr__ (self ):
310+ se_repr = self .se .__name__ if self .se else "False"
313311 return (
314312 f"{ self .name } constructor\n "
315313 f" in_chans: { self .in_chans } , num_classes: { self .num_classes } \n "
316314 f" expansion: { self .expansion } , groups: { self .groups } , dw: { self .dw } , div_groups: { self .div_groups } \n "
317- f" act_fn: { self .act_fn .__name__ } , sa: { self .sa } , se: { self . se } \n "
315+ f" act_fn: { self .act_fn .__name__ } , sa: { self .sa } , se: { se_repr } \n "
318316 f" stem sizes: { self .stem_sizes } , stride on { self .stem_stride_on } \n "
319317 f" body sizes { self .block_sizes } \n "
320318 f" layers: { self .layers } "
0 commit comments