1- from typing import Callable , Optional
1+ from typing import Callable , List , Optional , Type
22
33import torch
44from torch import nn
@@ -26,8 +26,8 @@ def __init__(
2626 in_channels : int ,
2727 mid_channels : int ,
2828 stride : int = 1 ,
29- conv_layer : type [ConvBnAct ] = ConvBnAct ,
30- act_fn : type [nn .Module ] = nn .ReLU ,
29+ conv_layer : Type [ConvBnAct ] = ConvBnAct ,
30+ act_fn : Type [nn .Module ] = nn .ReLU ,
3131 zero_bn : bool = True ,
3232 bn_1st : bool = True ,
3333 groups : int = 1 ,
@@ -150,16 +150,16 @@ def __init__(
150150 in_channels : int ,
151151 mid_channels : int ,
152152 stride : int = 1 ,
153- conv_layer : type [ConvBnAct ] = ConvBnAct ,
154- act_fn : type [nn .Module ] = nn .ReLU ,
153+ conv_layer : Type [ConvBnAct ] = ConvBnAct ,
154+ act_fn : Type [nn .Module ] = nn .ReLU ,
155155 zero_bn : bool = True ,
156156 bn_1st : bool = True ,
157157 groups : int = 1 ,
158158 dw : bool = False ,
159159 div_groups : Optional [int ] = None ,
160160 pool : Optional [Callable [[], nn .Module ]] = None ,
161- se : Optional [type [nn .Module ]] = None ,
162- sa : Optional [type [nn .Module ]] = None ,
161+ se : Optional [Type [nn .Module ]] = None ,
162+ sa : Optional [Type [nn .Module ]] = None ,
163163 ):
164164 super ().__init__ ()
165165 # pool defined at ModelConstructor.
@@ -265,7 +265,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
265265def make_stem (cfg : ModelCfg ) -> nn .Sequential :
266266 """Create xResnet stem -> 3 conv 3*3 instead of 1 conv 7*7"""
267267 len_stem = len (cfg .stem_sizes )
268- stem : list [ tuple [ str , nn . Module ]] = [
268+ stem : ListStrMod = [
269269 (
270270 f"conv_{ i } " ,
271271 cfg .conv_layer (
@@ -341,11 +341,11 @@ class XResNet(ModelConstructor):
341341 make_layer : Callable [[ModelCfg , int ], ModSeq ] = make_layer
342342 make_body : Callable [[ModelCfg ], ModSeq ] = make_body
343343 make_head : Callable [[ModelCfg ], ModSeq ] = make_head
344- block : type [nn .Module ] = XResBlock
344+ block : Type [nn .Module ] = XResBlock
345345
346346
347347class XResNet34 (XResNet ):
348- layers : list [int ] = [3 , 4 , 6 , 3 ]
348+ layers : List [int ] = [3 , 4 , 6 , 3 ]
349349
350350
351351class XResNet50 (XResNet34 ):
@@ -357,13 +357,13 @@ class YaResNet(XResNet):
357357 YaResBlock, Mish activation, custom stem.
358358 """
359359
360- block : type [nn .Module ] = YaResBlock
361- stem_sizes : list [int ] = [3 , 32 , 64 , 64 ]
362- act_fn : type [nn .Module ] = nn .Mish
360+ block : Type [nn .Module ] = YaResBlock
361+ stem_sizes : List [int ] = [3 , 32 , 64 , 64 ]
362+ act_fn : Type [nn .Module ] = nn .Mish
363363
364364
365365class YaResNet34 (YaResNet ):
366- layers : list [int ] = [3 , 4 , 6 , 3 ]
366+ layers : List [int ] = [3 , 4 , 6 , 3 ]
367367
368368
369369class YaResNet50 (YaResNet34 ):
0 commit comments