Skip to content

Commit 38907d7

Browse files
committed
Add sign, maximum, minimum
1 parent 03ead57 commit 38907d7

13 files changed

+448
-214
lines changed

ADDFUNCS.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ Add clauses to generate the FUNC_CODES from the ``functions.hpp`` header, making
171171
};
172172
#endif
173173
174-
Some functions (e.g. ``fmod``, ``isnan``) are not available in MKL, and so must be hard-coded here as well:
174+
Some functions (e.g. ``fmod``, ``isnan``) are not available in MKL, and so must be hard-coded in ``bespoke_functions.hpp`` as well:
175175

176176
.. code-block:: cpp
177177
@@ -186,7 +186,7 @@ Some functions (e.g. ``fmod``, ``isnan``) are not available in MKL, and so must
186186
};
187187
#endif
188188
189-
The complex case is slightlñy different (see other examples in the same file).
189+
The complex case is slightly different (see other examples in the same file).
190190

191191
Add case handling to the ``check_program`` function
192192

numexpr/bespoke_functions.hpp

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
#include <numpy/npy_cpu.h>
2+
#include <math.h>
3+
#include <string.h>
4+
#include <assert.h>
5+
#include <vector>
6+
#include "numexpr_config.hpp" // isnan definitions
7+
8+
// Generic sign function
9+
inline int signi(int x) {return (0 < x) - (x < 0);}
10+
inline long signl(long x) {return (0 < x) - (x < 0);}
11+
inline double sign(double x){
12+
// Floats: -1.0, 0.0, +1.0, NaN stays NaN
13+
if (isnand(x)) {return NAN;}
14+
if (x > 0) {return 1;}
15+
if (x < 0) {return -1;}
16+
return 0; // handles +0.0 and -0.0
17+
}
18+
inline float signf(float x){
19+
// Floats: -1.0, 0.0, +1.0, NaN stays NaN
20+
if (isnanf_(x)) {return NAN;}
21+
if (x > 0) {return 1;}
22+
if (x < 0) {return -1;}
23+
return 0; // handles +0.0 and -0.0
24+
}
25+
26+
27+
#ifdef USE_VML
28+
/* Fake vsConj function just for casting purposes inside numexpr */
29+
static void vsConj(MKL_INT n, const float* x1, float* dest)
30+
{
31+
MKL_INT j;
32+
for (j=0; j<n; j++) {
33+
dest[j] = x1[j];
34+
};
35+
};
36+
37+
/* fmod not available in VML */
38+
static void vsfmod(MKL_INT n, const float* x1, const float* x2, float* dest)
39+
{
40+
MKL_INT j;
41+
for(j=0; j < n; j++) {
42+
dest[j] = fmod(x1[j], x2[j]);
43+
};
44+
}
45+
/* no isnan, isfinite, isinf or signbit in VML */
46+
static void vsIsfinite(MKL_INT n, const float* x1, bool* dest)
47+
{
48+
MKL_INT j;
49+
for (j=0; j<n; j++) {
50+
dest[j] = isfinitef_(x1[j]);
51+
};
52+
};
53+
static void vsIsinf(MKL_INT n, const float* x1, bool* dest)
54+
{
55+
MKL_INT j;
56+
for (j=0; j<n; j++) {
57+
dest[j] = isinff_(x1[j]);
58+
};
59+
};
60+
static void vsIsnan(MKL_INT n, const float* x1, bool* dest)
61+
{
62+
MKL_INT j;
63+
for (j=0; j<n; j++) {
64+
dest[j] = isnanf_(x1[j]);
65+
};
66+
};
67+
static void vsSignBit(MKL_INT n, const float* x1, bool* dest)
68+
{
69+
MKL_INT j;
70+
for (j=0; j<n; j++) {
71+
dest[j] = signbitf(x1[j]);
72+
};
73+
};
74+
75+
/* no isnan, isfinite, isinf, signbit in VML */
76+
static void vdIsfinite(MKL_INT n, const double* x1, bool* dest)
77+
{
78+
MKL_INT j;
79+
for (j=0; j<n; j++) {
80+
dest[j] = isfinited(x1[j]);
81+
};
82+
};
83+
static void vdIsinf(MKL_INT n, const double* x1, bool* dest)
84+
{
85+
MKL_INT j;
86+
for (j=0; j<n; j++) {
87+
dest[j] = isinfd(x1[j]);
88+
};
89+
};
90+
static void vdIsnan(MKL_INT n, const double* x1, bool* dest)
91+
{
92+
MKL_INT j;
93+
for (j=0; j<n; j++) {
94+
dest[j] = isnand(x1[j]);
95+
};
96+
};
97+
static void vdSignBit(MKL_INT n, const double* x1, bool* dest)
98+
{
99+
MKL_INT j;
100+
for (j=0; j<n; j++) {
101+
dest[j] = signbit(x1[j]);
102+
};
103+
};
104+
105+
/* no isnan, isfinite or isinf in VML */
106+
static void vzIsfinite(MKL_INT n, const MKL_Complex16* x1, bool* dest)
107+
{
108+
MKL_INT j;
109+
for (j=0; j<n; j++) {
110+
dest[j] = isfinited(x1[j].real) && isfinited(x1[j].imag);
111+
};
112+
};
113+
static void vzIsinf(MKL_INT n, const MKL_Complex16* x1, bool* dest)
114+
{
115+
MKL_INT j;
116+
for (j=0; j<n; j++) {
117+
dest[j] = isinfd(x1[j].real) || isinfd(x1[j].imag);
118+
};
119+
};
120+
static void vzIsnan(MKL_INT n, const MKL_Complex16* x1, bool* dest)
121+
{
122+
MKL_INT j;
123+
for (j=0; j<n; j++) {
124+
dest[j] = isnand(x1[j].real) || isnand(x1[j].imag);
125+
};
126+
};
127+
128+
/* Fake vdConj function just for casting purposes inside numexpr */
129+
static void vdConj(MKL_INT n, const double* x1, double* dest)
130+
{
131+
MKL_INT j;
132+
for (j=0; j<n; j++) {
133+
dest[j] = x1[j];
134+
};
135+
};
136+
137+
/* fmod not available in VML */
138+
static void vdfmod(MKL_INT n, const double* x1, const double* x2, double* dest)
139+
{
140+
MKL_INT j;
141+
for(j=0; j < n; j++) {
142+
dest[j] = fmod(x1[j], x2[j]);
143+
};
144+
};
145+
146+
/* various functions not available in VML */
147+
static void vzExpm1(MKL_INT n, const MKL_Complex16* x1, MKL_Complex16* dest)
148+
{
149+
MKL_INT j;
150+
vzExp(n, x1, dest);
151+
for (j=0; j<n; j++) {
152+
dest[j].real -= 1.0;
153+
};
154+
};
155+
156+
static void vzLog1p(MKL_INT n, const MKL_Complex16* x1, MKL_Complex16* dest)
157+
{
158+
MKL_INT j;
159+
for (j=0; j<n; j++) {
160+
dest[j].real = x1[j].real + 1;
161+
dest[j].imag = x1[j].imag;
162+
};
163+
vzLn(n, dest, dest);
164+
};
165+
166+
static void vzLog2(MKL_INT n, const MKL_Complex16* x1, MKL_Complex16* dest)
167+
{
168+
MKL_INT j;
169+
vzLn(n, x1, dest);
170+
for (j=0; j<n; j++) {
171+
dest[j].real = dest[j].real * M_LOG2_E;
172+
dest[j].imag = dest[j].imag * M_LOG2_E;
173+
};
174+
};
175+
176+
static void vzRint(MKL_INT n, const MKL_Complex16* x1, MKL_Complex16* dest)
177+
{
178+
MKL_INT j;
179+
for (j=0; j<n; j++) {
180+
dest[j].real = rint(x1[j].real);
181+
dest[j].imag = rint(x1[j].imag);
182+
};
183+
};
184+
185+
/* Use this instead of native vzAbs in VML as it seems to work badly */
186+
static void vzAbs_(MKL_INT n, const MKL_Complex16* x1, MKL_Complex16* dest)
187+
{
188+
MKL_INT j;
189+
for (j=0; j<n; j++) {
190+
dest[j].real = sqrt(x1[j].real*x1[j].real + x1[j].imag*x1[j].imag);
191+
dest[j].imag = 0;
192+
};
193+
};
194+
195+
/*sign functions*/
196+
static void vsSign(MKL_INT n, const float* x1, float* dest)
197+
{
198+
MKL_INT j;
199+
for(j=0; j < n; j++) {
200+
dest[j] = signf(x1[j]);
201+
};
202+
};
203+
static void vdSign(MKL_INT n, const double* x1, double* dest)
204+
{
205+
MKL_INT j;
206+
for(j=0; j < n; j++) {
207+
dest[j] = sign(x1[j]);
208+
};
209+
};
210+
static void viSign(MKL_INT n, const int* x1, int* dest)
211+
{
212+
MKL_INT j;
213+
for(j=0; j < n; j++) {
214+
dest[j] = signi(x1[j]);
215+
};
216+
};
217+
static void vlSign(MKL_INT n, const long* x1, long* dest)
218+
{
219+
MKL_INT j;
220+
for(j=0; j < n; j++) {
221+
dest[j] = signl(x1[j]);
222+
};
223+
};
224+
static void vzSign(MKL_INT n, const MKL_Complex16* x1, MKL_Complex16* dest)
225+
{
226+
MKL_INT j;
227+
double mag;
228+
for(j=0; j < n; j++) {
229+
mag = sqrt(x1[j].real*x1[j].real + x1[j].imag*x1[j].imag);
230+
if (isnand(mag)) {
231+
dest[j].real = NAN;
232+
dest[j].imag = NAN;
233+
}
234+
else if (mag == 0) {
235+
dest[j].real = 0;
236+
dest[j].imag = 0;
237+
}
238+
else {
239+
dest[j].real = x1[j].real / mag;
240+
dest[j].imag = x1[j].imag / mag;
241+
}
242+
};
243+
};
244+
#endif

numexpr/complex_functions.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
**********************************************************************/
1212

1313
// Replace npy_cdouble with std::complex<double>
14+
#include <math.h> // NAN
1415
#include <complex>
1516

1617
/* constants */
@@ -471,4 +472,24 @@ nc_isfinite(std::complex<double> *x)
471472
br = isfinited(xr);
472473
return bi && br;
473474
}
475+
476+
static void
477+
nc_sign(std::complex<double> *x, std::complex<double> *r)
478+
{
479+
if (nc_isnan(x)){
480+
r->real(NAN);
481+
r->imag(NAN);
482+
}
483+
std::complex<double> mag;
484+
nc_abs(x, &mag);
485+
if (mag.real() == 0){
486+
r->real(0);
487+
r->imag(0);
488+
}
489+
else{
490+
r->real(x->real()/mag.real());
491+
r->imag(x->imag()/mag.real());
492+
}
493+
}
494+
474495
#endif // NUMEXPR_COMPLEX_FUNCTIONS_HPP

numexpr/expressions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,9 @@ def multiply(x, y):
351351
'hypot': func(numpy.hypot, 'double'),
352352
'nextafter': func(numpy.nextafter, 'double'),
353353
'copysign': func(numpy.copysign, 'double'),
354+
'maximum': func(numpy.maximum, 'double'),
355+
'minimum': func(numpy.minimum, 'double'),
356+
354357

355358
'log': func(numpy.log, 'float'),
356359
'log1p': func(numpy.log1p, 'float'),
@@ -364,6 +367,7 @@ def multiply(x, y):
364367
'floor': func(numpy.floor, 'float', 'double'),
365368
'round': func(numpy.round, 'double'),
366369
'trunc': func(numpy.trunc, 'double'),
370+
'sign': func(numpy.sign, 'double'),
367371

368372
'where': where_func,
369373

numexpr/functions.hpp

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ FUNC_FF(FUNC_CONJ_FF, "conjugate_ff",fconjf, fconjf2, vsConj)
3838
FUNC_FF(FUNC_CEIL_FF, "ceil_ff", ceilf, ceilf2, vsCeil)
3939
FUNC_FF(FUNC_FLOOR_FF, "floor_ff", floorf, floorf2, vsFloor)
4040
FUNC_FF(FUNC_TRUNC_FF, "trunc_ff", truncf, truncf2, vsTrunc)
41+
FUNC_FF(FUNC_SIGN_FF, "sign_ff", signf, signf2, vsSign)
4142
//rint rounds to nearest even integer, matching NumPy (round doesn't)
4243
FUNC_FF(FUNC_ROUND_FF, "round_ff", rintf, rintf2, vsRint)
4344
FUNC_FF(FUNC_FF_LAST, NULL, NULL, NULL, NULL)
@@ -55,6 +56,8 @@ FUNC_FFF(FUNC_ARCTAN2_FFF, "arctan2_fff", atan2f, atan2f2, vsAtan2)
5556
FUNC_FFF(FUNC_HYPOT_FFF, "hypot_fff", hypotf, hypotf2, vsHypot)
5657
FUNC_FFF(FUNC_NEXTAFTER_FFF, "nextafter_fff", nextafterf, nextafterf2, vsNextAfter)
5758
FUNC_FFF(FUNC_COPYSIGN_FFF, "copysign_fff", copysignf, copysignf2, vsCopySign)
59+
FUNC_FFF(FUNC_MAXIMUM_FFF, "maximum_fff", fmaxf, fmaxf2, vsFmax)
60+
FUNC_FFF(FUNC_MINIMUM_FFF, "minimum_fff", fminf, fminf2, vsFmin)
5861
FUNC_FFF(FUNC_FFF_LAST, NULL, NULL, NULL, NULL)
5962
#ifdef ELIDE_FUNC_FFF
6063
#undef ELIDE_FUNC_FFF
@@ -81,15 +84,16 @@ FUNC_DD(FUNC_ARCTANH_DD, "arctanh_dd", atanh, vdAtanh)
8184
FUNC_DD(FUNC_LOG_DD, "log_dd", log, vdLn)
8285
FUNC_DD(FUNC_LOG1P_DD, "log1p_dd", log1p, vdLog1p)
8386
FUNC_DD(FUNC_LOG10_DD, "log10_dd", log10, vdLog10)
84-
FUNC_DD(FUNC_LOG2_DD, "log2_dd", log2, vdLog2)
87+
FUNC_DD(FUNC_LOG2_DD, "log2_dd", log2, vdLog2)
8588
FUNC_DD(FUNC_EXP_DD, "exp_dd", exp, vdExp)
8689
FUNC_DD(FUNC_EXPM1_DD, "expm1_dd", expm1, vdExpm1)
8790
FUNC_DD(FUNC_ABS_DD, "absolute_dd", fabs, vdAbs)
8891
FUNC_DD(FUNC_CONJ_DD, "conjugate_dd",fconj, vdConj)
8992
FUNC_DD(FUNC_CEIL_DD, "ceil_dd", ceil, vdCeil)
9093
FUNC_DD(FUNC_FLOOR_DD, "floor_dd", floor, vdFloor)
9194
FUNC_DD(FUNC_TRUNC_DD, "trunc_dd", trunc, vdTrunc)
92-
//rint rounds to nearest even integer, matching NumPy (round doesn't)
95+
FUNC_DD(FUNC_SIGN_DD, "sign_dd", sign, vdSign)
96+
//rint rounds to nearest even integer, matching NumPy (round doesn't)
9397
FUNC_DD(FUNC_ROUND_DD, "round_dd", rint, vdRint)
9498
FUNC_DD(FUNC_DD_LAST, NULL, NULL, NULL)
9599
#ifdef ELIDE_FUNC_DD
@@ -136,6 +140,8 @@ FUNC_DDD(FUNC_ARCTAN2_DDD, "arctan2_ddd", atan2, vdAtan2)
136140
FUNC_DDD(FUNC_HYPOT_DDD, "hypot_ddd", hypot, vdHypot)
137141
FUNC_DDD(FUNC_NEXTAFTER_DDD, "nextafter_ddd", nextafter, vdNextAfter)
138142
FUNC_DDD(FUNC_COPYSIGN_DDD, "copysign_ddd", copysign, vdCopySign)
143+
FUNC_DDD(FUNC_MAXIMUM_DDD, "maximum_ddd", fmax, vdFmax)
144+
FUNC_DDD(FUNC_MINIMUM_DDD, "minimum_ddd", fmin, vdFmin)
139145
FUNC_DDD(FUNC_DDD_LAST, NULL, NULL, NULL)
140146
#ifdef ELIDE_FUNC_DDD
141147
#undef ELIDE_FUNC_DDD
@@ -167,6 +173,7 @@ FUNC_CC(FUNC_EXP_CC, "exp_cc", nc_exp, vzExp)
167173
FUNC_CC(FUNC_EXPM1_CC, "expm1_cc", nc_expm1, vzExpm1)
168174
FUNC_CC(FUNC_ABS_CC, "absolute_cc", nc_abs, vzAbs_)
169175
FUNC_CC(FUNC_CONJ_CC, "conjugate_cc",nc_conj, vzConj)
176+
FUNC_CC(FUNC_SIGN_CC, "sign_cc", nc_sign, vzSign)
170177
// rint rounds to nearest even integer, matches NumPy behaviour (round doesn't)
171178
FUNC_CC(FUNC_ROUND_CC, "round_cc", nc_rint, vzRint)
172179
FUNC_CC(FUNC_CC_LAST, NULL, NULL, NULL)
@@ -199,3 +206,26 @@ FUNC_BC(FUNC_BC_LAST, NULL, NULL, NULL)
199206
#undef ELIDE_FUNC_BC
200207
#undef FUNC_BC
201208
#endif
209+
210+
// int -> int functions
211+
#ifndef FUNC_II
212+
#define ELIDE_FUNC_II
213+
#define FUNC_II(...)
214+
#endif
215+
FUNC_II(FUNC_SIGN_II, "sign_ii", signi, viSign)
216+
FUNC_II(FUNC_II_LAST, NULL, NULL, NULL)
217+
#ifdef ELIDE_FUNC_II
218+
#undef ELIDE_FUNC_II
219+
#undef FUNC_II
220+
#endif
221+
222+
#ifndef FUNC_LL
223+
#define ELIDE_FUNC_LL
224+
#define FUNC_LL(...)
225+
#endif
226+
FUNC_LL(FUNC_SIGN_LL, "sign_LL", signl, vlSign)
227+
FUNC_LL(FUNC_LL_LAST, NULL, NULL, NULL)
228+
#ifdef ELIDE_FUNC_LL
229+
#undef ELIDE_FUNC_LL
230+
#undef FUNC_LL
231+
#endif

0 commit comments

Comments
 (0)