1- from dataclasses import dataclass , field , asdict
2-
31from collections import OrderedDict
4- # from functools import partial
52from typing import Callable , List , Optional , Type , Union
63
74import torch .nn as nn
5+ from pydantic import BaseModel
86
97from .layers import ConvBnAct , SEModule , SimpleSelfAttention
108
11-
129__all__ = [
1310 "init_cnn" ,
1411 "act_fn" ,
@@ -122,17 +119,16 @@ def forward(self, x):
122119 return self .act_fn (self .convs (x ) + identity )
123120
124121
125- @dataclass
126- class CfgMC :
122+ class CfgMC (BaseModel ):
127123 """Model constructor Config. As default - xresnet18"""
128124
129125 name : str = "MC"
130126 in_chans : int = 3
131127 num_classes : int = 1000
132128 block : Type [nn .Module ] = ResBlock
133129 conv_layer : Type [nn .Module ] = ConvBnAct
134- block_sizes : List [int ] = field ( default_factory = lambda : [64 , 128 , 256 , 512 ])
135- layers : List [int ] = field ( default_factory = lambda : [2 , 2 , 2 , 2 ])
130+ block_sizes : List [int ] = [64 , 128 , 256 , 512 ]
131+ layers : List [int ] = [2 , 2 , 2 , 2 ]
136132 norm : Type [nn .Module ] = nn .BatchNorm2d
137133 act_fn : nn .Module = nn .ReLU (inplace = True )
138134 pool : nn .Module = nn .AvgPool2d (2 , ceil_mode = True )
@@ -142,19 +138,22 @@ class CfgMC:
142138 div_groups : Union [int , None ] = None
143139 sa : Union [bool , int , Type [nn .Module ]] = False
144140 se : Union [bool , int , Type [nn .Module ]] = False
145- se_module = None
146- se_reduction = None
141+ se_module : Union [ bool , None ] = None
142+ se_reduction : Union [ int , None ] = None
147143 bn_1st : bool = True
148144 zero_bn : bool = True
149145 stem_stride_on : int = 0
150- stem_sizes : List [int ] = field ( default_factory = lambda : [32 , 32 , 64 ])
146+ stem_sizes : List [int ] = [32 , 32 , 64 ]
151147 stem_pool : Union [nn .Module , None ] = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 ) # type: ignore
152148 stem_bn_end : bool = False
153- _init_cnn : Optional [Callable [[nn .Module ], None ]] = field (repr = False , default = None )
154- _make_stem : Optional [Callable ] = field (repr = False , default = None )
155- _make_layer : Optional [Callable ] = field (repr = False , default = None )
156- _make_body : Optional [Callable ] = field (repr = False , default = None )
157- _make_head : Optional [Callable ] = field (repr = False , default = None )
149+ init_cnn : Optional [Callable [[nn .Module ], None ]] = None
150+ make_stem : Optional [Callable ] = None
151+ make_layer : Optional [Callable ] = None
152+ make_body : Optional [Callable ] = None
153+ make_head : Optional [Callable ] = None
154+
155+ class Config :
156+ arbitrary_types_allowed = True
158157
159158
160159def init_cnn (module : nn .Module ):
@@ -230,7 +229,7 @@ def make_body(cfg: CfgMC) -> nn.Sequential:
230229 [
231230 (
232231 f"l_{ layer_num } " ,
233- cfg ._make_layer (cfg , layer_num ) # type: ignore
232+ cfg .make_layer (cfg , layer_num ) # type: ignore
234233 )
235234 for layer_num in range (len (cfg .layers ))
236235 ]
@@ -247,21 +246,21 @@ def make_head(cfg: CfgMC) -> nn.Sequential:
247246 return nn .Sequential (OrderedDict (head ))
248247
249248
250- @dataclass
251249class ModelConstructor (CfgMC ):
252250 """Model constructor. As default - xresnet18"""
253251
254- def __post_init__ (self ):
255- if self ._init_cnn is None :
256- self ._init_cnn = init_cnn
257- if self ._make_stem is None :
258- self ._make_stem = make_stem
259- if self ._make_layer is None :
260- self ._make_layer = make_layer
261- if self ._make_body is None :
262- self ._make_body = make_body
263- if self ._make_head is None :
264- self ._make_head = make_head
252+ def __init__ (self , ** data ):
253+ super ().__init__ (** data )
254+ if self .init_cnn is None :
255+ self .init_cnn = init_cnn
256+ if self .make_stem is None :
257+ self .make_stem = make_stem
258+ if self .make_layer is None :
259+ self .make_layer = make_layer
260+ if self .make_body is None :
261+ self .make_body = make_body
262+ if self .make_head is None :
263+ self .make_head = make_head
265264
266265 if self .stem_sizes [0 ] != self .in_chans :
267266 self .stem_sizes = [self .in_chans ] + self .stem_sizes
@@ -276,30 +275,30 @@ def __post_init__(self):
276275
277276 @property
278277 def stem (self ):
279- return self ._make_stem (self ) # type: ignore
278+ return self .make_stem (self ) # type: ignore
280279
281280 @property
282281 def head (self ):
283- return self ._make_head (self ) # type: ignore
282+ return self .make_head (self ) # type: ignore
284283
285284 @property
286285 def body (self ):
287- return self ._make_body (self ) # type: ignore
286+ return self .make_body (self ) # type: ignore
288287
289288 @classmethod
290289 def from_cfg (cls , cfg : CfgMC ):
291- return cls (** asdict ( cfg ))
290+ return cls (** cfg . dict ( ))
292291
293292 def __call__ (self ):
294293 model = nn .Sequential (
295294 OrderedDict ([("stem" , self .stem ), ("body" , self .body ), ("head" , self .head )])
296295 )
297- self ._init_cnn (model ) # type: ignore
296+ self .init_cnn (model ) # type: ignore
298297 model .extra_repr = lambda : f"{ self .name } "
299298 return model
300299
301- def print_cfg (self ):
302- print (
300+ def __repr__ (self ):
301+ return (
303302 f"{ self .name } constructor\n "
304303 f" in_chans: { self .in_chans } , num_classes: { self .num_classes } \n "
305304 f" expansion: { self .expansion } , groups: { self .groups } , dw: { self .dw } , div_groups: { self .div_groups } \n "
@@ -310,14 +309,11 @@ def print_cfg(self):
310309 )
311310
312311
313- @dataclass
314312class XResNet34 (ModelConstructor ):
315313 name : str = "xresnet34"
316- layers : list [int ] = field ( default_factory = lambda : [3 , 4 , 6 , 3 ])
314+ layers : list [int ] = [3 , 4 , 6 , 3 ]
317315
318316
319- @dataclass
320- class XResNet50 (ModelConstructor ):
317+ class XResNet50 (XResNet34 ):
321318 name : str = "xresnet50"
322319 expansion : int = 4
323- layers : list [int ] = field (default_factory = lambda : [3 , 4 , 6 , 3 ])
0 commit comments