2626 _integer_dtypes ,
2727 _integer_or_boolean_dtypes ,
2828 _floating_dtypes ,
29+ _real_floating_dtypes ,
2930 _complex_floating_dtypes ,
3031 _numeric_dtypes ,
3132 _result_type ,
3233 _dtype_categories ,
34+ _real_to_complex_map ,
3335)
3436from ._flags import get_array_api_strict_flags , set_array_api_strict_flags
3537
@@ -243,6 +245,7 @@ def _promote_scalar(self, scalar):
243245 """
244246 from ._data_type_functions import iinfo
245247
248+ target_dtype = self .dtype
246249 # Note: Only Python scalar types that match the array dtype are
247250 # allowed.
248251 if isinstance (scalar , bool ):
@@ -268,10 +271,13 @@ def _promote_scalar(self, scalar):
268271 "Python float scalars can only be promoted with floating-point arrays."
269272 )
270273 elif isinstance (scalar , complex ):
271- if self .dtype not in _complex_floating_dtypes :
274+ if self .dtype not in _floating_dtypes :
272275 raise TypeError (
273- "Python complex scalars can only be promoted with complex floating-point arrays."
276+ "Python complex scalars can only be promoted with floating-point arrays."
274277 )
278+ # 1j * array(floating) is allowed
279+ if self .dtype in _real_floating_dtypes :
280+ target_dtype = _real_to_complex_map [self .dtype ]
275281 else :
276282 raise TypeError ("'scalar' must be a Python scalar" )
277283
@@ -282,7 +288,7 @@ def _promote_scalar(self, scalar):
282288 # behavior for integers within the bounds of the integer dtype.
283289 # Outside of those bounds we use the default NumPy behavior (either
284290 # cast or raise OverflowError).
285- return Array ._new (np .array (scalar , dtype = self . dtype ._np_dtype ), device = self .device )
291+ return Array ._new (np .array (scalar , dtype = target_dtype ._np_dtype ), device = self .device )
286292
287293 @staticmethod
288294 def _normalize_two_args (x1 , x2 ) -> Tuple [Array , Array ]:
0 commit comments