@@ -34,9 +34,10 @@ def __init__(self, in_features, hidden_features=None, out_features=None, act_lay
3434 super ().__init__ ()
3535 out_features = out_features or in_features
3636 hidden_features = hidden_features or in_features
37- self .fc1 = nn .Linear (in_features , hidden_features * 2 )
37+ assert hidden_features % 2 == 0
38+ self .fc1 = nn .Linear (in_features , hidden_features )
3839 self .act = act_layer ()
39- self .fc2 = nn .Linear (hidden_features , out_features )
40+ self .fc2 = nn .Linear (hidden_features // 2 , out_features )
4041 self .drop = nn .Dropout (drop )
4142
4243 def forward (self , x ):
@@ -47,3 +48,32 @@ def forward(self, x):
4748 x = self .fc2 (x )
4849 x = self .drop (x )
4950 return x
51+
52+
53+ class GatedMlp (nn .Module ):
54+ """ MLP as used in gMLP
55+ """
56+ def __init__ (self , in_features , hidden_features = None , out_features = None , act_layer = nn .GELU ,
57+ gate_layer = None , drop = 0. ):
58+ super ().__init__ ()
59+ out_features = out_features or in_features
60+ hidden_features = hidden_features or in_features
61+ self .fc1 = nn .Linear (in_features , hidden_features )
62+ self .act = act_layer ()
63+ if gate_layer is not None :
64+ assert hidden_features % 2 == 0
65+ self .gate = gate_layer (hidden_features )
66+ hidden_features = hidden_features // 2 # FIXME base reduction on gate property?
67+ else :
68+ self .gate = nn .Identity ()
69+ self .fc2 = nn .Linear (hidden_features , out_features )
70+ self .drop = nn .Dropout (drop )
71+
72+ def forward (self , x ):
73+ x = self .fc1 (x )
74+ x = self .act (x )
75+ x = self .drop (x )
76+ x = self .gate (x )
77+ x = self .fc2 (x )
78+ x = self .drop (x )
79+ return x
0 commit comments