Skip to content

Commit 3ef9204

Browse files
committed
act func inplace
1 parent 8f9d6ba commit 3ef9204

File tree

5 files changed

+29
-20
lines changed

5 files changed

+29
-20
lines changed

src/model_constructor/layers.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"noop",
1111
"Noop",
1212
"ConvLayer",
13-
"act_fn",
13+
"act",
1414
"conv1d",
1515
"SimpleSelfAttention",
1616
"SEBlock",
@@ -43,7 +43,16 @@ def forward(self, x):
4343
return x
4444

4545

46-
act_fn = nn.ReLU(inplace=True)
46+
act = nn.ReLU(inplace=True)
47+
48+
49+
def get_act(act_fn: Type[nn.Module], inplace: bool = True) -> nn.Module:
50+
"""Return obj of act_fn, inplace if possible."""
51+
try:
52+
res = act_fn(inplace=inplace) # type: ignore
53+
except TypeError:
54+
res = act_fn()
55+
return res
4756

4857

4958
class ConvBnAct(nn.Sequential):
@@ -95,7 +104,7 @@ def __init__(
95104
act_position = 1
96105
else:
97106
act_position = len(layers)
98-
layers.insert(act_position, ("act_fn", act_fn(inplace=True))) # type: ignore
107+
layers.insert(act_position, ("act_fn", get_act(act_fn))) # type: ignore
99108
super().__init__(OrderedDict(layers))
100109

101110

@@ -112,7 +121,7 @@ def __init__(
112121
ks=3,
113122
stride=1,
114123
act=True,
115-
act_fn=act_fn,
124+
act_fn=act,
116125
bn_layer=True,
117126
bn_1st=True,
118127
zero_bn=False,

src/model_constructor/model_constructor.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.nn as nn
66
from pydantic import BaseModel, root_validator
77

8-
from .layers import ConvBnAct, SEModule, SimpleSelfAttention
8+
from .layers import ConvBnAct, SEModule, SimpleSelfAttention, get_act
99

1010
__all__ = [
1111
"init_cnn",
@@ -32,7 +32,7 @@ def __init__(
3232
groups: int = 1,
3333
dw: bool = False,
3434
div_groups: Union[None, int] = None,
35-
pool: Union[Callable[[Any], nn.Module], None] = None,
35+
pool: Union[Callable[[], nn.Module], None] = None,
3636
se: Union[nn.Module, None] = None,
3737
sa: Union[nn.Module, None] = None,
3838
):
@@ -109,7 +109,7 @@ def __init__(
109109
self.id_conv = nn.Sequential(OrderedDict(id_layers))
110110
else:
111111
self.id_conv = None
112-
self.act_fn = act_fn(inplace=True) # type: ignore
112+
self.act_fn = get_act(act_fn) # type: ignore
113113

114114
def forward(self, x):
115115
identity = self.id_conv(x) if self.id_conv is not None else x
@@ -141,13 +141,13 @@ class ModelCfg(BaseModel):
141141
zero_bn: bool = True
142142
stem_stride_on: int = 0
143143
stem_sizes: List[int] = [32, 32, 64]
144-
stem_pool: Union[Callable[[Any], nn.Module], None] = partial(nn.MaxPool2d, kernel_size=3, stride=2, padding=1)
144+
stem_pool: Union[Callable[[], nn.Module], None] = partial(nn.MaxPool2d, kernel_size=3, stride=2, padding=1)
145145
stem_bn_end: bool = False
146146
init_cnn: Optional[Callable[[nn.Module], None]] = None
147-
make_stem: Optional[Callable] = None
148-
make_layer: Optional[Callable] = None
149-
make_body: Optional[Callable] = None
150-
make_head: Optional[Callable] = None
147+
make_stem: Optional[Callable[["ModelCfg"], nn.Module]] = None
148+
make_layer: Optional[Callable[["ModelCfg"], nn.Module]] = None
149+
make_body: Optional[Callable[["ModelCfg"], nn.Module]] = None
150+
make_head: Optional[Callable[["ModelCfg"], nn.Module]] = None
151151

152152
class Config:
153153
arbitrary_types_allowed = True

src/model_constructor/twist.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import OrderedDict
2-
from .layers import ConvLayer, noop, act_fn, SimpleSelfAttention
2+
from .layers import ConvLayer, noop, act, SimpleSelfAttention
33

44
import torch
55
import torch.nn as nn
@@ -112,7 +112,7 @@ class NewResBlockTwist(nn.Module):
112112
Now YaResBlock.'''
113113

114114
def __init__(self, expansion, ni, nh, stride=1,
115-
conv_layer=ConvLayer, act_fn=act_fn, bn_1st=True,
115+
conv_layer=ConvLayer, act_fn=act, bn_1st=True,
116116
pool=nn.AvgPool2d(2, ceil_mode=True), sa=False, sym=False, zero_bn=True, **kwargs):
117117
super().__init__()
118118
nf, ni = nh * expansion, ni * expansion
@@ -139,7 +139,7 @@ class ResBlockTwist(nn.Module):
139139
'''Resnet block with ConvTwist'''
140140

141141
def __init__(self, expansion, ni, nh, stride=1,
142-
conv_layer=ConvLayer, act_fn=act_fn, zero_bn=True, bn_1st=True,
142+
conv_layer=ConvLayer, act_fn=act, zero_bn=True, bn_1st=True,
143143
pool=nn.AvgPool2d(2, ceil_mode=True), sa=False, sym=False, **kwargs):
144144
super().__init__()
145145
nf, ni = nh * expansion, ni * expansion

src/model_constructor/xresnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections import OrderedDict
33

44
from .base_constructor import Net
5-
from .layers import ConvLayer, Noop, act_fn
5+
from .layers import ConvLayer, Noop, act
66

77

88
__all__ = ['DownsampleLayer', 'XResBlock', 'xresnet18', 'xresnet34', 'xresnet50']
@@ -25,7 +25,7 @@ class XResBlock(nn.Module):
2525
'''XResnet block'''
2626

2727
def __init__(self, ni, nh, expansion=1, stride=1, zero_bn=True,
28-
conv_layer=ConvLayer, act_fn=act_fn, **kwargs):
28+
conv_layer=ConvLayer, act_fn=act, **kwargs):
2929
super().__init__()
3030
nf, ni = nh * expansion, ni * expansion
3131
layers = [('conv_0', conv_layer(ni, nh, 3, stride=stride, act_fn=act_fn, **kwargs)),

src/model_constructor/yaresnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.nn as nn
88
from torch.nn import Mish
99

10-
from .layers import ConvBnAct
10+
from .layers import ConvBnAct, get_act
1111
from .model_constructor import ModelConstructor
1212

1313
__all__ = [
@@ -34,7 +34,7 @@ def __init__(
3434
groups: int = 1,
3535
dw: bool = False,
3636
div_groups: Union[None, int] = None,
37-
pool: Union[Callable[[Any], nn.Module], None] = None,
37+
pool: Union[Callable[[], nn.Module], None] = None,
3838
se: Union[nn.Module, None] = None,
3939
sa: Union[nn.Module, None] = None,
4040
):
@@ -115,7 +115,7 @@ def __init__(
115115
)
116116
else:
117117
self.id_conv = None
118-
self.merge = act_fn()
118+
self.merge = get_act(act_fn)
119119

120120
def forward(self, x):
121121
if self.reduce:

0 commit comments

Comments
 (0)