3131 _empty_like_pair_orderK ,
3232 _find_buf_dtype ,
3333 _find_buf_dtype2 ,
34+ _find_inplace_dtype ,
3435 _to_device_supported_dtype ,
3536)
3637
@@ -331,11 +332,19 @@ class BinaryElementwiseFunc:
331332 Class that implements binary element-wise functions.
332333 """
333334
334- def __init__ (self , name , result_type_resolver_fn , binary_dp_impl_fn , docs ):
335+ def __init__ (
336+ self ,
337+ name ,
338+ result_type_resolver_fn ,
339+ binary_dp_impl_fn ,
340+ docs ,
341+ binary_inplace_fn = None ,
342+ ):
335343 self .__name__ = "BinaryElementwiseFunc"
336344 self .name_ = name
337345 self .result_type_resolver_fn_ = result_type_resolver_fn
338346 self .binary_fn_ = binary_dp_impl_fn
347+ self .binary_inplace_fn_ = binary_inplace_fn
339348 self .__doc__ = docs
340349
341350 def __str__ (self ):
@@ -345,6 +354,13 @@ def __repr__(self):
345354 return f"<BinaryElementwiseFunc '{ self .name_ } '>"
346355
347356 def __call__ (self , o1 , o2 , out = None , order = "K" ):
357+ # FIXME: replace with check against base array
358+ # when views can be identified
359+ if o1 is out :
360+ return self ._inplace (o1 , o2 )
361+ elif o2 is out :
362+ return self ._inplace (o2 , o1 )
363+
348364 if order not in ["K" , "C" , "F" , "A" ]:
349365 order = "K"
350366 q1 , o1_usm_type = _get_queue_usm_type (o1 )
@@ -388,6 +404,7 @@ def __call__(self, o1, o2, out=None, order="K"):
388404 raise TypeError (
389405 "Shape of arguments can not be inferred. "
390406 "Arguments are expected to be "
407+ "lists, tuples, or both"
391408 )
392409 try :
393410 res_shape = _broadcast_shape_impl (
@@ -415,7 +432,7 @@ def __call__(self, o1, o2, out=None, order="K"):
415432
416433 if res_dt is None :
417434 raise TypeError (
418- "function 'add ' does not support input types "
435+ f "function '{ self . name_ } ' does not support input types "
419436 f"({ o1_dtype } , { o2_dtype } ), "
420437 "and the inputs could not be safely coerced to any "
421438 "supported types according to the casting rule ''safe''."
@@ -631,3 +648,116 @@ def __call__(self, o1, o2, out=None, order="K"):
631648 )
632649 dpctl .SyclEvent .wait_for ([ht_copy1_ev , ht_copy2_ev , ht_ ])
633650 return out
651+
652+ def _inplace (self , lhs , val ):
653+ if self .binary_inplace_fn_ is None :
654+ raise ValueError (
655+ f"In-place operation not supported for ufunc '{ self .name_ } '"
656+ )
657+ if not isinstance (lhs , dpt .usm_ndarray ):
658+ raise TypeError (
659+ f"Expected dpctl.tensor.usm_ndarray, got { type (lhs )} "
660+ )
661+ q1 , lhs_usm_type = _get_queue_usm_type (lhs )
662+ q2 , val_usm_type = _get_queue_usm_type (val )
663+ if q2 is None :
664+ exec_q = q1
665+ usm_type = lhs_usm_type
666+ else :
667+ exec_q = dpctl .utils .get_execution_queue ((q1 , q2 ))
668+ if exec_q is None :
669+ raise ExecutionPlacementError (
670+ "Execution placement can not be unambiguously inferred "
671+ "from input arguments."
672+ )
673+ usm_type = dpctl .utils .get_coerced_usm_type (
674+ (
675+ lhs_usm_type ,
676+ val_usm_type ,
677+ )
678+ )
679+ dpctl .utils .validate_usm_type (usm_type , allow_none = False )
680+ lhs_shape = _get_shape (lhs )
681+ val_shape = _get_shape (val )
682+ if not all (
683+ isinstance (s , (tuple , list ))
684+ for s in (
685+ lhs_shape ,
686+ val_shape ,
687+ )
688+ ):
689+ raise TypeError (
690+ "Shape of arguments can not be inferred. "
691+ "Arguments are expected to be "
692+ "lists, tuples, or both"
693+ )
694+ try :
695+ res_shape = _broadcast_shape_impl (
696+ [
697+ lhs_shape ,
698+ val_shape ,
699+ ]
700+ )
701+ except ValueError :
702+ raise ValueError (
703+ "operands could not be broadcast together with shapes "
704+ f"{ lhs_shape } and { val_shape } "
705+ )
706+ if res_shape != lhs_shape :
707+ raise ValueError (
708+ f"output shape { lhs_shape } does not match "
709+ f"broadcast shape { res_shape } "
710+ )
711+ sycl_dev = exec_q .sycl_device
712+ lhs_dtype = lhs .dtype
713+ val_dtype = _get_dtype (val , sycl_dev )
714+ if not _validate_dtype (val_dtype ):
715+ raise ValueError ("Input operand of unsupported type" )
716+
717+ lhs_dtype , val_dtype = _resolve_weak_types (
718+ lhs_dtype , val_dtype , sycl_dev
719+ )
720+
721+ buf_dt = _find_inplace_dtype (
722+ lhs_dtype , val_dtype , self .result_type_resolver_fn_ , sycl_dev
723+ )
724+
725+ if buf_dt is None :
726+ raise TypeError (
727+ f"In-place '{ self .name_ } ' does not support input types "
728+ f"({ lhs_dtype } , { val_dtype } ), "
729+ "and the inputs could not be safely coerced to any "
730+ "supported types according to the casting rule ''safe''."
731+ )
732+
733+ if isinstance (val , dpt .usm_ndarray ):
734+ rhs = val
735+ overlap = ti ._array_overlap (lhs , rhs )
736+ else :
737+ rhs = dpt .asarray (val , dtype = val_dtype , sycl_queue = exec_q )
738+ overlap = False
739+
740+ if buf_dt == val_dtype and overlap is False :
741+ rhs = dpt .broadcast_to (rhs , res_shape )
742+ ht_ , _ = self .binary_inplace_fn_ (
743+ lhs = lhs , rhs = rhs , sycl_queue = exec_q
744+ )
745+ ht_ .wait ()
746+
747+ else :
748+ buf = dpt .empty_like (rhs , dtype = buf_dt )
749+ ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
750+ src = rhs , dst = buf , sycl_queue = exec_q
751+ )
752+
753+ buf = dpt .broadcast_to (buf , res_shape )
754+ ht_ , _ = self .binary_inplace_fn_ (
755+ lhs = lhs ,
756+ rhs = buf ,
757+ sycl_queue = exec_q ,
758+ depends = [copy_ev ],
759+ )
760+ ht_copy_ev .wait ()
761+ ht_ .wait ()
762+
763+ return lhs
0 commit comments