@@ -320,33 +320,23 @@ def elemwise_wrapper(*inputs):
320320
321321 # Pure python implementation, that will be used in tests
322322 def elemwise (* inputs ):
323- inputs = [ np . asarray ( input ) for input in inputs ]
323+ Elemwise . _check_runtime_broadcast ( node , inputs )
324324 inputs_bc = np .broadcast_arrays (* inputs )
325- shape = inputs [0 ].shape
326- for input , bc in zip (inputs , input_bc_patterns , strict = True ):
327- for length , allow_bc , iter_length in zip (
328- input .shape , bc , shape , strict = True
329- ):
330- if length == 1 and shape and iter_length != 1 and not allow_bc :
331- raise ValueError ("Broadcast not allowed." )
332-
333- outputs = [np .empty (shape , dtype = dtype ) for dtype in output_dtypes ]
334-
335- for idx in np .ndindex (shape ):
336- vals = [input [idx ] for input in inputs_bc ]
337- outs = scalar_op_fn (* vals )
338- if not isinstance (outs , tuple ):
339- outs = (outs ,)
340- for out , out_val in zip (outputs , outs , strict = True ):
341- out [idx ] = out_val
342-
343- outputs_summed = []
344- for output , bc in zip (outputs , output_bc_patterns , strict = True ):
345- axes = tuple (np .nonzero (bc )[0 ])
346- outputs_summed .append (output .sum (axes , keepdims = True ))
347- if len (outputs_summed ) != 1 :
348- return tuple (outputs_summed )
349- return outputs_summed [0 ]
325+ shape = inputs_bc [0 ].shape
326+
327+ if len (output_dtypes ) == 1 :
328+ output = np .empty (shape , dtype = output_dtypes [0 ])
329+ for idx in np .ndindex (shape ):
330+ output [idx ] = scalar_op_fn (* (inp [idx ] for inp in inputs_bc ))
331+ return output
332+
333+ else :
334+ outputs = [np .empty (shape , dtype = dtype ) for dtype in output_dtypes ]
335+ for idx in np .ndindex (shape ):
336+ outs_vals = scalar_op_fn (* (inp [idx ] for inp in inputs_bc ))
337+ for out , out_val in zip (outputs , outs_vals ):
338+ out [idx ] = out_val
339+ return outputs
350340
351341 @overload (elemwise )
352342 def ov_elemwise (* inputs ):
@@ -594,7 +584,7 @@ def numba_funcify_Argmax(op, node, **kwargs):
594584
595585 if x_ndim == 0 :
596586
597- @numba_basic .numba_njit ( inline = "always" )
587+ @numba_basic .numba_njit
598588 def argmax (x ):
599589 return np .array (0 , dtype = "int64" )
600590
0 commit comments