@@ -12,23 +12,17 @@ class SigmoidModel(nn.Module):
1212 -pytorch-and-make-your-life-simpler-ec5367895199
1313 """
1414
15- # pyre-fixme[2]: Parameter must be annotated.
16- def __init__ (self , num_in , num_hidden , num_out ) -> None :
15+ def __init__ (self , num_in : int , num_hidden : int , num_out : int ) -> None :
1716 super ().__init__ ()
18- # pyre-fixme[4]: Attribute must be annotated.
1917 self .num_in = num_in
20- # pyre-fixme[4]: Attribute must be annotated.
2118 self .num_hidden = num_hidden
22- # pyre-fixme[4]: Attribute must be annotated.
2319 self .num_out = num_out
2420 self .lin1 = nn .Linear (num_in , num_hidden )
2521 self .lin2 = nn .Linear (num_hidden , num_out )
2622 self .relu1 = nn .ReLU ()
2723 self .sigmoid = nn .Sigmoid ()
2824
29- # pyre-fixme[3]: Return type must be annotated.
30- # pyre-fixme[2]: Parameter must be annotated.
31- def forward (self , input ):
25+ def forward (self , input : torch .Tensor ) -> torch .Tensor :
3226 lin1 = self .lin1 (input )
3327 lin2 = self .lin2 (self .relu1 (lin1 ))
3428 return self .sigmoid (lin2 )
@@ -40,14 +34,12 @@ class SoftmaxModel(nn.Module):
4034 https://adventuresinmachinelearning.com/pytorch-tutorial-deep-learning/
4135 """
4236
43- # pyre-fixme[2]: Parameter must be annotated.
44- def __init__ (self , num_in , num_hidden , num_out , inplace : bool = False ) -> None :
37+ def __init__ (
38+ self , num_in : int , num_hidden : int , num_out : int , inplace : bool = False
39+ ) -> None :
4540 super ().__init__ ()
46- # pyre-fixme[4]: Attribute must be annotated.
4741 self .num_in = num_in
48- # pyre-fixme[4]: Attribute must be annotated.
4942 self .num_hidden = num_hidden
50- # pyre-fixme[4]: Attribute must be annotated.
5143 self .num_out = num_out
5244 self .lin1 = nn .Linear (num_in , num_hidden )
5345 self .lin2 = nn .Linear (num_hidden , num_hidden )
@@ -56,9 +48,7 @@ def __init__(self, num_in, num_hidden, num_out, inplace: bool = False) -> None:
5648 self .relu2 = nn .ReLU (inplace = inplace )
5749 self .softmax = nn .Softmax (dim = 1 )
5850
59- # pyre-fixme[3]: Return type must be annotated.
60- # pyre-fixme[2]: Parameter must be annotated.
61- def forward (self , input ):
51+ def forward (self , input : torch .Tensor ) -> torch .Tensor :
6252 lin1 = self .relu1 (self .lin1 (input ))
6353 lin2 = self .relu2 (self .lin2 (lin1 ))
6454 lin3 = self .lin3 (lin2 )
@@ -72,14 +62,10 @@ class SigmoidDeepLiftModel(nn.Module):
7262 -pytorch-and-make-your-life-simpler-ec5367895199
7363 """
7464
75- # pyre-fixme[2]: Parameter must be annotated.
76- def __init__ (self , num_in , num_hidden , num_out ) -> None :
65+ def __init__ (self , num_in : int , num_hidden : int , num_out : int ) -> None :
7766 super ().__init__ ()
78- # pyre-fixme[4]: Attribute must be annotated.
7967 self .num_in = num_in
80- # pyre-fixme[4]: Attribute must be annotated.
8168 self .num_hidden = num_hidden
82- # pyre-fixme[4]: Attribute must be annotated.
8369 self .num_out = num_out
8470 self .lin1 = nn .Linear (num_in , num_hidden , bias = False )
8571 self .lin2 = nn .Linear (num_hidden , num_out , bias = False )
@@ -88,9 +74,7 @@ def __init__(self, num_in, num_hidden, num_out) -> None:
8874 self .relu1 = nn .ReLU ()
8975 self .sigmoid = nn .Sigmoid ()
9076
91- # pyre-fixme[3]: Return type must be annotated.
92- # pyre-fixme[2]: Parameter must be annotated.
93- def forward (self , input ):
77+ def forward (self , input : torch .Tensor ) -> torch .Tensor :
9478 lin1 = self .lin1 (input )
9579 lin2 = self .lin2 (self .relu1 (lin1 ))
9680 return self .sigmoid (lin2 )
@@ -102,14 +86,10 @@ class SoftmaxDeepLiftModel(nn.Module):
10286 https://adventuresinmachinelearning.com/pytorch-tutorial-deep-learning/
10387 """
10488
105- # pyre-fixme[2]: Parameter must be annotated.
106- def __init__ (self , num_in , num_hidden , num_out ) -> None :
89+ def __init__ (self , num_in : int , num_hidden : int , num_out : int ) -> None :
10790 super ().__init__ ()
108- # pyre-fixme[4]: Attribute must be annotated.
10991 self .num_in = num_in
110- # pyre-fixme[4]: Attribute must be annotated.
11192 self .num_hidden = num_hidden
112- # pyre-fixme[4]: Attribute must be annotated.
11393 self .num_out = num_out
11494 self .lin1 = nn .Linear (num_in , num_hidden )
11595 self .lin2 = nn .Linear (num_hidden , num_hidden )
@@ -121,9 +101,7 @@ def __init__(self, num_in, num_hidden, num_out) -> None:
121101 self .relu2 = nn .ReLU ()
122102 self .softmax = nn .Softmax (dim = 1 )
123103
124- # pyre-fixme[3]: Return type must be annotated.
125- # pyre-fixme[2]: Parameter must be annotated.
126- def forward (self , input ):
104+ def forward (self , input : torch .Tensor ) -> torch .Tensor :
127105 lin1 = self .relu1 (self .lin1 (input ))
128106 lin2 = self .relu2 (self .lin2 (lin1 ))
129107 lin3 = self .lin3 (lin2 )
0 commit comments