55import torch .nn as nn
66from pydantic import BaseModel , root_validator
77
8- from .layers import ConvBnAct , SEModule , SimpleSelfAttention
8+ from .layers import ConvBnAct , SEModule , SimpleSelfAttention , get_act
99
1010__all__ = [
1111 "init_cnn" ,
@@ -32,7 +32,7 @@ def __init__(
3232 groups : int = 1 ,
3333 dw : bool = False ,
3434 div_groups : Union [None , int ] = None ,
35- pool : Union [Callable [[Any ], nn .Module ], None ] = None ,
35+ pool : Union [Callable [[], nn .Module ], None ] = None ,
3636 se : Union [nn .Module , None ] = None ,
3737 sa : Union [nn .Module , None ] = None ,
3838 ):
@@ -109,7 +109,7 @@ def __init__(
109109 self .id_conv = nn .Sequential (OrderedDict (id_layers ))
110110 else :
111111 self .id_conv = None
112- self .act_fn = act_fn ( inplace = True ) # type: ignore
112+ self .act_fn = get_act ( act_fn ) # type: ignore
113113
114114 def forward (self , x ):
115115 identity = self .id_conv (x ) if self .id_conv is not None else x
@@ -141,13 +141,13 @@ class ModelCfg(BaseModel):
141141 zero_bn : bool = True
142142 stem_stride_on : int = 0
143143 stem_sizes : List [int ] = [32 , 32 , 64 ]
144- stem_pool : Union [Callable [[Any ], nn .Module ], None ] = partial (nn .MaxPool2d , kernel_size = 3 , stride = 2 , padding = 1 )
144+ stem_pool : Union [Callable [[], nn .Module ], None ] = partial (nn .MaxPool2d , kernel_size = 3 , stride = 2 , padding = 1 )
145145 stem_bn_end : bool = False
146146 init_cnn : Optional [Callable [[nn .Module ], None ]] = None
147- make_stem : Optional [Callable ] = None
148- make_layer : Optional [Callable ] = None
149- make_body : Optional [Callable ] = None
150- make_head : Optional [Callable ] = None
147+ make_stem : Optional [Callable [[ "ModelCfg" ], nn . Module ] ] = None
148+ make_layer : Optional [Callable [[ "ModelCfg" ], nn . Module ] ] = None
149+ make_body : Optional [Callable [[ "ModelCfg" ], nn . Module ] ] = None
150+ make_head : Optional [Callable [[ "ModelCfg" ], nn . Module ] ] = None
151151
152152 class Config :
153153 arbitrary_types_allowed = True
0 commit comments