@@ -2541,7 +2541,7 @@ def perform(self, node, inputs, output_storage):
25412541 )
25422542
25432543 def c_code_cache_version (self ):
2544- return (6 ,)
2544+ return (7 ,)
25452545
25462546 def c_code (self , node , name , inputs , outputs , sub ):
25472547 axis , * arrays = inputs
@@ -2580,16 +2580,86 @@ def c_code(self, node, name, inputs, outputs, sub):
25802580 code = f"""
25812581 int axis = { axis_def }
25822582 PyArrayObject* arrays[{ n } ] = {{{ ',' .join (arrays )} }};
2583- PyObject* arrays_tuple = PyTuple_New( { n } ) ;
2583+ int out_is_valid = { out } != NULL ;
25842584
25852585 { axis_check }
25862586
2587- Py_XDECREF({ out } );
2588- { copy_arrays_to_tuple }
2589- { out } = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2590- Py_DECREF(arrays_tuple);
2591- if(!{ out } ){{
2592- { fail }
2587+ if (out_is_valid) {{
2588+ // Check if we can reuse output
2589+ npy_intp join_size = 0;
2590+ npy_intp out_shape[{ ndim } ];
2591+ npy_intp *shape = PyArray_SHAPE(arrays[0]);
2592+
2593+ for (int i = 0; i < { n } ; i++) {{
2594+ if (PyArray_NDIM(arrays[i]) != { ndim } ) {{
2595+ PyErr_SetString(PyExc_ValueError, "Input to join has wrong ndim");
2596+ { fail }
2597+ }}
2598+
2599+ join_size += PyArray_SHAPE(arrays[i])[axis];
2600+
2601+ if (i > 0){{
2602+ for (int j = 0; j < { ndim } ; j++) {{
2603+ if ((j != axis) && (PyArray_SHAPE(arrays[i])[j] != shape[j])) {{
2604+ PyErr_SetString(PyExc_ValueError, "Arrays shape must match along non join axis");
2605+ { fail }
2606+ }}
2607+ }}
2608+ }}
2609+ }}
2610+
2611+ memcpy(out_shape, shape, { ndim } * sizeof(npy_intp));
2612+ out_shape[axis] = join_size;
2613+
2614+ for (int i = 0; i < { ndim } ; i++) {{
2615+ out_is_valid &= (PyArray_SHAPE({ out } )[i] == out_shape[i]);
2616+ }}
2617+ }}
2618+
2619+ if (!out_is_valid) {{
2620+ // Use PyArray_Concatenate
2621+ Py_XDECREF({ out } );
2622+ PyObject* arrays_tuple = PyTuple_New({ n } );
2623+ { copy_arrays_to_tuple }
2624+ { out } = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2625+ Py_DECREF(arrays_tuple);
2626+ if(!{ out } ){{
2627+ { fail }
2628+ }}
2629+ }}
2630+ else {{
2631+ // Copy the data to the pre-allocated output buffer
2632+
2633+ // Create view into output buffer
2634+ PyArrayObject_fields *view;
2635+
2636+ // PyArray_NewFromDescr steals a reference to descr, so we need to increase it
2637+ Py_INCREF(PyArray_DESCR({ out } ));
2638+ view = (PyArrayObject_fields *)PyArray_NewFromDescr(&PyArray_Type,
2639+ PyArray_DESCR({ out } ),
2640+ { ndim } ,
2641+ PyArray_SHAPE(arrays[0]),
2642+ PyArray_STRIDES({ out } ),
2643+ PyArray_DATA({ out } ),
2644+ NPY_ARRAY_WRITEABLE,
2645+ NULL);
2646+ if (view == NULL) {{
2647+ { fail }
2648+ }}
2649+
2650+ // Copy data into output buffer
2651+ for (int i = 0; i < { n } ; i++) {{
2652+ view->dimensions[axis] = PyArray_SHAPE(arrays[i])[axis];
2653+
2654+ if (PyArray_CopyInto((PyArrayObject*)view, arrays[i]) != 0) {{
2655+ Py_DECREF(view);
2656+ { fail }
2657+ }}
2658+
2659+ view->data += (view->dimensions[axis] * view->strides[axis]);
2660+ }}
2661+
2662+ Py_DECREF(view);
25932663 }}
25942664 """
25952665 return code
0 commit comments