@@ -47,9 +47,33 @@ def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs):
4747 self .unary_fn_ = unary_dp_impl_fn
4848 self .__doc__ = docs
4949
50- def __call__ (self , x , order = "K" ):
50+ def __call__ (self , x , out = None , order = "K" ):
5151 if not isinstance (x , dpt .usm_ndarray ):
5252 raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
53+
54+ if out is not None :
55+ if not isinstance (out , dpt .usm_ndarray ):
56+ raise TypeError (
57+ f"output array must be of usm_ndarray type, got { type (out )} "
58+ )
59+
60+ if out .shape != x .shape :
61+ raise TypeError (
62+ "The shape of input and output arrays are inconsistent."
63+ f"Expected output shape is { x .shape } , got { out .shape } "
64+ )
65+
66+ if ti ._array_overlap (x , out ):
67+ raise TypeError ("Input and output arrays have memory overlap" )
68+
69+ if (
70+ dpctl .utils .get_execution_queue ((x .sycl_queue , out .sycl_queue ))
71+ is None
72+ ):
73+ raise TypeError (
74+ "Input and output allocation queues are not compatible"
75+ )
76+
5377 if order not in ["C" , "F" , "K" , "A" ]:
5478 order = "K"
5579 buf_dt , res_dt = _find_buf_dtype (
@@ -59,17 +83,24 @@ def __call__(self, x, order="K"):
5983 raise RuntimeError
6084 exec_q = x .sycl_queue
6185 if buf_dt is None :
62- if order == "K" :
63- r = _empty_like_orderK (x , res_dt )
86+ if out is None :
87+ if order == "K" :
88+ out = _empty_like_orderK (x , res_dt )
89+ else :
90+ if order == "A" :
91+ order = "F" if x .flags .f_contiguous else "C"
92+ out = dpt .empty_like (x , dtype = res_dt , order = order )
6493 else :
65- if order == "A" :
66- order = "F" if x .flags .f_contiguous else "C"
67- r = dpt .empty_like (x , dtype = res_dt , order = order )
94+ if res_dt != out .dtype :
95+ raise TypeError (
96+ f"Output array of type { res_dt } is needed,"
97+ f" got { out .dtype } "
98+ )
6899
69- ht , _ = self .unary_fn_ (x , r , sycl_queue = exec_q )
100+ ht , _ = self .unary_fn_ (x , out , sycl_queue = exec_q )
70101 ht .wait ()
71102
72- return r
103+ return out
73104 if order == "K" :
74105 buf = _empty_like_orderK (x , buf_dt )
75106 else :
@@ -80,16 +111,22 @@ def __call__(self, x, order="K"):
80111 ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
81112 src = x , dst = buf , sycl_queue = exec_q
82113 )
83- if order == "K" :
84- r = _empty_like_orderK (buf , res_dt )
114+ if out is None :
115+ if order == "K" :
116+ out = _empty_like_orderK (buf , res_dt )
117+ else :
118+ out = dpt .empty_like (buf , dtype = res_dt , order = order )
85119 else :
86- r = dpt .empty_like (buf , dtype = res_dt , order = order )
120+ if buf_dt != out .dtype :
121+ raise TypeError (
122+ f"Output array of type { buf_dt } is needed, got { out .dtype } "
123+ )
87124
88- ht , _ = self .unary_fn_ (buf , r , sycl_queue = exec_q , depends = [copy_ev ])
125+ ht , _ = self .unary_fn_ (buf , out , sycl_queue = exec_q , depends = [copy_ev ])
89126 ht_copy_ev .wait ()
90127 ht .wait ()
91128
92- return r
129+ return out
93130
94131
95132def _get_queue_usm_type (o ):
@@ -281,7 +318,7 @@ def __str__(self):
281318 def __repr__ (self ):
282319 return f"<BinaryElementwiseFunc '{ self .name_ } '>"
283320
284- def __call__ (self , o1 , o2 , order = "K" ):
321+ def __call__ (self , o1 , o2 , out = None , order = "K" ):
285322 if order not in ["K" , "C" , "F" , "A" ]:
286323 order = "K"
287324 q1 , o1_usm_type = _get_queue_usm_type (o1 )
@@ -358,6 +395,31 @@ def __call__(self, o1, o2, order="K"):
358395 "supported types according to the casting rule ''safe''."
359396 )
360397
398+ if out is not None :
399+ if not isinstance (out , dpt .usm_ndarray ):
400+ raise TypeError (
401+ f"output array must be of usm_ndarray type, got { type (out )} "
402+ )
403+
404+ if out .shape != res_shape :
405+ raise TypeError (
406+ "The shape of input and output arrays are inconsistent."
407+ f"Expected output shape is { o1_shape } , got { out .shape } "
408+ )
409+
410+ if ti ._array_overlap (o1 , out ) or ti ._array_overlap (o2 , out ):
411+ raise TypeError ("Input and output arrays have memory overlap" )
412+
413+ if (
414+ dpctl .utils .get_execution_queue (
415+ (o1 .sycl_queue , o2 .sycl_queue , out .sycl_queue )
416+ )
417+ is None
418+ ):
419+ raise TypeError (
420+ "Input and output allocation queues are not compatible"
421+ )
422+
361423 if isinstance (o1 , dpt .usm_ndarray ):
362424 src1 = o1
363425 else :
@@ -368,37 +430,45 @@ def __call__(self, o1, o2, order="K"):
368430 src2 = dpt .asarray (o2 , dtype = o2_dtype , sycl_queue = exec_q )
369431
370432 if buf1_dt is None and buf2_dt is None :
371- if order == "K" :
372- r = _empty_like_pair_orderK (
373- src1 , src2 , res_dt , res_usm_type , exec_q
374- )
375- else :
376- if order == "A" :
377- order = (
378- "F"
379- if all (
380- arr .flags .f_contiguous
381- for arr in (
382- src1 ,
383- src2 ,
433+ if out is None :
434+ if order == "K" :
435+ out = _empty_like_pair_orderK (
436+ src1 , src2 , res_dt , res_usm_type , exec_q
437+ )
438+ else :
439+ if order == "A" :
440+ order = (
441+ "F"
442+ if all (
443+ arr .flags .f_contiguous
444+ for arr in (
445+ src1 ,
446+ src2 ,
447+ )
384448 )
449+ else "C"
385450 )
386- else "C"
451+ out = dpt .empty (
452+ res_shape ,
453+ dtype = res_dt ,
454+ usm_type = res_usm_type ,
455+ sycl_queue = exec_q ,
456+ order = order ,
387457 )
388- r = dpt . empty (
389- res_shape ,
390- dtype = res_dt ,
391- usm_type = res_usm_type ,
392- sycl_queue = exec_q ,
393- order = order ,
394- )
458+ else :
459+ if res_dt != out . dtype :
460+ raise TypeError (
461+ f"Output array of type { res_dt } is needed,"
462+ f"got { out . dtype } "
463+ )
464+
395465 src1 = dpt .broadcast_to (src1 , res_shape )
396466 src2 = dpt .broadcast_to (src2 , res_shape )
397467 ht_ , _ = self .binary_fn_ (
398- src1 = src1 , src2 = src2 , dst = r , sycl_queue = exec_q
468+ src1 = src1 , src2 = src2 , dst = out , sycl_queue = exec_q
399469 )
400470 ht_ .wait ()
401- return r
471+ return out
402472 elif buf1_dt is None :
403473 if order == "K" :
404474 buf2 = _empty_like_orderK (src2 , buf2_dt )
@@ -409,30 +479,38 @@ def __call__(self, o1, o2, order="K"):
409479 ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
410480 src = src2 , dst = buf2 , sycl_queue = exec_q
411481 )
412- if order == "K" :
413- r = _empty_like_pair_orderK (
414- src1 , buf2 , res_dt , res_usm_type , exec_q
415- )
482+ if out is None :
483+ if order == "K" :
484+ out = _empty_like_pair_orderK (
485+ src1 , buf2 , res_dt , res_usm_type , exec_q
486+ )
487+ else :
488+ out = dpt .empty (
489+ res_shape ,
490+ dtype = res_dt ,
491+ usm_type = res_usm_type ,
492+ sycl_queue = exec_q ,
493+ order = order ,
494+ )
416495 else :
417- r = dpt .empty (
418- res_shape ,
419- dtype = res_dt ,
420- usm_type = res_usm_type ,
421- sycl_queue = exec_q ,
422- order = order ,
423- )
496+ if res_dt != out .dtype :
497+ raise TypeError (
498+ f"Output array of type { res_dt } is needed,"
499+ f"got { out .dtype } "
500+ )
501+
424502 src1 = dpt .broadcast_to (src1 , res_shape )
425503 buf2 = dpt .broadcast_to (buf2 , res_shape )
426504 ht_ , _ = self .binary_fn_ (
427505 src1 = src1 ,
428506 src2 = buf2 ,
429- dst = r ,
507+ dst = out ,
430508 sycl_queue = exec_q ,
431509 depends = [copy_ev ],
432510 )
433511 ht_copy_ev .wait ()
434512 ht_ .wait ()
435- return r
513+ return out
436514 elif buf2_dt is None :
437515 if order == "K" :
438516 buf1 = _empty_like_orderK (src1 , buf1_dt )
@@ -443,30 +521,38 @@ def __call__(self, o1, o2, order="K"):
443521 ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
444522 src = src1 , dst = buf1 , sycl_queue = exec_q
445523 )
446- if order == "K" :
447- r = _empty_like_pair_orderK (
448- buf1 , src2 , res_dt , res_usm_type , exec_q
449- )
524+ if out is None :
525+ if order == "K" :
526+ out = _empty_like_pair_orderK (
527+ buf1 , src2 , res_dt , res_usm_type , exec_q
528+ )
529+ else :
530+ out = dpt .empty (
531+ res_shape ,
532+ dtype = res_dt ,
533+ usm_type = res_usm_type ,
534+ sycl_queue = exec_q ,
535+ order = order ,
536+ )
450537 else :
451- r = dpt .empty (
452- res_shape ,
453- dtype = res_dt ,
454- usm_type = res_usm_type ,
455- sycl_queue = exec_q ,
456- order = order ,
457- )
538+ if res_dt != out .dtype :
539+ raise TypeError (
540+ f"Output array of type { res_dt } is needed,"
541+ f"got { out .dtype } "
542+ )
543+
458544 buf1 = dpt .broadcast_to (buf1 , res_shape )
459545 src2 = dpt .broadcast_to (src2 , res_shape )
460546 ht_ , _ = self .binary_fn_ (
461547 src1 = buf1 ,
462548 src2 = src2 ,
463- dst = r ,
549+ dst = out ,
464550 sycl_queue = exec_q ,
465551 depends = [copy_ev ],
466552 )
467553 ht_copy_ev .wait ()
468554 ht_ .wait ()
469- return r
555+ return out
470556
471557 if order in ["K" , "A" ]:
472558 if src1 .flags .f_contiguous and src2 .flags .f_contiguous :
@@ -489,26 +575,33 @@ def __call__(self, o1, o2, order="K"):
489575 ht_copy2_ev , copy2_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
490576 src = src2 , dst = buf2 , sycl_queue = exec_q
491577 )
492- if order == "K" :
493- r = _empty_like_pair_orderK (
494- buf1 , buf2 , res_dt , res_usm_type , exec_q
495- )
578+ if out is None :
579+ if order == "K" :
580+ out = _empty_like_pair_orderK (
581+ buf1 , buf2 , res_dt , res_usm_type , exec_q
582+ )
583+ else :
584+ out = dpt .empty (
585+ res_shape ,
586+ dtype = res_dt ,
587+ usm_type = res_usm_type ,
588+ sycl_queue = exec_q ,
589+ order = order ,
590+ )
496591 else :
497- r = dpt .empty (
498- res_shape ,
499- dtype = res_dt ,
500- usm_type = res_usm_type ,
501- sycl_queue = exec_q ,
502- order = order ,
503- )
592+ if res_dt != out .dtype :
593+ raise TypeError (
594+ f"Output array of type { res_dt } is needed, got { out .dtype } "
595+ )
596+
504597 buf1 = dpt .broadcast_to (buf1 , res_shape )
505598 buf2 = dpt .broadcast_to (buf2 , res_shape )
506599 ht_ , _ = self .binary_fn_ (
507600 src1 = buf1 ,
508601 src2 = buf2 ,
509- dst = r ,
602+ dst = out ,
510603 sycl_queue = exec_q ,
511604 depends = [copy1_ev , copy2_ev ],
512605 )
513606 dpctl .SyclEvent .wait_for ([ht_copy1_ev , ht_copy2_ev , ht_ ])
514- return r
607+ return out
0 commit comments