11from collections import OrderedDict
22from functools import partial
3- from typing import Any , Callable , List , Optional , Type , TypeVar , Union
3+ from typing import Any , Callable , Optional , TypeVar , Union
44
55import torch .nn as nn
66from pydantic import BaseModel , root_validator
@@ -39,7 +39,7 @@ def __init__(
3939 mid_channels : int ,
4040 stride : int = 1 ,
4141 conv_layer = ConvBnAct ,
42- act_fn : Type [nn .Module ] = nn .ReLU ,
42+ act_fn : type [nn .Module ] = nn .ReLU ,
4343 zero_bn : bool = True ,
4444 bn_1st : bool = True ,
4545 groups : int = 1 ,
@@ -153,7 +153,7 @@ def forward(self, x):
153153
154154def make_stem (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
155155 len_stem = len (cfg .stem_sizes )
156- stem : List [tuple [str , nn .Module ]] = [
156+ stem : list [tuple [str , nn .Module ]] = [
157157 (
158158 f"conv_{ i } " ,
159159 cfg .conv_layer (
@@ -238,27 +238,27 @@ class ModelCfg(BaseModel):
238238 name : Optional [str ] = None
239239 in_chans : int = 3
240240 num_classes : int = 1000
241- block : Type [nn .Module ] = ResBlock
242- conv_layer : Type [nn .Module ] = ConvBnAct
243- block_sizes : List [int ] = [64 , 128 , 256 , 512 ]
244- layers : List [int ] = [2 , 2 , 2 , 2 ]
245- norm : Type [nn .Module ] = nn .BatchNorm2d
246- act_fn : Type [nn .Module ] = nn .ReLU
241+ block : type [nn .Module ] = ResBlock
242+ conv_layer : type [nn .Module ] = ConvBnAct
243+ block_sizes : list [int ] = [64 , 128 , 256 , 512 ]
244+ layers : list [int ] = [2 , 2 , 2 , 2 ]
245+ norm : type [nn .Module ] = nn .BatchNorm2d
246+ act_fn : type [nn .Module ] = nn .ReLU
247247 pool : Callable [[Any ], nn .Module ] = partial (
248248 nn .AvgPool2d , kernel_size = 2 , ceil_mode = True
249249 )
250250 expansion : int = 1
251251 groups : int = 1
252252 dw : bool = False
253253 div_groups : Union [int , None ] = None
254- sa : Union [bool , int , Type [nn .Module ]] = False
255- se : Union [bool , int , Type [nn .Module ]] = False
254+ sa : Union [bool , int , type [nn .Module ]] = False
255+ se : Union [bool , int , type [nn .Module ]] = False
256256 se_module : Union [bool , None ] = None
257257 se_reduction : Union [int , None ] = None
258258 bn_1st : bool = True
259259 zero_bn : bool = True
260260 stem_stride_on : int = 0
261- stem_sizes : List [int ] = [32 , 32 , 64 ]
261+ stem_sizes : list [int ] = [32 , 32 , 64 ]
262262 stem_pool : Union [Callable [[], nn .Module ], None ] = partial (
263263 nn .MaxPool2d , kernel_size = 3 , stride = 2 , padding = 1
264264 )
@@ -286,7 +286,7 @@ def _get_str_value(self, field: str) -> str:
286286 def __repr__ (self ) -> str :
287287 return f"{ self .__repr_name__ ()} (\n { self .__repr_str__ (chr (10 ) + ' ' )} )"
288288
289- def __repr_args__ (self ):
289+ def __repr_args__ (self ) -> list [ tuple [ str , str ]] :
290290 return [
291291 (field , str_value )
292292 for field in self .__fields__
@@ -325,7 +325,7 @@ def body(self):
325325 def from_cfg (cls , cfg : ModelCfg ):
326326 return cls (** cfg .dict ())
327327
328- def __call__ (self ):
328+ def __call__ (self ) -> nn . Sequential :
329329 model_name = self .name or self .__class__ .__name__
330330 named_sequential = type (model_name , (nn .Sequential ,), {})
331331 model = named_sequential (
@@ -338,13 +338,14 @@ def __call__(self):
338338 return model
339339
340340 def _get_extra_repr (self ) -> str :
341+ """Repr for changed fields"""
341342 return " " .join (
342343 f"{ field } : { self ._get_str_value (field )} ,"
343344 for field in self .__fields_set__
344345 if field != "name"
345- )[:- 1 ]
346+ )[:- 1 ] # strip last comma.
346347
347- def __repr__ (self ):
348+ def __repr__ (self ) -> str :
348349 se_repr = self .se .__name__ if self .se else "False" # type: ignore
349350 model_name = self .name or self .__class__ .__name__
350351 return (
0 commit comments