11from typing import Callable , Union
22
3+ import torch
34from torch import nn
45
56from .helpers import nn_seq
67from .layers import ConvBnAct , get_act
7- from .model_constructor import ModelCfg , ModelConstructor
8+ from .model_constructor import ListStrMod , ModelCfg , ModelConstructor
89
910__all__ = [
1011 "XResBlock" ,
@@ -25,7 +26,7 @@ def __init__(
2526 in_channels : int ,
2627 mid_channels : int ,
2728 stride : int = 1 ,
28- conv_layer : type [nn . Module ] = ConvBnAct ,
29+ conv_layer : type [ConvBnAct ] = ConvBnAct ,
2930 act_fn : type [nn .Module ] = nn .ReLU ,
3031 zero_bn : bool = True ,
3132 bn_1st : bool = True ,
@@ -42,7 +43,7 @@ def __init__(
4243 if div_groups is not None : # check if groups != 1 and div_groups
4344 groups = int (mid_channels / div_groups )
4445 if expansion == 1 :
45- layers = [
46+ layers : ListStrMod = [
4647 (
4748 "conv_0" ,
4849 conv_layer (
@@ -69,7 +70,7 @@ def __init__(
6970 ),
7071 ]
7172 else :
72- layers = [
73+ layers : ListStrMod = [
7374 (
7475 "conv_0" ,
7576 conv_layer (
@@ -110,13 +111,13 @@ def __init__(
110111 layers .append (("sa" , sa (out_channels )))
111112 self .convs = nn_seq (layers )
112113 if stride != 1 or in_channels != out_channels :
113- id_layers = []
114+ id_layers : ListStrMod = []
114115 if (
115116 stride != 1 and pool is not None
116117 ): # if pool - reduce by pool else stride 2 art id_conv
117118 id_layers .append (("pool" , pool ()))
118119 if in_channels != out_channels or (stride != 1 and pool is None ):
119- id_layers += [
120+ id_layers . append (
120121 (
121122 "id_conv" ,
122123 conv_layer (
@@ -127,13 +128,13 @@ def __init__(
127128 act_fn = False ,
128129 ),
129130 )
130- ]
131+ )
131132 self .id_conv = nn_seq (id_layers )
132133 else :
133134 self .id_conv = None
134135 self .act_fn = get_act (act_fn )
135136
136- def forward (self , x ):
137+ def forward (self , x : torch . Tensor ): # type: ignore
137138 identity = self .id_conv (x ) if self .id_conv is not None else x
138139 return self .act_fn (self .convs (x ) + identity )
139140
@@ -147,7 +148,7 @@ def __init__(
147148 in_channels : int ,
148149 mid_channels : int ,
149150 stride : int = 1 ,
150- conv_layer = ConvBnAct ,
151+ conv_layer : type [ ConvBnAct ] = ConvBnAct ,
151152 act_fn : type [nn .Module ] = nn .ReLU ,
152153 zero_bn : bool = True ,
153154 bn_1st : bool = True ,
@@ -173,7 +174,7 @@ def __init__(
173174 else :
174175 self .reduce = None
175176 if expansion == 1 :
176- layers = [
177+ layers : ListStrMod = [
177178 (
178179 "conv_0" ,
179180 conv_layer (
@@ -200,7 +201,7 @@ def __init__(
200201 ),
201202 ]
202203 else :
203- layers = [
204+ layers : ListStrMod = [
204205 (
205206 "conv_0" ,
206207 conv_layer (
@@ -252,15 +253,15 @@ def __init__(
252253 self .id_conv = None
253254 self .merge = get_act (act_fn )
254255
255- def forward (self , x ):
256+ def forward (self , x : torch . Tensor ): # type: ignore
256257 if self .reduce :
257258 x = self .reduce (x )
258259 identity = self .id_conv (x ) if self .id_conv is not None else x
259260 return self .merge (self .convs (x ) + identity )
260261
261262
262263def make_stem (cfg : ModelCfg ) -> nn .Sequential : # type: ignore
263- """Create xResnet stem -> 3 conv 3*3 instead 1 conv 7*7"""
264+ """Create xResnet stem -> 3 conv 3*3 instead of 1 conv 7*7"""
264265 len_stem = len (cfg .stem_sizes )
265266 stem : list [tuple [str , nn .Module ]] = [
266267 (
0 commit comments