11""" Activation Factory
22Hacked together by / Copyright 2020 Ross Wightman
33"""
4+ from typing import Union , Callable , Type
5+
46from .activations import *
57from .activations_jit import *
68from .activations_me import *
79from .config import is_exportable , is_scriptable , is_no_jit
810
9- # PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. This code
10- # will use native version if present. Eventually, the custom Swish layers will be removed
11- # and only native 'silu' will be used.
11+ # PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7.
12+ # Also hardsigmoid, hardswish, and soon mish. This code will use native version if present.
13+ # Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used.
1214_has_silu = 'silu' in dir (torch .nn .functional )
15+ _has_hardswish = 'hardswish' in dir (torch .nn .functional )
16+ _has_hardsigmoid = 'hardsigmoid' in dir (torch .nn .functional )
17+ _has_mish = 'mish' in dir (torch .nn .functional )
18+
1319
1420_ACT_FN_DEFAULT = dict (
1521 silu = F .silu if _has_silu else swish ,
1622 swish = F .silu if _has_silu else swish ,
17- mish = mish ,
23+ mish = F . mish if _has_mish else mish ,
1824 relu = F .relu ,
1925 relu6 = F .relu6 ,
2026 leaky_relu = F .leaky_relu ,
2430 gelu = gelu ,
2531 sigmoid = sigmoid ,
2632 tanh = tanh ,
27- hard_sigmoid = hard_sigmoid ,
28- hard_swish = hard_swish ,
33+ hard_sigmoid = F . hardsigmoid if _has_hardsigmoid else hard_sigmoid ,
34+ hard_swish = F . hardswish if _has_hardswish else hard_swish ,
2935 hard_mish = hard_mish ,
3036)
3137
3238_ACT_FN_JIT = dict (
3339 silu = F .silu if _has_silu else swish_jit ,
3440 swish = F .silu if _has_silu else swish_jit ,
35- mish = mish_jit ,
36- hard_sigmoid = hard_sigmoid_jit ,
37- hard_swish = hard_swish_jit ,
41+ mish = F . mish if _has_mish else mish_jit ,
42+ hard_sigmoid = F . hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit ,
43+ hard_swish = F . hardswish if _has_hardswish else hard_swish_jit ,
3844 hard_mish = hard_mish_jit
3945)
4046
4147_ACT_FN_ME = dict (
4248 silu = F .silu if _has_silu else swish_me ,
4349 swish = F .silu if _has_silu else swish_me ,
44- mish = mish_me ,
45- hard_sigmoid = hard_sigmoid_me ,
46- hard_swish = hard_swish_me ,
50+ mish = F . mish if _has_mish else mish_me ,
51+ hard_sigmoid = F . hardsigmoid if _has_hardsigmoid else hard_sigmoid_me ,
52+ hard_swish = F . hardswish if _has_hardswish else hard_swish_me ,
4753 hard_mish = hard_mish_me ,
4854)
4955
56+ _ACT_FNS = (_ACT_FN_ME , _ACT_FN_JIT , _ACT_FN_DEFAULT )
57+ for a in _ACT_FNS :
58+ a .setdefault ('hardsigmoid' , a .get ('hard_sigmoid' ))
59+ a .setdefault ('hardswish' , a .get ('hard_swish' ))
60+
61+
5062_ACT_LAYER_DEFAULT = dict (
5163 silu = nn .SiLU if _has_silu else Swish ,
5264 swish = nn .SiLU if _has_silu else Swish ,
53- mish = Mish ,
65+ mish = nn . Mish if _has_mish else Mish ,
5466 relu = nn .ReLU ,
5567 relu6 = nn .ReLU6 ,
5668 leaky_relu = nn .LeakyReLU ,
6173 gelu = GELU ,
6274 sigmoid = Sigmoid ,
6375 tanh = Tanh ,
64- hard_sigmoid = HardSigmoid ,
65- hard_swish = HardSwish ,
76+ hard_sigmoid = nn . Hardsigmoid if _has_hardsigmoid else HardSigmoid ,
77+ hard_swish = nn . Hardswish if _has_hardswish else HardSwish ,
6678 hard_mish = HardMish ,
6779)
6880
6981_ACT_LAYER_JIT = dict (
7082 silu = nn .SiLU if _has_silu else SwishJit ,
7183 swish = nn .SiLU if _has_silu else SwishJit ,
72- mish = MishJit ,
73- hard_sigmoid = HardSigmoidJit ,
74- hard_swish = HardSwishJit ,
84+ mish = nn . Mish if _has_mish else MishJit ,
85+ hard_sigmoid = nn . Hardsigmoid if _has_hardsigmoid else HardSigmoidJit ,
86+ hard_swish = nn . Hardswish if _has_hardswish else HardSwishJit ,
7587 hard_mish = HardMishJit
7688)
7789
7890_ACT_LAYER_ME = dict (
7991 silu = nn .SiLU if _has_silu else SwishMe ,
8092 swish = nn .SiLU if _has_silu else SwishMe ,
81- mish = MishMe ,
82- hard_sigmoid = HardSigmoidMe ,
83- hard_swish = HardSwishMe ,
93+ mish = nn . Mish if _has_mish else MishMe ,
94+ hard_sigmoid = nn . Hardsigmoid if _has_hardsigmoid else HardSigmoidMe ,
95+ hard_swish = nn . Hardswish if _has_hardswish else HardSwishMe ,
8496 hard_mish = HardMishMe ,
8597)
8698
99+ _ACT_LAYERS = (_ACT_LAYER_ME , _ACT_LAYER_JIT , _ACT_LAYER_DEFAULT )
100+ for a in _ACT_LAYERS :
101+ a .setdefault ('hardsigmoid' , a .get ('hard_sigmoid' ))
102+ a .setdefault ('hardswish' , a .get ('hard_swish' ))
103+
87104
88- def get_act_fn (name = 'relu' ):
105+ def get_act_fn (name : Union [ Callable , str ] = 'relu' ):
89106 """ Activation Function Factory
90107 Fetching activation fns by name with this function allows export or torch script friendly
91108 functions to be returned dynamically based on current config.
92109 """
93110 if not name :
94111 return None
112+ if isinstance (name , Callable ):
113+ return name
95114 if not (is_no_jit () or is_exportable () or is_scriptable ()):
96115 # If not exporting or scripting the model, first look for a memory-efficient version with
97116 # custom autograd, then fallback
@@ -106,13 +125,15 @@ def get_act_fn(name='relu'):
106125 return _ACT_FN_DEFAULT [name ]
107126
108127
109- def get_act_layer (name = 'relu' ):
128+ def get_act_layer (name : Union [ Type [ nn . Module ], str ] = 'relu' ):
110129 """ Activation Layer Factory
111130 Fetching activation layers by name with this function allows export or torch script friendly
112131 functions to be returned dynamically based on current config.
113132 """
114133 if not name :
115134 return None
135+ if isinstance (name , type ):
136+ return name
116137 if not (is_no_jit () or is_exportable () or is_scriptable ()):
117138 if name in _ACT_LAYER_ME :
118139 return _ACT_LAYER_ME [name ]
@@ -125,9 +146,8 @@ def get_act_layer(name='relu'):
125146 return _ACT_LAYER_DEFAULT [name ]
126147
127148
128- def create_act_layer (name , inplace = False , ** kwargs ):
149+ def create_act_layer (name : Union [ nn . Module , str ], inplace = None , ** kwargs ):
129150 act_layer = get_act_layer (name )
130- if act_layer is not None :
131- return act_layer (inplace = inplace , ** kwargs )
132- else :
151+ if act_layer is None :
133152 return None
153+ return act_layer (** kwargs ) if inplace is None else act_layer (inplace = inplace , ** kwargs )
0 commit comments