Skip to content

Commit ece3c83

Browse files
committed
basic & bottle blocks
1 parent a0a7e3d commit ece3c83

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

src/model_constructor/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@
66

77
def nn_seq(list_of_tuples: Iterable[tuple[str, nn.Module]]) -> nn.Sequential:
88
"""return nn.Sequential from OrderedDict from list of tuples"""
9-
return nn.Sequential(OrderedDict(list_of_tuples))
9+
return nn.Sequential(OrderedDict(list_of_tuples)) #

src/model_constructor/layers.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import List, Optional, Type, Union
33

44
import torch
5-
import torch.nn as nn
5+
from torch import nn
66
from torch.nn.utils.spectral_norm import spectral_norm
77

88
__all__ = [
@@ -21,19 +21,19 @@
2121
class Flatten(nn.Module):
2222
"""flat x to vector"""
2323

24-
def forward(self, x):
24+
def forward(self, x: torch.Tensor) -> torch.Tensor:
2525
return x.view(x.size(0), -1)
2626

2727

28-
def noop(x):
28+
def noop(x: torch.Tensor) -> torch.Tensor:
2929
"""Dummy func. Return input"""
3030
return x
3131

3232

3333
class Noop(nn.Module):
3434
"""Dummy module"""
3535

36-
def forward(self, x):
36+
def forward(self, x: torch.Tensor) -> torch.Tensor:
3737
return x
3838

3939

@@ -176,7 +176,7 @@ def __init__(self, n_in: int, ks=1, sym=False, use_bias=False):
176176
self.sym = sym
177177
self.n_in = n_in
178178

179-
def forward(self, x):
179+
def forward(self, x: torch.Tensor) -> torch.Tensor:
180180
if self.sym: # check ks=3
181181
# symmetry hack by https://github.com/mgrankin
182182
c = self.conv.weight.view(self.n_in, self.n_in)
@@ -202,7 +202,7 @@ class SEBlock(nn.Module):
202202
act_fn = nn.ReLU(inplace=True)
203203
use_bias = True
204204

205-
def __init__(self, c, r=16):
205+
def __init__(self, c: int, r: int = 16):
206206
super().__init__()
207207
ch = max(c // r, 1)
208208
self.squeeze = nn.AdaptiveAvgPool2d(1)
@@ -217,7 +217,7 @@ def __init__(self, c, r=16):
217217
)
218218
)
219219

220-
def forward(self, x):
220+
def forward(self, x: torch.Tensor) -> torch.Tensor:
221221
bs, c, _, _ = x.shape
222222
y = self.squeeze(x).view(bs, c)
223223
y = self.excitation(y).view(bs, c, 1, 1)

0 commit comments

Comments
 (0)