|
6 | 6 | from typing import cast, overload |
7 | 7 |
|
8 | 8 | import numpy as np |
| 9 | +from numpy.lib.array_utils import normalize_axis_tuple |
9 | 10 |
|
10 | 11 | import pytensor |
11 | 12 | from pytensor import scalar as ps |
|
18 | 19 | from pytensor.graph.utils import MethodNotDefined |
19 | 20 | from pytensor.link.c.op import COp |
20 | 21 | from pytensor.link.c.params_type import ParamsType |
21 | | -from pytensor.npy_2_compat import normalize_axis_tuple, numpy_version, using_numpy_2 |
22 | 22 | from pytensor.printing import Printer, pprint, set_precedence |
23 | 23 | from pytensor.scalar.basic import ScalarConstant, ScalarVariable |
24 | 24 | from pytensor.tensor import ( |
@@ -2330,199 +2330,6 @@ def copy_of_x(self, x): |
2330 | 2330 | return f"""(PyArrayObject*)PyArray_FromAny(py_{x}, NULL, 0, 0, |
2331 | 2331 | NPY_ARRAY_ENSURECOPY, NULL)""" |
2332 | 2332 |
|
2333 | | - def c_support_code(self, **kwargs): |
2334 | | - if numpy_version < "1.8.0" or using_numpy_2: |
2335 | | - return None |
2336 | | - |
2337 | | - types = [ |
2338 | | - "npy_" + t |
2339 | | - for t in [ |
2340 | | - "int8", |
2341 | | - "int16", |
2342 | | - "int32", |
2343 | | - "int64", |
2344 | | - "uint8", |
2345 | | - "uint16", |
2346 | | - "uint32", |
2347 | | - "uint64", |
2348 | | - "float16", |
2349 | | - "float32", |
2350 | | - "float64", |
2351 | | - ] |
2352 | | - ] |
2353 | | - |
2354 | | - complex_types = ["npy_" + t for t in ("complex32", "complex64", "complex128")] |
2355 | | - |
2356 | | - inplace_map_template = """ |
2357 | | - #if defined(%(typen)s) |
2358 | | - static void %(type)s_inplace_add(PyArrayMapIterObject *mit, |
2359 | | - PyArrayIterObject *it, int inc_or_set) |
2360 | | - { |
2361 | | - int index = mit->size; |
2362 | | - while (index--) { |
2363 | | - %(op)s |
2364 | | -
|
2365 | | - PyArray_MapIterNext(mit); |
2366 | | - PyArray_ITER_NEXT(it); |
2367 | | - } |
2368 | | - } |
2369 | | - #endif |
2370 | | - """ |
2371 | | - |
2372 | | - floatadd = ( |
2373 | | - "((%(type)s*)mit->dataptr)[0] = " |
2374 | | - "(inc_or_set ? ((%(type)s*)mit->dataptr)[0] : 0)" |
2375 | | - " + ((%(type)s*)it->dataptr)[0];" |
2376 | | - ) |
2377 | | - complexadd = """ |
2378 | | - ((%(type)s*)mit->dataptr)[0].real = |
2379 | | - (inc_or_set ? ((%(type)s*)mit->dataptr)[0].real : 0) |
2380 | | - + ((%(type)s*)it->dataptr)[0].real; |
2381 | | - ((%(type)s*)mit->dataptr)[0].imag = |
2382 | | - (inc_or_set ? ((%(type)s*)mit->dataptr)[0].imag : 0) |
2383 | | - + ((%(type)s*)it->dataptr)[0].imag; |
2384 | | - """ |
2385 | | - |
2386 | | - fns = "".join( |
2387 | | - [ |
2388 | | - inplace_map_template |
2389 | | - % {"type": t, "typen": t.upper(), "op": floatadd % {"type": t}} |
2390 | | - for t in types |
2391 | | - ] |
2392 | | - + [ |
2393 | | - inplace_map_template |
2394 | | - % {"type": t, "typen": t.upper(), "op": complexadd % {"type": t}} |
2395 | | - for t in complex_types |
2396 | | - ] |
2397 | | - ) |
2398 | | - |
2399 | | - def gen_binop(type, typen): |
2400 | | - return f""" |
2401 | | - #if defined({typen}) |
2402 | | - {type}_inplace_add, |
2403 | | - #endif |
2404 | | - """ |
2405 | | - |
2406 | | - fn_array = ( |
2407 | | - "static inplace_map_binop addition_funcs[] = {" |
2408 | | - + "".join(gen_binop(type=t, typen=t.upper()) for t in types + complex_types) |
2409 | | - + "NULL};\n" |
2410 | | - ) |
2411 | | - |
2412 | | - def gen_num(typen): |
2413 | | - return f""" |
2414 | | - #if defined({typen}) |
2415 | | - {typen}, |
2416 | | - #endif |
2417 | | - """ |
2418 | | - |
2419 | | - type_number_array = ( |
2420 | | - "static int type_numbers[] = {" |
2421 | | - + "".join(gen_num(typen=t.upper()) for t in types + complex_types) |
2422 | | - + "-1000};" |
2423 | | - ) |
2424 | | - |
2425 | | - code = ( |
2426 | | - """ |
2427 | | - typedef void (*inplace_map_binop)(PyArrayMapIterObject *, |
2428 | | - PyArrayIterObject *, int inc_or_set); |
2429 | | - """ |
2430 | | - + fns |
2431 | | - + fn_array |
2432 | | - + type_number_array |
2433 | | - + """ |
2434 | | - static int |
2435 | | - map_increment(PyArrayMapIterObject *mit, PyArrayObject *op, |
2436 | | - inplace_map_binop add_inplace, int inc_or_set) |
2437 | | - { |
2438 | | - PyArrayObject *arr = NULL; |
2439 | | - PyArrayIterObject *it; |
2440 | | - PyArray_Descr *descr; |
2441 | | - if (mit->ait == NULL) { |
2442 | | - return -1; |
2443 | | - } |
2444 | | - descr = PyArray_DESCR(mit->ait->ao); |
2445 | | - Py_INCREF(descr); |
2446 | | - arr = (PyArrayObject *)PyArray_FromAny((PyObject *)op, descr, |
2447 | | - 0, 0, NPY_ARRAY_FORCECAST, NULL); |
2448 | | - if (arr == NULL) { |
2449 | | - return -1; |
2450 | | - } |
2451 | | - if ((mit->subspace != NULL) && (mit->consec)) { |
2452 | | - PyArray_MapIterSwapAxes(mit, (PyArrayObject **)&arr, 0); |
2453 | | - if (arr == NULL) { |
2454 | | - return -1; |
2455 | | - } |
2456 | | - } |
2457 | | - it = (PyArrayIterObject*) |
2458 | | - PyArray_BroadcastToShape((PyObject*)arr, mit->dimensions, mit->nd); |
2459 | | - if (it == NULL) { |
2460 | | - Py_DECREF(arr); |
2461 | | - return -1; |
2462 | | - } |
2463 | | -
|
2464 | | - (*add_inplace)(mit, it, inc_or_set); |
2465 | | -
|
2466 | | - Py_DECREF(arr); |
2467 | | - Py_DECREF(it); |
2468 | | - return 0; |
2469 | | - } |
2470 | | -
|
2471 | | -
|
2472 | | - static int |
2473 | | - inplace_increment(PyArrayObject *a, PyObject *index, PyArrayObject *inc, |
2474 | | - int inc_or_set) |
2475 | | - { |
2476 | | - inplace_map_binop add_inplace = NULL; |
2477 | | - int type_number = -1; |
2478 | | - int i = 0; |
2479 | | - PyArrayMapIterObject * mit; |
2480 | | -
|
2481 | | - if (PyArray_FailUnlessWriteable(a, "input/output array") < 0) { |
2482 | | - return -1; |
2483 | | - } |
2484 | | -
|
2485 | | - if (PyArray_NDIM(a) == 0) { |
2486 | | - PyErr_SetString(PyExc_IndexError, "0-d arrays can't be indexed."); |
2487 | | - return -1; |
2488 | | - } |
2489 | | - type_number = PyArray_TYPE(a); |
2490 | | -
|
2491 | | - while (type_numbers[i] >= 0 && addition_funcs[i] != NULL){ |
2492 | | - if (type_number == type_numbers[i]) { |
2493 | | - add_inplace = addition_funcs[i]; |
2494 | | - break; |
2495 | | - } |
2496 | | - i++ ; |
2497 | | - } |
2498 | | -
|
2499 | | - if (add_inplace == NULL) { |
2500 | | - PyErr_SetString(PyExc_TypeError, "unsupported type for a"); |
2501 | | - return -1; |
2502 | | - } |
2503 | | - mit = (PyArrayMapIterObject *) PyArray_MapIterArray(a, index); |
2504 | | - if (mit == NULL) { |
2505 | | - goto fail; |
2506 | | - } |
2507 | | - if (map_increment(mit, inc, add_inplace, inc_or_set) != 0) { |
2508 | | - goto fail; |
2509 | | - } |
2510 | | -
|
2511 | | - Py_DECREF(mit); |
2512 | | -
|
2513 | | - Py_INCREF(Py_None); |
2514 | | - return 0; |
2515 | | -
|
2516 | | - fail: |
2517 | | - Py_XDECREF(mit); |
2518 | | -
|
2519 | | - return -1; |
2520 | | - } |
2521 | | - """ |
2522 | | - ) |
2523 | | - |
2524 | | - return code |
2525 | | - |
2526 | 2333 | def c_code(self, node, name, input_names, output_names, sub): |
2527 | 2334 | x, y, idx = input_names |
2528 | 2335 | [out] = output_names |
@@ -2636,34 +2443,7 @@ def c_code(self, node, name, input_names, output_names, sub): |
2636 | 2443 | """ |
2637 | 2444 | return code |
2638 | 2445 |
|
2639 | | - if numpy_version < "1.8.0" or using_numpy_2: |
2640 | | - raise NotImplementedError |
2641 | | - |
2642 | | - return f""" |
2643 | | - PyObject* rval = NULL; |
2644 | | - if ({params}->inplace) |
2645 | | - {{ |
2646 | | - if ({x} != {out}) |
2647 | | - {{ |
2648 | | - Py_XDECREF({out}); |
2649 | | - Py_INCREF({x}); |
2650 | | - {out} = {x}; |
2651 | | - }} |
2652 | | - }} |
2653 | | - else |
2654 | | - {{ |
2655 | | - Py_XDECREF({out}); |
2656 | | - {out} = {copy_of_x}; |
2657 | | - if (!{out}) {{ |
2658 | | - // Exception already set |
2659 | | - {fail} |
2660 | | - }} |
2661 | | - }} |
2662 | | - if (inplace_increment({out}, (PyObject *){idx}, {y}, (1 - {params}->set_instead_of_inc))) {{ |
2663 | | - {fail}; |
2664 | | - }} |
2665 | | - Py_XDECREF(rval); |
2666 | | - """ |
| 2446 | + raise NotImplementedError |
2667 | 2447 |
|
2668 | 2448 | def c_code_cache_version(self): |
2669 | 2449 | return (10,) |
|
0 commit comments