11from collections import OrderedDict
2- from typing import Callable , List , Optional , Type , Union
2+ from functools import partial
3+ from typing import Any , Callable , List , Optional , Type , Union
34
45import torch .nn as nn
56from pydantic import BaseModel
89
910__all__ = [
1011 "init_cnn" ,
11- "act_fn" ,
12+ # "act_fn",
1213 "ResBlock" ,
1314 "ModelConstructor" ,
1415 "XResNet34" ,
1516 "XResNet50" ,
1617]
1718
1819
19- act_fn = nn .ReLU ( inplace = True )
20+ # act_fn = nn.ReLU
2021
2122
2223class ResBlock (nn .Module ):
@@ -29,13 +30,13 @@ def __init__(
2930 mid_channels : int ,
3031 stride : int = 1 ,
3132 conv_layer = ConvBnAct ,
32- act_fn : nn .Module = act_fn ,
33+ act_fn : Type [ nn .Module ] = nn . ReLU ,
3334 zero_bn : bool = True ,
3435 bn_1st : bool = True ,
3536 groups : int = 1 ,
3637 dw : bool = False ,
3738 div_groups : Union [None , int ] = None ,
38- pool : Union [nn .Module , None ] = None ,
39+ pool : Union [Callable [[ Any ], nn .Module ] , None ] = None ,
3940 se : Union [nn .Module , None ] = None ,
4041 sa : Union [nn .Module , None ] = None ,
4142 ):
@@ -100,7 +101,7 @@ def __init__(
100101 if stride != 1 or in_channels != out_channels :
101102 id_layers = []
102103 if stride != 1 and pool is not None : # if pool - reduce by pool else stride 2 art id_conv
103- id_layers .append (("pool" , pool ))
104+ id_layers .append (("pool" , pool () ))
104105 if in_channels != out_channels or (stride != 1 and pool is None ):
105106 id_layers += [("id_conv" , conv_layer (
106107 in_channels ,
@@ -112,7 +113,7 @@ def __init__(
112113 self .id_conv = nn .Sequential (OrderedDict (id_layers ))
113114 else :
114115 self .id_conv = None
115- self .act_fn = act_fn
116+ self .act_fn = act_fn ( inplace = True ) # type: ignore
116117
117118 def forward (self , x ):
118119 identity = self .id_conv (x ) if self .id_conv is not None else x
@@ -130,8 +131,8 @@ class ModelCfg(BaseModel):
130131 block_sizes : List [int ] = [64 , 128 , 256 , 512 ]
131132 layers : List [int ] = [2 , 2 , 2 , 2 ]
132133 norm : Type [nn .Module ] = nn .BatchNorm2d
133- act_fn : nn .Module = nn .ReLU ( inplace = True )
134- pool : nn .Module = nn .AvgPool2d ( 2 , ceil_mode = True )
134+ act_fn : Type [ nn .Module ] = nn .ReLU
135+ pool : Callable [[ Any ], nn .Module ] = partial ( nn .AvgPool2d , kernel_size = 2 , ceil_mode = True )
135136 expansion : int = 1
136137 groups : int = 1
137138 dw : bool = False
@@ -144,7 +145,7 @@ class ModelCfg(BaseModel):
144145 zero_bn : bool = True
145146 stem_stride_on : int = 0
146147 stem_sizes : List [int ] = [32 , 32 , 64 ]
147- stem_pool : Union [nn .Module , None ] = nn .MaxPool2d ( kernel_size = 3 , stride = 2 , padding = 1 ) # type: ignore
148+ stem_pool : Union [Callable [[ Any ], nn .Module ] , None ] = partial ( nn .MaxPool2d , kernel_size = 3 , stride = 2 , padding = 1 )
148149 stem_bn_end : bool = False
149150 init_cnn : Optional [Callable [[nn .Module ], None ]] = None
150151 make_stem : Optional [Callable ] = None
@@ -192,7 +193,7 @@ def make_stem(self: ModelCfg) -> nn.Sequential:
192193 for i in range (len (self .stem_sizes ) - 1 )
193194 ]
194195 if self .stem_pool :
195- stem .append (("stem_pool" , self .stem_pool ))
196+ stem .append (("stem_pool" , self .stem_pool () ))
196197 if self .stem_bn_end :
197198 stem .append (("norm" , self .norm (self .stem_sizes [- 1 ]))) # type: ignore
198199 return nn .Sequential (OrderedDict (stem ))
@@ -313,7 +314,7 @@ def __repr__(self):
313314 f"{ self .name } constructor\n "
314315 f" in_chans: { self .in_chans } , num_classes: { self .num_classes } \n "
315316 f" expansion: { self .expansion } , groups: { self .groups } , dw: { self .dw } , div_groups: { self .div_groups } \n "
316- f" sa: { self .sa } , se: { self .se } \n "
317+ f" act_fn: { self . act_fn . __name__ } , sa: { self .sa } , se: { self .se } \n "
317318 f" stem sizes: { self .stem_sizes } , stride on { self .stem_stride_on } \n "
318319 f" body sizes { self .block_sizes } \n "
319320 f" layers: { self .layers } "
0 commit comments