1212from keras .src .backend .mlx .core import cast
1313from keras .src .backend .mlx .core import convert_to_tensor
1414from keras .src .backend .mlx .core import convert_to_tensors
15+ from keras .src .backend .mlx .core import is_tensor
1516from keras .src .backend .mlx .core import slice
1617from keras .src .backend .mlx .core import to_mlx_dtype
1718
@@ -272,8 +273,20 @@ def bitwise_xor(x, y):
272273
273274def bitwise_left_shift (x , y ):
274275 x = convert_to_tensor (x )
275- y = convert_to_tensor (y )
276- return mx .left_shift (x , y )
276+ if not isinstance (y , int ):
277+ y = convert_to_tensor (y )
278+
279+ # handle result dtype to match other backends
280+ types = [x .dtype ]
281+ if is_tensor (y ):
282+ types .append (y .dtype )
283+ result_dtype = result_type (* types )
284+ mlx_result_dtype = to_mlx_dtype (result_dtype )
285+
286+ result = mx .left_shift (x , y )
287+ if result .dtype != mlx_result_dtype :
288+ return result .astype (mlx_result_dtype )
289+ return result
277290
278291
279292def left_shift (x , y ):
@@ -282,8 +295,20 @@ def left_shift(x, y):
282295
283296def bitwise_right_shift (x , y ):
284297 x = convert_to_tensor (x )
285- y = convert_to_tensor (y )
286- return mx .right_shift (x , y )
298+ if not isinstance (y , int ):
299+ y = convert_to_tensor (y )
300+
301+ # handle result dtype to match other backends
302+ types = [x .dtype ]
303+ if is_tensor (y ):
304+ types .append (y .dtype )
305+ result_dtype = result_type (* types )
306+ mlx_result_dtype = to_mlx_dtype (result_dtype )
307+
308+ result = mx .right_shift (x , y )
309+ if result .dtype != mlx_result_dtype :
310+ return result .astype (mlx_result_dtype )
311+ return result
287312
288313
289314def right_shift (x , y ):
@@ -1567,3 +1592,34 @@ def rot90(array, k=1, axes=(0, 1)):
15671592 array = array [tuple (slices )]
15681593
15691594 return array
1595+
1596+
1597+ def signbit (x ):
1598+ x = convert_to_tensor (x )
1599+
1600+ if x .dtype in (
1601+ mx .float16 ,
1602+ mx .float32 ,
1603+ mx .float64 ,
1604+ mx .bfloat16 ,
1605+ mx .complex64 ,
1606+ ):
1607+ if x .dtype == mx .complex64 :
1608+ # check sign of real part for complex numbers
1609+ real_part = mx .real (x )
1610+ return signbit (real_part )
1611+ zeros = x == 0
1612+ # this works because in mlx 1/0=inf and 1/-0=-inf
1613+ neg_zeros = (1 / x == mx .array (float ("-inf" ))) & zeros
1614+ return mx .where (zeros , neg_zeros , x < 0 )
1615+ elif x .dtype in (mx .uint8 , mx .uint16 , mx .uint32 , mx .uint64 ):
1616+ # unsigned integers never negative
1617+ return mx .zeros_like (x ).astype (mx .bool_ )
1618+ elif x .dtype in (mx .int8 , mx .int16 , mx .int32 , mx .int64 ):
1619+ # for integers, simple negative check
1620+ return x < 0
1621+ elif x .dtype == mx .bool_ :
1622+ # for boolean array, return false
1623+ return mx .zeros_like (x ).astype (mx .bool_ )
1624+ else :
1625+ raise ValueError (f"Unsupported dtype in `signbit`: { x .dtype } " )
0 commit comments