@@ -1081,44 +1081,12 @@ string_isnan_resolve_descriptors(
10811081 * Copied from NumPy, because NumPy doesn't always use it :)
10821082 */
10831083static int
1084- ufunc_promoter_internal (PyUFuncObject * ufunc , PyArray_DTypeMeta * op_dtypes [],
1085- PyArray_DTypeMeta * signature [],
1086- PyArray_DTypeMeta * new_op_dtypes [],
1087- PyArray_DTypeMeta * final_dtype )
1084+ string_inputs_promoter (PyUFuncObject * ufunc , PyArray_DTypeMeta * op_dtypes [],
1085+ PyArray_DTypeMeta * signature [],
1086+ PyArray_DTypeMeta * new_op_dtypes [],
1087+ PyArray_DTypeMeta * final_dtype )
10881088{
1089- /* If nin < 2 promotion is a no-op, so it should not be registered */
1090- assert (ufunc -> nin > 1 );
1091- if (op_dtypes [0 ] == NULL ) {
1092- assert (ufunc -> nin == 2 && ufunc -> nout == 1 ); /* must be reduction */
1093- Py_INCREF (op_dtypes [1 ]);
1094- new_op_dtypes [0 ] = op_dtypes [1 ];
1095- Py_INCREF (op_dtypes [1 ]);
1096- new_op_dtypes [1 ] = op_dtypes [1 ];
1097- Py_INCREF (op_dtypes [1 ]);
1098- new_op_dtypes [2 ] = op_dtypes [1 ];
1099- return 0 ;
1100- }
1101- PyArray_DTypeMeta * common = NULL ;
1102- /*
1103- * If a signature is used and homogeneous in its outputs use that
1104- * (Could/should likely be rather applied to inputs also, although outs
1105- * only could have some advantage and input dtypes are rarely enforced.)
1106- */
1107- for (int i = ufunc -> nin ; i < ufunc -> nargs ; i ++ ) {
1108- if (signature [i ] != NULL ) {
1109- if (common == NULL ) {
1110- Py_INCREF (signature [i ]);
1111- common = signature [i ];
1112- }
1113- else if (common != signature [i ]) {
1114- Py_CLEAR (common ); /* Not homogeneous, unset common */
1115- break ;
1116- }
1117- }
1118- }
1119- Py_XDECREF (common );
1120-
1121- /* Otherwise, set all input operands to final_dtype */
1089+ /* set all input operands to final_dtype */
11221090 for (int i = 0 ; i < ufunc -> nargs ; i ++ ) {
11231091 PyArray_DTypeMeta * tmp = final_dtype ;
11241092 if (signature [i ]) {
@@ -1127,6 +1095,7 @@ ufunc_promoter_internal(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
11271095 Py_INCREF (tmp );
11281096 new_op_dtypes [i ] = tmp ;
11291097 }
1098+ /* don't touch output dtypes */
11301099 for (int i = ufunc -> nin ; i < ufunc -> nargs ; i ++ ) {
11311100 Py_XINCREF (op_dtypes [i ]);
11321101 new_op_dtypes [i ] = op_dtypes [i ];
@@ -1140,19 +1109,50 @@ string_object_promoter(PyObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
11401109 PyArray_DTypeMeta * signature [],
11411110 PyArray_DTypeMeta * new_op_dtypes [])
11421111{
1143- return ufunc_promoter_internal ((PyUFuncObject * )ufunc , op_dtypes ,
1144- signature , new_op_dtypes ,
1145- (PyArray_DTypeMeta * )& PyArray_ObjectDType );
1112+ return string_inputs_promoter ((PyUFuncObject * )ufunc , op_dtypes , signature ,
1113+ new_op_dtypes ,
1114+ (PyArray_DTypeMeta * )& PyArray_ObjectDType );
11461115}
11471116
11481117static int
11491118string_unicode_promoter (PyObject * ufunc , PyArray_DTypeMeta * op_dtypes [],
11501119 PyArray_DTypeMeta * signature [],
11511120 PyArray_DTypeMeta * new_op_dtypes [])
11521121{
1153- return ufunc_promoter_internal ((PyUFuncObject * )ufunc , op_dtypes ,
1154- signature , new_op_dtypes ,
1155- (PyArray_DTypeMeta * )& StringDType );
1122+ return string_inputs_promoter ((PyUFuncObject * )ufunc , op_dtypes , signature ,
1123+ new_op_dtypes ,
1124+ (PyArray_DTypeMeta * )& StringDType );
1125+ }
1126+
1127+ static int
1128+ string_multiply_promoter (PyObject * ufunc_obj , PyArray_DTypeMeta * op_dtypes [],
1129+ PyArray_DTypeMeta * signature [],
1130+ PyArray_DTypeMeta * new_op_dtypes [])
1131+ {
1132+ PyUFuncObject * ufunc = (PyUFuncObject * )ufunc_obj ;
1133+ for (int i = 0 ; i < ufunc -> nargs ; i ++ ) {
1134+ PyArray_DTypeMeta * tmp = NULL ;
1135+ if (signature [i ]) {
1136+ tmp = signature [i ];
1137+ }
1138+ else if (op_dtypes [i ] == & PyArray_PyIntAbstractDType ) {
1139+ tmp = & PyArray_Int64DType ;
1140+ }
1141+ else if (op_dtypes [i ]) {
1142+ tmp = op_dtypes [i ];
1143+ }
1144+ else {
1145+ tmp = (PyArray_DTypeMeta * )& StringDType ;
1146+ }
1147+ Py_INCREF (tmp );
1148+ new_op_dtypes [i ] = tmp ;
1149+ }
1150+ /* don't touch output dtypes */
1151+ for (int i = ufunc -> nin ; i < ufunc -> nargs ; i ++ ) {
1152+ Py_XINCREF (op_dtypes [i ]);
1153+ new_op_dtypes [i ] = op_dtypes [i ];
1154+ }
1155+ return 0 ;
11561156}
11571157
11581158// Register a ufunc.
@@ -1161,14 +1161,18 @@ string_unicode_promoter(PyObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
11611161int
11621162init_ufunc (PyObject * numpy , const char * ufunc_name , PyArray_DTypeMeta * * dtypes ,
11631163 resolve_descriptors_function * resolve_func ,
1164- PyArrayMethod_StridedLoop * loop_func , const char * loop_name ,
1165- int nin , int nout , NPY_CASTING casting , NPY_ARRAYMETHOD_FLAGS flags )
1164+ PyArrayMethod_StridedLoop * loop_func , int nin , int nout ,
1165+ NPY_CASTING casting , NPY_ARRAYMETHOD_FLAGS flags )
11661166{
11671167 PyObject * ufunc = PyObject_GetAttrString (numpy , ufunc_name );
11681168 if (ufunc == NULL ) {
11691169 return -1 ;
11701170 }
11711171
1172+ char loop_name [256 ] = {0 };
1173+
1174+ snprintf (loop_name , sizeof (loop_name ), "string_%s" , ufunc_name );
1175+
11721176 PyArrayMethod_Spec spec = {
11731177 .name = loop_name ,
11741178 .nin = nin ,
@@ -1208,7 +1212,7 @@ add_promoter(PyObject *numpy, const char *ufunc_name,
12081212 PyArray_DTypeMeta * ldtype , PyArray_DTypeMeta * rdtype ,
12091213 PyArray_DTypeMeta * edtype , promoter_function * promoter_impl )
12101214{
1211- PyObject * ufunc = PyObject_GetAttrString (numpy , ufunc_name );
1215+ PyObject * ufunc = PyObject_GetAttrString (( PyObject * ) numpy , ufunc_name );
12121216
12131217 if (ufunc == NULL ) {
12141218 return -1 ;
@@ -1251,8 +1255,8 @@ add_promoter(PyObject *numpy, const char *ufunc_name,
12511255 \
12521256 if (init_ufunc(numpy, "multiply", multiply_right_##shortname##_types, \
12531257 &multiply_resolve_descriptors, \
1254- &multiply_right_##shortname##_strided_loop, \
1255- "string_multiply", 2, 1, NPY_NO_CASTING, 0) < 0) { \
1258+ &multiply_right_##shortname##_strided_loop, 2, 1, \
1259+ NPY_NO_CASTING, 0) < 0) { \
12561260 goto error; \
12571261 } \
12581262 \
@@ -1262,8 +1266,8 @@ add_promoter(PyObject *numpy, const char *ufunc_name,
12621266 \
12631267 if (init_ufunc(numpy, "multiply", multiply_left_##shortname##_types, \
12641268 &multiply_resolve_descriptors, \
1265- &multiply_left_##shortname##_strided_loop, \
1266- "string_multiply", 2, 1, NPY_NO_CASTING, 0) < 0) { \
1269+ &multiply_left_##shortname##_strided_loop, 2, 1, \
1270+ NPY_NO_CASTING, 0) < 0) { \
12671271 goto error; \
12681272 }
12691273
@@ -1279,53 +1283,23 @@ init_ufuncs(void)
12791283 "greater" , "greater_equal" ,
12801284 "less" , "less_equal" };
12811285
1286+ static PyArrayMethod_StridedLoop * strided_loops [6 ] = {
1287+ & string_equal_strided_loop , & string_not_equal_strided_loop ,
1288+ & string_greater_strided_loop , & string_greater_equal_strided_loop ,
1289+ & string_less_strided_loop , & string_less_equal_strided_loop ,
1290+ };
1291+
12821292 PyArray_DTypeMeta * comparison_dtypes [] = {
12831293 (PyArray_DTypeMeta * )& StringDType ,
12841294 (PyArray_DTypeMeta * )& StringDType , & PyArray_BoolDType };
12851295
1286- if (init_ufunc (numpy , "equal" , comparison_dtypes ,
1287- & string_comparison_resolve_descriptors ,
1288- & string_equal_strided_loop , "string_equal" , 2 , 1 ,
1289- NPY_NO_CASTING , 0 ) < 0 ) {
1290- goto error ;
1291- }
1292-
1293- if (init_ufunc (numpy , "not_equal" , comparison_dtypes ,
1294- & string_comparison_resolve_descriptors ,
1295- & string_not_equal_strided_loop , "string_not_equal" , 2 , 1 ,
1296- NPY_NO_CASTING , 0 ) < 0 ) {
1297- goto error ;
1298- }
1299-
1300- if (init_ufunc (numpy , "greater" , comparison_dtypes ,
1301- & string_comparison_resolve_descriptors ,
1302- & string_greater_strided_loop , "string_greater" , 2 , 1 ,
1303- NPY_NO_CASTING , 0 ) < 0 ) {
1304- goto error ;
1305- }
1306-
1307- if (init_ufunc (numpy , "greater_equal" , comparison_dtypes ,
1308- & string_comparison_resolve_descriptors ,
1309- & string_greater_equal_strided_loop , "string_greater_equal" ,
1310- 2 , 1 , NPY_NO_CASTING , 0 ) < 0 ) {
1311- goto error ;
1312- }
1313-
1314- if (init_ufunc (numpy , "less" , comparison_dtypes ,
1315- & string_comparison_resolve_descriptors ,
1316- & string_less_strided_loop , "string_less" , 2 , 1 ,
1317- NPY_NO_CASTING , 0 ) < 0 ) {
1318- goto error ;
1319- }
1320-
1321- if (init_ufunc (numpy , "less_equal" , comparison_dtypes ,
1322- & string_comparison_resolve_descriptors ,
1323- & string_less_equal_strided_loop , "string_less_equal" , 2 , 1 ,
1324- NPY_NO_CASTING , 0 ) < 0 ) {
1325- goto error ;
1326- }
1327-
13281296 for (int i = 0 ; i < 6 ; i ++ ) {
1297+ if (init_ufunc (numpy , comparison_ufunc_names [i ], comparison_dtypes ,
1298+ & string_comparison_resolve_descriptors ,
1299+ strided_loops [i ], 2 , 1 , NPY_NO_CASTING , 0 ) < 0 ) {
1300+ goto error ;
1301+ }
1302+
13291303 if (add_promoter (numpy , comparison_ufunc_names [i ],
13301304 (PyArray_DTypeMeta * )& StringDType ,
13311305 & PyArray_UnicodeDType , & PyArray_BoolDType ,
@@ -1360,8 +1334,7 @@ init_ufuncs(void)
13601334
13611335 if (init_ufunc (numpy , "isnan" , isnan_dtypes ,
13621336 & string_isnan_resolve_descriptors ,
1363- & string_isnan_strided_loop , "string_isnan" , 1 , 1 ,
1364- NPY_NO_CASTING , 0 ) < 0 ) {
1337+ & string_isnan_strided_loop , 1 , 1 , NPY_NO_CASTING , 0 ) < 0 ) {
13651338 goto error ;
13661339 }
13671340
@@ -1372,20 +1345,17 @@ init_ufuncs(void)
13721345 };
13731346
13741347 if (init_ufunc (numpy , "maximum" , binary_dtypes , binary_resolve_descriptors ,
1375- & maximum_strided_loop , "string_maximum" , 2 , 1 ,
1376- NPY_NO_CASTING , 0 ) < 0 ) {
1348+ & maximum_strided_loop , 2 , 1 , NPY_NO_CASTING , 0 ) < 0 ) {
13771349 goto error ;
13781350 }
13791351
13801352 if (init_ufunc (numpy , "minimum" , binary_dtypes , binary_resolve_descriptors ,
1381- & minimum_strided_loop , "string_minimum" , 2 , 1 ,
1382- NPY_NO_CASTING , 0 ) < 0 ) {
1353+ & minimum_strided_loop , 2 , 1 , NPY_NO_CASTING , 0 ) < 0 ) {
13831354 goto error ;
13841355 }
13851356
13861357 if (init_ufunc (numpy , "add" , binary_dtypes , binary_resolve_descriptors ,
1387- & add_strided_loop , "string_add" , 2 , 1 , NPY_NO_CASTING ,
1388- 0 ) < 0 ) {
1358+ & add_strided_loop , 2 , 1 , NPY_NO_CASTING , 0 ) < 0 ) {
13891359 goto error ;
13901360 }
13911361
@@ -1414,6 +1384,20 @@ init_ufuncs(void)
14141384 INIT_MULTIPLY (ULongLong , ulonglong );
14151385#endif
14161386
1387+ if (add_promoter (numpy , "multiply" , (PyArray_DTypeMeta * )& StringDType ,
1388+ & PyArray_PyIntAbstractDType ,
1389+ (PyArray_DTypeMeta * )& StringDType ,
1390+ string_multiply_promoter ) < 0 ) {
1391+ goto error ;
1392+ }
1393+
1394+ if (add_promoter (numpy , "multiply" , & PyArray_PyIntAbstractDType ,
1395+ (PyArray_DTypeMeta * )& StringDType ,
1396+ (PyArray_DTypeMeta * )& StringDType ,
1397+ string_multiply_promoter ) < 0 ) {
1398+ goto error ;
1399+ }
1400+
14171401 Py_DECREF (numpy );
14181402 return 0 ;
14191403
0 commit comments