|
6 | 6 |
|
7 | 7 | import os |
8 | 8 | import warnings |
| 9 | +from textwrap import dedent |
9 | 10 |
|
10 | 11 | import numpy as np |
11 | 12 | import scipy.special |
@@ -1134,7 +1135,8 @@ class Softplus(UnaryScalarOp): |
1134 | 1135 | r""" |
1135 | 1136 | Compute log(1 + exp(x)), also known as softplus or log1pexp |
1136 | 1137 |
|
1137 | | - This function is numerically more stable than the naive approach. |
| 1138 | + This function is numerically faster than the naive approach, and does not overflow |
| 1139 | + for large values of x. |
1138 | 1140 |
|
1139 | 1141 | For details, see |
1140 | 1142 | https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf |
@@ -1172,52 +1174,38 @@ def grad(self, inp, grads): |
1172 | 1174 | def c_code(self, node, name, inp, out, sub): |
1173 | 1175 | (x,) = inp |
1174 | 1176 | (z,) = out |
1175 | | - # The boundary constants were obtained by looking at the output of |
1176 | | - # python commands like: |
1177 | | - # import numpy, pytensor |
1178 | | - # dt='float32' # or float64 |
1179 | | - # for i in range(750): |
1180 | | - # print i, repr(numpy.log1p(numpy.exp(_asarray([i,-i], dtype=dt)))) |
1181 | | - # the upper boundary check prevents us from generating inf, whereas the |
1182 | | - # the lower boundary check prevents using exp when the result will be 0 anyway. |
1183 | | - # The intermediate constants are taken from Machler (2012). |
1184 | | - |
1185 | | - # We use the float32 limits for float16 for now as the |
1186 | | - # computation will happen in float32 anyway. |
| 1177 | + # We use the same limits for all precisions, which may be suboptimal. The reference |
| 1178 | + # paper only looked at double precision |
1187 | 1179 | if node.inputs[0].type in float_types: |
1188 | 1180 | if node.inputs[0].type == float64: |
1189 | | - return ( |
1190 | | - """ |
1191 | | - %(z)s = ( |
1192 | | - %(x)s < -745.0 ? 0.0 : |
1193 | | - %(x)s < -37.0 ? exp(%(x)s) : |
1194 | | - %(x)s < 18.0 ? log1p(exp(%(x)s)) : |
1195 | | - %(x)s < 33.3 ? %(x)s + exp(-%(x)s) : |
1196 | | - %(x)s |
| 1181 | + return dedent( |
| 1182 | + f""" |
| 1183 | + {z} = ( |
| 1184 | + {x} < -37.0 ? exp({x}) : |
| 1185 | + {x} < 18.0 ? log1p(exp({x})) : |
| 1186 | + {x} < 33.3 ? {x} + exp(-{x}) : |
| 1187 | + {x} |
1197 | 1188 | ); |
1198 | 1189 | """ |
1199 | | - % locals() |
1200 | 1190 | ) |
1201 | 1191 | else: |
1202 | | - return ( |
1203 | | - """ |
1204 | | - %(z)s = ( |
1205 | | - %(x)s < -103.0f ? 0.0 : |
1206 | | - %(x)s < -37.0f ? exp(%(x)s) : |
1207 | | - %(x)s < 18.0f ? log1p(exp(%(x)s)) : |
1208 | | - %(x)s < 33.3f ? %(x)s + exp(-%(x)s) : |
1209 | | - %(x)s |
| 1192 | + return dedent( |
| 1193 | + f""" |
| 1194 | + {z} = ( |
| 1195 | + {x} < -37.0f ? exp({x}) : |
| 1196 | + {x} < 18.0f ? log1p(exp({x})) : |
| 1197 | + {x} < 33.3f ? {x} + exp(-{x}) : |
| 1198 | + {x} |
1210 | 1199 | ); |
1211 | 1200 | """ |
1212 | | - % locals() |
1213 | 1201 | ) |
1214 | 1202 | else: |
1215 | 1203 | raise NotImplementedError("only floatingpoint is implemented") |
1216 | 1204 |
|
1217 | 1205 | def c_code_cache_version(self): |
1218 | 1206 | v = super().c_code_cache_version() |
1219 | 1207 | if v: |
1220 | | - return (2,) + v |
| 1208 | + return (3,) + v |
1221 | 1209 | else: |
1222 | 1210 | return v |
1223 | 1211 |
|
|
0 commit comments