Skip to content

Commit f70f617

Browse files
committed
black
1 parent 89015aa commit f70f617

File tree

5 files changed

+40
-21
lines changed

5 files changed

+40
-21
lines changed

src/model_constructor/activations.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,24 @@
55
from torch.nn import Mish
66

77

8-
__all__ = ['mish', 'Mish', 'mish_jit', 'MishJit', 'mish_jit_fwd', 'mish_jit_bwd', 'MishJitAutoFn', 'mish_me', 'MishMe',
9-
'hard_mish_jit', 'HardMishJit', 'hard_mish_jit_fwd', 'hard_mish_jit_bwd', 'HardMishJitAutoFn',
10-
'hard_mish_me', 'HardMishMe']
8+
__all__ = [
9+
"mish",
10+
"Mish",
11+
"mish_jit",
12+
"MishJit",
13+
"mish_jit_fwd",
14+
"mish_jit_bwd",
15+
"MishJitAutoFn",
16+
"mish_me",
17+
"MishMe",
18+
"hard_mish_jit",
19+
"HardMishJit",
20+
"hard_mish_jit_fwd",
21+
"hard_mish_jit_bwd",
22+
"HardMishJitAutoFn",
23+
"hard_mish_me",
24+
"HardMishMe",
25+
]
1126

1227

1328
def mish(x, inplace: bool = False):
@@ -40,7 +55,8 @@ def mish_jit(x, _inplace: bool = False):
4055
class MishJit(nn.Module):
4156
def __init__(self, inplace: bool = False):
4257
"""Jit version of Mish.
43-
Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681"""
58+
Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
59+
"""
4460
super(MishJit, self).__init__()
4561

4662
def forward(self, x):
@@ -61,8 +77,9 @@ def mish_jit_bwd(x, grad_output):
6177

6278

6379
class MishJitAutoFn(torch.autograd.Function):
64-
""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
80+
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
6581
A memory efficient, jit scripted variant of Mish"""
82+
6683
@staticmethod
6784
def forward(ctx, x):
6885
ctx.save_for_backward(x)
@@ -79,8 +96,9 @@ def mish_me(x, inplace=False):
7996

8097

8198
class MishMe(nn.Module):
82-
""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
99+
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
83100
A memory efficient, jit scripted variant of Mish"""
101+
84102
def __init__(self, inplace: bool = False):
85103
super(MishMe, self).__init__()
86104

@@ -90,18 +108,19 @@ def forward(self, x):
90108

91109
@torch.jit.script
92110
def hard_mish_jit(x, inplace: bool = False):
93-
""" Hard Mish
111+
"""Hard Mish
94112
Experimental, based on notes by Mish author Diganta Misra at
95113
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
96114
"""
97115
return 0.5 * x * (x + 2).clamp(min=0, max=2)
98116

99117

100118
class HardMishJit(nn.Module):
101-
""" Hard Mish
119+
"""Hard Mish
102120
Experimental, based on notes by Mish author Diganta Misra at
103121
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
104122
"""
123+
105124
def __init__(self, inplace: bool = False):
106125
super(HardMishJit, self).__init__()
107126

@@ -116,16 +135,17 @@ def hard_mish_jit_fwd(x):
116135

117136
@torch.jit.script
118137
def hard_mish_jit_bwd(x, grad_output):
119-
m = torch.ones_like(x) * (x >= -2.)
120-
m = torch.where((x >= -2.) & (x <= 0.), x + 1., m)
138+
m = torch.ones_like(x) * (x >= -2.0)
139+
m = torch.where((x >= -2.0) & (x <= 0.0), x + 1.0, m)
121140
return grad_output * m
122141

123142

124143
class HardMishJitAutoFn(torch.autograd.Function):
125-
""" A memory efficient, jit scripted variant of Hard Mish
144+
"""A memory efficient, jit scripted variant of Hard Mish
126145
Experimental, based on notes by Mish author Diganta Misra at
127146
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
128147
"""
148+
129149
@staticmethod
130150
def forward(ctx, x):
131151
ctx.save_for_backward(x)
@@ -142,10 +162,11 @@ def hard_mish_me(x, inplace: bool = False):
142162

143163

144164
class HardMishMe(nn.Module):
145-
""" A memory efficient, jit scripted variant of Hard Mish
165+
"""A memory efficient, jit scripted variant of Hard Mish
146166
Experimental, based on notes by Mish author Diganta Misra at
147167
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
148168
"""
169+
149170
def __init__(self, inplace: bool = False):
150171
super(HardMishMe, self).__init__()
151172

src/model_constructor/convmixer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def __init__(
6767
bn_1st: bool = False,
6868
pre_act: bool = False,
6969
):
70-
7170
conv_layer: List[tuple[str, nn.Module]] = [
7271
(
7372
"conv",

src/model_constructor/layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ def __init__(
7070
bn_1st: bool = True,
7171
zero_bn: bool = False,
7272
):
73-
7473
if padding is None:
7574
padding = kernel_size // 2
7675
layers: List[tuple[str, nn.Module]] = [
@@ -124,7 +123,6 @@ def __init__(
124123
groups=1,
125124
**kwargs # pylint: disable=unused-argument
126125
):
127-
128126
if padding is None:
129127
padding = ks // 2
130128
layers = [
@@ -197,6 +195,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
197195

198196
class SEBlock(nn.Module):
199197
"""se block"""
198+
200199
# first version
201200
se_layer = nn.Linear
202201
act_fn = nn.ReLU(inplace=True)
@@ -226,6 +225,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
226225

227226
class SEBlockConv(nn.Module):
228227
"""se block with conv on excitation"""
228+
229229
# first version
230230
se_layer = nn.Conv2d
231231
act_fn = nn.ReLU(inplace=True)

src/model_constructor/mxresnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ class MxResNet(XResNet):
77
stem_sizes: list[int] = [3, 32, 64, 64]
88
act_fn: type[nn.Module] = nn.Mish
99

10+
1011
class MxResNet34(XResNet34):
1112
stem_sizes: list[int] = [3, 32, 64, 64]
1213
act_fn: type[nn.Module] = nn.Mish

src/model_constructor/universal_blocks.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -297,14 +297,10 @@ def make_layer(cfg: ModelCfg, layer_num: int) -> nn.Sequential: # type: ignore
297297
f"bl_{block_num}",
298298
cfg.block(
299299
cfg.expansion, # type: ignore
300-
block_chs[layer_num]
301-
if block_num == 0
302-
else block_chs[layer_num + 1],
300+
block_chs[layer_num] if block_num == 0 else block_chs[layer_num + 1],
303301
block_chs[layer_num + 1],
304302
stride if block_num == 0 else 1,
305-
sa=cfg.sa
306-
if (block_num == num_blocks - 1) and layer_num == 0
307-
else None,
303+
sa=cfg.sa if (block_num == num_blocks - 1) and layer_num == 0 else None,
308304
conv_layer=cfg.conv_layer,
309305
act_fn=cfg.act_fn,
310306
pool=cfg.pool,
@@ -340,6 +336,7 @@ def make_head(cfg: ModelCfg) -> nn.Sequential: # type: ignore
340336

341337
class XResNet(ModelConstructor):
342338
"""Base Xresnet constructor."""
339+
343340
make_stem: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = make_stem
344341
make_layer: Callable[[ModelCfg, int], Union[nn.Module, nn.Sequential]] = make_layer
345342
make_body: Callable[[ModelCfg], Union[nn.Module, nn.Sequential]] = make_body
@@ -359,6 +356,7 @@ class YaResNet(XResNet):
359356
"""Base Yaresnet constructor.
360357
YaResBlock, Mish activation, custom stem.
361358
"""
359+
362360
block: type[nn.Module] = YaResBlock
363361
stem_sizes: list[int] = [3, 32, 64, 64]
364362
act_fn: type[nn.Module] = nn.Mish

0 commit comments

Comments
 (0)