Skip to content

Commit de2bb6a

Browse files
committed
act, pool
1 parent 13404b9 commit de2bb6a

File tree

3 files changed

+23
-22
lines changed

3 files changed

+23
-22
lines changed

src/model_constructor/layers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import OrderedDict
2-
from typing import List, Optional, Union
2+
from typing import List, Optional, Type, Union
33

44
import torch
55
import torch.nn as nn
@@ -61,7 +61,7 @@ def __init__(
6161
padding: Optional[int] = None,
6262
bias: bool = False,
6363
groups: int = 1,
64-
act_fn: Union[nn.Module, bool] = act_fn,
64+
act_fn: Union[Type[nn.Module], bool] = nn.ReLU,
6565
pre_act: bool = False,
6666
bn_layer: bool = True,
6767
bn_1st: bool = True,
@@ -88,14 +88,14 @@ def __init__(
8888
bn = self.batchnorm_module(out_channels)
8989
nn.init.constant_(bn.weight, 0.0 if zero_bn else 1.0)
9090
layers.append(("bn", bn))
91-
if isinstance(act_fn, nn.Module): # act_fn either nn.Module or False
91+
if act_fn: # act_fn either nn.Module subclass or False
9292
if pre_act:
9393
act_position = 0
9494
elif not bn_1st:
9595
act_position = 1
9696
else:
9797
act_position = len(layers)
98-
layers.insert(act_position, ("act_fn", act_fn))
98+
layers.insert(act_position, ("act_fn", act_fn(inplace=True))) # type: ignore
9999
super().__init__(OrderedDict(layers))
100100

101101

src/model_constructor/model_constructor.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections import OrderedDict
2-
from typing import Callable, List, Optional, Type, Union
2+
from functools import partial
3+
from typing import Any, Callable, List, Optional, Type, Union
34

45
import torch.nn as nn
56
from pydantic import BaseModel
@@ -8,15 +9,15 @@
89

910
__all__ = [
1011
"init_cnn",
11-
"act_fn",
12+
# "act_fn",
1213
"ResBlock",
1314
"ModelConstructor",
1415
"XResNet34",
1516
"XResNet50",
1617
]
1718

1819

19-
act_fn = nn.ReLU(inplace=True)
20+
# act_fn = nn.ReLU
2021

2122

2223
class ResBlock(nn.Module):
@@ -29,13 +30,13 @@ def __init__(
2930
mid_channels: int,
3031
stride: int = 1,
3132
conv_layer=ConvBnAct,
32-
act_fn: nn.Module = act_fn,
33+
act_fn: Type[nn.Module] = nn.ReLU,
3334
zero_bn: bool = True,
3435
bn_1st: bool = True,
3536
groups: int = 1,
3637
dw: bool = False,
3738
div_groups: Union[None, int] = None,
38-
pool: Union[nn.Module, None] = None,
39+
pool: Union[Callable[[Any], nn.Module], None] = None,
3940
se: Union[nn.Module, None] = None,
4041
sa: Union[nn.Module, None] = None,
4142
):
@@ -100,7 +101,7 @@ def __init__(
100101
if stride != 1 or in_channels != out_channels:
101102
id_layers = []
102103
if stride != 1 and pool is not None: # if pool - reduce by pool else stride 2 art id_conv
103-
id_layers.append(("pool", pool))
104+
id_layers.append(("pool", pool()))
104105
if in_channels != out_channels or (stride != 1 and pool is None):
105106
id_layers += [("id_conv", conv_layer(
106107
in_channels,
@@ -112,7 +113,7 @@ def __init__(
112113
self.id_conv = nn.Sequential(OrderedDict(id_layers))
113114
else:
114115
self.id_conv = None
115-
self.act_fn = act_fn
116+
self.act_fn = act_fn(inplace=True) # type: ignore
116117

117118
def forward(self, x):
118119
identity = self.id_conv(x) if self.id_conv is not None else x
@@ -130,8 +131,8 @@ class ModelCfg(BaseModel):
130131
block_sizes: List[int] = [64, 128, 256, 512]
131132
layers: List[int] = [2, 2, 2, 2]
132133
norm: Type[nn.Module] = nn.BatchNorm2d
133-
act_fn: nn.Module = nn.ReLU(inplace=True)
134-
pool: nn.Module = nn.AvgPool2d(2, ceil_mode=True)
134+
act_fn: Type[nn.Module] = nn.ReLU
135+
pool: Callable[[Any], nn.Module] = partial(nn.AvgPool2d, kernel_size=2, ceil_mode=True)
135136
expansion: int = 1
136137
groups: int = 1
137138
dw: bool = False
@@ -144,7 +145,7 @@ class ModelCfg(BaseModel):
144145
zero_bn: bool = True
145146
stem_stride_on: int = 0
146147
stem_sizes: List[int] = [32, 32, 64]
147-
stem_pool: Union[nn.Module, None] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # type: ignore
148+
stem_pool: Union[Callable[[Any], nn.Module], None] = partial(nn.MaxPool2d, kernel_size=3, stride=2, padding=1)
148149
stem_bn_end: bool = False
149150
init_cnn: Optional[Callable[[nn.Module], None]] = None
150151
make_stem: Optional[Callable] = None
@@ -192,7 +193,7 @@ def make_stem(self: ModelCfg) -> nn.Sequential:
192193
for i in range(len(self.stem_sizes) - 1)
193194
]
194195
if self.stem_pool:
195-
stem.append(("stem_pool", self.stem_pool))
196+
stem.append(("stem_pool", self.stem_pool()))
196197
if self.stem_bn_end:
197198
stem.append(("norm", self.norm(self.stem_sizes[-1]))) # type: ignore
198199
return nn.Sequential(OrderedDict(stem))
@@ -313,7 +314,7 @@ def __repr__(self):
313314
f"{self.name} constructor\n"
314315
f" in_chans: {self.in_chans}, num_classes: {self.num_classes}\n"
315316
f" expansion: {self.expansion}, groups: {self.groups}, dw: {self.dw}, div_groups: {self.div_groups}\n"
316-
f" sa: {self.sa}, se: {self.se}\n"
317+
f" act_fn: {self.act_fn.__name__}, sa: {self.sa}, se: {self.se}\n"
317318
f" stem sizes: {self.stem_sizes}, stride on {self.stem_stride_on}\n"
318319
f" body sizes {self.block_sizes}\n"
319320
f" layers: {self.layers}"

src/model_constructor/yaresnet.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Yet another ResNet.
33

44
from collections import OrderedDict
5-
from typing import List, Type, Union
5+
from typing import Any, Callable, List, Type, Union
66

77
import torch.nn as nn
88
from torch.nn import Mish
@@ -15,7 +15,7 @@
1515
]
1616

1717

18-
act_fn = nn.ReLU(inplace=True)
18+
# act_fn = nn.ReLU(inplace=True)
1919

2020

2121
class YaResBlock(nn.Module):
@@ -28,13 +28,13 @@ def __init__(
2828
mid_channels: int,
2929
stride: int = 1,
3030
conv_layer=ConvBnAct,
31-
act_fn: nn.Module = act_fn,
31+
act_fn: Type[nn.Module] = nn.ReLU,
3232
zero_bn: bool = True,
3333
bn_1st: bool = True,
3434
groups: int = 1,
3535
dw: bool = False,
3636
div_groups: Union[None, int] = None,
37-
pool: Union[nn.Module, None] = None,
37+
pool: Union[Callable[[Any], nn.Module], None] = None,
3838
se: Union[nn.Module, None] = None,
3939
sa: Union[nn.Module, None] = None,
4040
):
@@ -49,7 +49,7 @@ def __init__(
4949
self.reduce = conv_layer(in_channels, in_channels, 1, stride=2)
5050
# warnings.warn("pool not passed") # need to warn?
5151
else:
52-
self.reduce = pool
52+
self.reduce = pool()
5353
else:
5454
self.reduce = None
5555
if expansion == 1:
@@ -115,7 +115,7 @@ def __init__(
115115
)
116116
else:
117117
self.id_conv = None
118-
self.merge = act_fn
118+
self.merge = act_fn()
119119

120120
def forward(self, x):
121121
if self.reduce:

0 commit comments

Comments
 (0)