77from torch import nn
88
99from .blocks import BasicBlock , BottleneckBlock
10- from .helpers import (Cfg , ListStrMod , ModSeq , init_cnn , instantiate_module ,
11- is_module , nn_seq )
10+ from .helpers import (
11+ Cfg ,
12+ ListStrMod ,
13+ ModSeq ,
14+ init_cnn ,
15+ instantiate_module ,
16+ is_module ,
17+ nn_seq ,
18+ )
1219from .layers import ConvBnAct , SEModule , SimpleSelfAttention
1320
1421__all__ = [
@@ -60,7 +67,8 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
6067
6168 @field_validator ("act_fn" , "block" , "conv_layer" , "norm" , "pool" , "stem_pool" )
6269 def set_modules ( # pylint: disable=no-self-argument
63- cls , value : Union [nnModule , str ],
70+ cls ,
71+ value : Union [nnModule , str ],
6472 ) -> nnModule :
6573 """Check values, if string, convert to nn.Module."""
6674 if is_module (value ):
@@ -69,7 +77,9 @@ def set_modules( # pylint: disable=no-self-argument
6977
7078 @field_validator ("se" , "sa" )
7179 def set_se ( # pylint: disable=no-self-argument
72- cls , value : Union [bool , nnModule , str ], info : FieldValidationInfo ,
80+ cls ,
81+ value : Union [bool , nnModule , str ],
82+ info : FieldValidationInfo ,
7383 ) -> nnModule :
7484 if isinstance (value , (int , bool )):
7585 return DEFAULT_SE_SA [info .field_name ]
@@ -154,8 +164,8 @@ def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
154164
155165
156166def make_body (
157- cfg : ModelCfg ,
158- layer_constructor : Callable [[ModelCfg , int ], nn .Sequential ] = make_layer ,
167+ cfg : ModelCfg ,
168+ layer_constructor : Callable [[ModelCfg , int ], nn .Sequential ] = make_layer ,
159169) -> nn .Sequential :
160170 """Create model body."""
161171 if hasattr (cfg , "make_layer" ):
0 commit comments