22from typing import List , Optional , Type , Union
33
44import torch
5- import torch . nn as nn
5+ from torch import nn
66from torch .nn .utils .spectral_norm import spectral_norm
77
88__all__ = [
2121class Flatten (nn .Module ):
2222 """flat x to vector"""
2323
24- def forward (self , x ) :
24+ def forward (self , x : torch . Tensor ) -> torch . Tensor :
2525 return x .view (x .size (0 ), - 1 )
2626
2727
28- def noop (x ) :
28+ def noop (x : torch . Tensor ) -> torch . Tensor :
2929 """Dummy func. Return input"""
3030 return x
3131
3232
3333class Noop (nn .Module ):
3434 """Dummy module"""
3535
36- def forward (self , x ) :
36+ def forward (self , x : torch . Tensor ) -> torch . Tensor :
3737 return x
3838
3939
@@ -176,7 +176,7 @@ def __init__(self, n_in: int, ks=1, sym=False, use_bias=False):
176176 self .sym = sym
177177 self .n_in = n_in
178178
179- def forward (self , x ) :
179+ def forward (self , x : torch . Tensor ) -> torch . Tensor :
180180 if self .sym : # check ks=3
181181 # symmetry hack by https://github.com/mgrankin
182182 c = self .conv .weight .view (self .n_in , self .n_in )
@@ -202,7 +202,7 @@ class SEBlock(nn.Module):
202202 act_fn = nn .ReLU (inplace = True )
203203 use_bias = True
204204
205- def __init__ (self , c , r = 16 ):
205+ def __init__ (self , c : int , r : int = 16 ):
206206 super ().__init__ ()
207207 ch = max (c // r , 1 )
208208 self .squeeze = nn .AdaptiveAvgPool2d (1 )
@@ -217,7 +217,7 @@ def __init__(self, c, r=16):
217217 )
218218 )
219219
220- def forward (self , x ) :
220+ def forward (self , x : torch . Tensor ) -> torch . Tensor :
221221 bs , c , _ , _ = x .shape
222222 y = self .squeeze (x ).view (bs , c )
223223 y = self .excitation (y ).view (bs , c , 1 , 1 )
0 commit comments