55from torch .nn import Mish
66
77
8- __all__ = ['mish' , 'Mish' , 'mish_jit' , 'MishJit' , 'mish_jit_fwd' , 'mish_jit_bwd' , 'MishJitAutoFn' , 'mish_me' , 'MishMe' ,
9- 'hard_mish_jit' , 'HardMishJit' , 'hard_mish_jit_fwd' , 'hard_mish_jit_bwd' , 'HardMishJitAutoFn' ,
10- 'hard_mish_me' , 'HardMishMe' ]
8+ __all__ = [
9+ "mish" ,
10+ "Mish" ,
11+ "mish_jit" ,
12+ "MishJit" ,
13+ "mish_jit_fwd" ,
14+ "mish_jit_bwd" ,
15+ "MishJitAutoFn" ,
16+ "mish_me" ,
17+ "MishMe" ,
18+ "hard_mish_jit" ,
19+ "HardMishJit" ,
20+ "hard_mish_jit_fwd" ,
21+ "hard_mish_jit_bwd" ,
22+ "HardMishJitAutoFn" ,
23+ "hard_mish_me" ,
24+ "HardMishMe" ,
25+ ]
1126
1227
1328def mish (x , inplace : bool = False ):
@@ -40,7 +55,8 @@ def mish_jit(x, _inplace: bool = False):
4055class MishJit (nn .Module ):
4156 def __init__ (self , inplace : bool = False ):
4257 """Jit version of Mish.
43- Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681"""
58+ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
59+ """
4460 super (MishJit , self ).__init__ ()
4561
4662 def forward (self , x ):
@@ -61,8 +77,9 @@ def mish_jit_bwd(x, grad_output):
6177
6278
6379class MishJitAutoFn (torch .autograd .Function ):
64- """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
80+ """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
6581 A memory efficient, jit scripted variant of Mish"""
82+
6683 @staticmethod
6784 def forward (ctx , x ):
6885 ctx .save_for_backward (x )
@@ -79,8 +96,9 @@ def mish_me(x, inplace=False):
7996
8097
8198class MishMe (nn .Module ):
82- """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
99+ """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
83100 A memory efficient, jit scripted variant of Mish"""
101+
84102 def __init__ (self , inplace : bool = False ):
85103 super (MishMe , self ).__init__ ()
86104
@@ -90,18 +108,19 @@ def forward(self, x):
90108
91109@torch .jit .script
92110def hard_mish_jit (x , inplace : bool = False ):
93- """ Hard Mish
111+ """Hard Mish
94112 Experimental, based on notes by Mish author Diganta Misra at
95113 https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
96114 """
97115 return 0.5 * x * (x + 2 ).clamp (min = 0 , max = 2 )
98116
99117
100118class HardMishJit (nn .Module ):
101- """ Hard Mish
119+ """Hard Mish
102120 Experimental, based on notes by Mish author Diganta Misra at
103121 https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
104122 """
123+
105124 def __init__ (self , inplace : bool = False ):
106125 super (HardMishJit , self ).__init__ ()
107126
@@ -116,16 +135,17 @@ def hard_mish_jit_fwd(x):
116135
117136@torch .jit .script
118137def hard_mish_jit_bwd (x , grad_output ):
119- m = torch .ones_like (x ) * (x >= - 2. )
120- m = torch .where ((x >= - 2. ) & (x <= 0. ), x + 1. , m )
138+ m = torch .ones_like (x ) * (x >= - 2.0 )
139+ m = torch .where ((x >= - 2.0 ) & (x <= 0.0 ), x + 1.0 , m )
121140 return grad_output * m
122141
123142
124143class HardMishJitAutoFn (torch .autograd .Function ):
125- """ A memory efficient, jit scripted variant of Hard Mish
144+ """A memory efficient, jit scripted variant of Hard Mish
126145 Experimental, based on notes by Mish author Diganta Misra at
127146 https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
128147 """
148+
129149 @staticmethod
130150 def forward (ctx , x ):
131151 ctx .save_for_backward (x )
@@ -142,10 +162,11 @@ def hard_mish_me(x, inplace: bool = False):
142162
143163
144164class HardMishMe (nn .Module ):
145- """ A memory efficient, jit scripted variant of Hard Mish
165+ """A memory efficient, jit scripted variant of Hard Mish
146166 Experimental, based on notes by Mish author Diganta Misra at
147167 https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
148168 """
169+
149170 def __init__ (self , inplace : bool = False ):
150171 super (HardMishMe , self ).__init__ ()
151172
0 commit comments