Skip to content

Commit a0a7e3d

Browse files
committed
typing uni block
1 parent eb20472 commit a0a7e3d

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

src/model_constructor/universal_blocks.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from typing import Callable, Union
22

3+
import torch
34
from torch import nn
45

56
from .helpers import nn_seq
67
from .layers import ConvBnAct, get_act
7-
from .model_constructor import ModelCfg, ModelConstructor
8+
from .model_constructor import ListStrMod, ModelCfg, ModelConstructor
89

910
__all__ = [
1011
"XResBlock",
@@ -25,7 +26,7 @@ def __init__(
2526
in_channels: int,
2627
mid_channels: int,
2728
stride: int = 1,
28-
conv_layer: type[nn.Module] = ConvBnAct,
29+
conv_layer: type[ConvBnAct] = ConvBnAct,
2930
act_fn: type[nn.Module] = nn.ReLU,
3031
zero_bn: bool = True,
3132
bn_1st: bool = True,
@@ -42,7 +43,7 @@ def __init__(
4243
if div_groups is not None: # check if groups != 1 and div_groups
4344
groups = int(mid_channels / div_groups)
4445
if expansion == 1:
45-
layers = [
46+
layers: ListStrMod = [
4647
(
4748
"conv_0",
4849
conv_layer(
@@ -69,7 +70,7 @@ def __init__(
6970
),
7071
]
7172
else:
72-
layers = [
73+
layers: ListStrMod = [
7374
(
7475
"conv_0",
7576
conv_layer(
@@ -110,13 +111,13 @@ def __init__(
110111
layers.append(("sa", sa(out_channels)))
111112
self.convs = nn_seq(layers)
112113
if stride != 1 or in_channels != out_channels:
113-
id_layers = []
114+
id_layers: ListStrMod = []
114115
if (
115116
stride != 1 and pool is not None
116117
): # if pool - reduce by pool else stride 2 art id_conv
117118
id_layers.append(("pool", pool()))
118119
if in_channels != out_channels or (stride != 1 and pool is None):
119-
id_layers += [
120+
id_layers.append(
120121
(
121122
"id_conv",
122123
conv_layer(
@@ -127,13 +128,13 @@ def __init__(
127128
act_fn=False,
128129
),
129130
)
130-
]
131+
)
131132
self.id_conv = nn_seq(id_layers)
132133
else:
133134
self.id_conv = None
134135
self.act_fn = get_act(act_fn)
135136

136-
def forward(self, x):
137+
def forward(self, x: torch.Tensor): # type: ignore
137138
identity = self.id_conv(x) if self.id_conv is not None else x
138139
return self.act_fn(self.convs(x) + identity)
139140

@@ -147,7 +148,7 @@ def __init__(
147148
in_channels: int,
148149
mid_channels: int,
149150
stride: int = 1,
150-
conv_layer=ConvBnAct,
151+
conv_layer: type[ConvBnAct] = ConvBnAct,
151152
act_fn: type[nn.Module] = nn.ReLU,
152153
zero_bn: bool = True,
153154
bn_1st: bool = True,
@@ -173,7 +174,7 @@ def __init__(
173174
else:
174175
self.reduce = None
175176
if expansion == 1:
176-
layers = [
177+
layers: ListStrMod = [
177178
(
178179
"conv_0",
179180
conv_layer(
@@ -200,7 +201,7 @@ def __init__(
200201
),
201202
]
202203
else:
203-
layers = [
204+
layers: ListStrMod = [
204205
(
205206
"conv_0",
206207
conv_layer(
@@ -252,15 +253,15 @@ def __init__(
252253
self.id_conv = None
253254
self.merge = get_act(act_fn)
254255

255-
def forward(self, x):
256+
def forward(self, x: torch.Tensor): # type: ignore
256257
if self.reduce:
257258
x = self.reduce(x)
258259
identity = self.id_conv(x) if self.id_conv is not None else x
259260
return self.merge(self.convs(x) + identity)
260261

261262

262263
def make_stem(cfg: ModelCfg) -> nn.Sequential: # type: ignore
263-
"""Create xResnet stem -> 3 conv 3*3 instead 1 conv 7*7"""
264+
"""Create xResnet stem -> 3 conv 3*3 instead of 1 conv 7*7"""
264265
len_stem = len(cfg.stem_sizes)
265266
stem: list[tuple[str, nn.Module]] = [
266267
(

0 commit comments

Comments
 (0)