@@ -52,6 +52,20 @@ def __call__(self, x, out=None, order="K"):
5252 if not isinstance (x , dpt .usm_ndarray ):
5353 raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
5454
55+ if order not in ["C" , "F" , "K" , "A" ]:
56+ order = "K"
57+ buf_dt , res_dt = _find_buf_dtype (
58+ x .dtype , self .result_type_resolver_fn_ , x .sycl_device
59+ )
60+ if res_dt is None :
61+ raise TypeError (
62+ f"function '{ self .name_ } ' does not support input type "
63+ f"({ x .dtype } ), "
64+ "and the input could not be safely coerced to any "
65+ "supported types according to the casting rule ''safe''."
66+ )
67+
68+ orig_out = out
5569 if out is not None :
5670 if not isinstance (out , dpt .usm_ndarray ):
5771 raise TypeError (
@@ -64,8 +78,21 @@ def __call__(self, x, out=None, order="K"):
6478 f"Expected output shape is { x .shape } , got { out .shape } "
6579 )
6680
67- if ti ._array_overlap (x , out ):
68- raise TypeError ("Input and output arrays have memory overlap" )
81+ if res_dt != out .dtype :
82+ raise TypeError (
83+ f"Output array of type { res_dt } is needed,"
84+ f" got { out .dtype } "
85+ )
86+
87+ if (
88+ buf_dt is None
89+ and ti ._array_overlap (x , out )
90+ and not ti ._same_logical_tensors (x , out )
91+ ):
92+ # Allocate a temporary buffer to avoid memory overlapping.
93+ # Note if `buf_dt` is not None, a temporary copy of `x` will be
94+ # created, so the array overlap check isn't needed.
95+ out = dpt .empty_like (out )
6996
7097 if (
7198 dpctl .utils .get_execution_queue ((x .sycl_queue , out .sycl_queue ))
@@ -75,18 +102,6 @@ def __call__(self, x, out=None, order="K"):
75102 "Input and output allocation queues are not compatible"
76103 )
77104
78- if order not in ["C" , "F" , "K" , "A" ]:
79- order = "K"
80- buf_dt , res_dt = _find_buf_dtype (
81- x .dtype , self .result_type_resolver_fn_ , x .sycl_device
82- )
83- if res_dt is None :
84- raise TypeError (
85- f"function '{ self .name_ } ' does not support input type "
86- f"({ x .dtype } ), "
87- "and the input could not be safely coerced to any "
88- "supported types according to the casting rule ''safe''."
89- )
90105 exec_q = x .sycl_queue
91106 if buf_dt is None :
92107 if out is None :
@@ -96,17 +111,20 @@ def __call__(self, x, out=None, order="K"):
96111 if order == "A" :
97112 order = "F" if x .flags .f_contiguous else "C"
98113 out = dpt .empty_like (x , dtype = res_dt , order = order )
99- else :
100- if res_dt != out .dtype :
101- raise TypeError (
102- f"Output array of type { res_dt } is needed,"
103- f" got { out .dtype } "
104- )
105114
106- ht , _ = self .unary_fn_ (x , out , sycl_queue = exec_q )
107- ht .wait ()
115+ ht_unary_ev , unary_ev = self .unary_fn_ (x , out , sycl_queue = exec_q )
116+
117+ if not (orig_out is None or orig_out is out ):
118+ # Copy the out data from temporary buffer to original memory
119+ ht_copy_ev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
120+ src = out , dst = orig_out , sycl_queue = exec_q , depends = [unary_ev ]
121+ )
122+ ht_copy_ev .wait ()
123+ out = orig_out
108124
125+ ht_unary_ev .wait ()
109126 return out
127+
110128 if order == "K" :
111129 buf = _empty_like_orderK (x , buf_dt )
112130 else :
@@ -122,11 +140,6 @@ def __call__(self, x, out=None, order="K"):
122140 out = _empty_like_orderK (buf , res_dt )
123141 else :
124142 out = dpt .empty_like (buf , dtype = res_dt , order = order )
125- else :
126- if buf_dt != out .dtype :
127- raise TypeError (
128- f"Output array of type { buf_dt } is needed, got { out .dtype } "
129- )
130143
131144 ht , _ = self .unary_fn_ (buf , out , sycl_queue = exec_q , depends = [copy_ev ])
132145 ht_copy_ev .wait ()
0 commit comments