@@ -1146,6 +1146,10 @@ def test_argmax(self):
11461146 keras .config .backend () == "openvino" ,
11471147 reason = "OpenVINO doesn't support this change" ,
11481148 )
1149+ @pytest .mark .skipif (
1150+ keras .config .backend () == "mlx" ,
1151+ reason = "Wrong results due to MLX flushing denormal numbers to 0 on GPU" ,
1152+ )
11491153 def test_argmax_negative_zero (self ):
11501154 input_data = np .array (
11511155 [- 1.0 , - 0.0 , 1.401298464324817e-45 ], dtype = np .float32
@@ -1161,6 +1165,10 @@ def test_argmax_negative_zero(self):
11611165 evaluation and may change within this PR
11621166 """ ,
11631167 )
1168+ @pytest .mark .skipif (
1169+ keras .config .backend () == "mlx" ,
1170+ reason = "Wrong results due to MLX flushing denormal numbers to 0 on GPU" ,
1171+ )
11641172 def test_argmin_negative_zero (self ):
11651173 input_data = np .array (
11661174 [
@@ -5391,10 +5399,16 @@ def setUp(self):
53915399
53925400 self .jax_enable_x64 = enable_x64 ()
53935401 self .jax_enable_x64 .__enter__ ()
5402+
5403+ if backend .backend () == "mlx" :
5404+ self .mlx_cpu_context = backend .core .enable_float64 ()
5405+ self .mlx_cpu_context .__enter__ ()
53945406 return super ().setUp ()
53955407
53965408 def tearDown (self ):
53975409 self .jax_enable_x64 .__exit__ (None , None , None )
5410+ if backend .backend () == "mlx" :
5411+ self .mlx_cpu_context .__exit__ (None , None , None )
53985412 return super ().tearDown ()
53995413
54005414 @parameterized .named_parameters (
@@ -5598,6 +5612,13 @@ def test_matmul(self, dtypes):
55985612 import jax .numpy as jnp
55995613
56005614 dtype1 , dtype2 = dtypes
5615+ if (
5616+ all (dtype not in self .FLOAT_DTYPES for dtype in dtypes )
5617+ and backend .backend () == "mlx"
5618+ ):
5619+ # This must be removed once mlx.core.matmul supports integer dtypes
5620+ self .skipTest ("mlx doesn't support integer dot product" )
5621+
56015622 # The shape of the matrix needs to meet the requirements of
56025623 # torch._int_mm to test hardware-accelerated matmul
56035624 x1 = knp .ones ((17 , 16 ), dtype = dtype1 )
@@ -6620,6 +6641,13 @@ def test_dot(self, dtypes):
66206641 import jax .numpy as jnp
66216642
66226643 dtype1 , dtype2 = dtypes
6644+ if (
6645+ all (dtype not in self .FLOAT_DTYPES for dtype in dtypes )
6646+ and backend .backend () == "mlx"
6647+ ):
6648+ # This must be removed once mlx.core.matmul supports integer dtypes
6649+ self .skipTest ("mlx doesn't support integer dot product" )
6650+
66236651 x1 = knp .ones ((2 , 3 , 4 ), dtype = dtype1 )
66246652 x2 = knp .ones ((4 , 3 ), dtype = dtype2 )
66256653 x1_jax = jnp .ones ((2 , 3 , 4 ), dtype = dtype1 )
@@ -6648,6 +6676,13 @@ def get_input_shapes(subscripts):
66486676 return x1_shape , x2_shape
66496677
66506678 dtype1 , dtype2 = dtypes
6679+ if (
6680+ all (dtype not in self .FLOAT_DTYPES for dtype in dtypes )
6681+ and backend .backend () == "mlx"
6682+ ):
6683+ # This must be removed once mlx.core.matmul supports integer dtypes
6684+ self .skipTest ("mlx doesn't support integer dot product" )
6685+
66516686 subscripts = "ijk,lkj->il"
66526687 x1_shape , x2_shape = get_input_shapes (subscripts )
66536688 x1 = knp .ones (x1_shape , dtype = dtype1 )
@@ -8312,6 +8347,13 @@ def test_tensordot(self, dtypes):
83128347 import jax .numpy as jnp
83138348
83148349 dtype1 , dtype2 = dtypes
8350+ if (
8351+ all (dtype not in self .FLOAT_DTYPES for dtype in dtypes )
8352+ and backend .backend () == "mlx"
8353+ ):
8354+ # This must be removed once mlx.core.matmul supports integer dtypes
8355+ self .skipTest ("mlx doesn't support integer dot product" )
8356+
83158357 x1 = knp .ones ((1 , 1 ), dtype = dtype1 )
83168358 x2 = knp .ones ((1 , 1 ), dtype = dtype2 )
83178359 x1_jax = jnp .ones ((1 , 1 ), dtype = dtype1 )
@@ -8522,6 +8564,13 @@ def test_inner(self, dtypes):
85228564 import jax .numpy as jnp
85238565
85248566 dtype1 , dtype2 = dtypes
8567+ if (
8568+ all (dtype not in self .FLOAT_DTYPES for dtype in dtypes )
8569+ and backend .backend () == "mlx"
8570+ ):
8571+ # This must be removed once mlx.core.matmul supports integer dtypes
8572+ self .skipTest ("mlx doesn't support integer dot product" )
8573+
85258574 x1 = knp .ones ((1 ,), dtype = dtype1 )
85268575 x2 = knp .ones ((1 ,), dtype = dtype2 )
85278576 x1_jax = jnp .ones ((1 ,), dtype = dtype1 )
0 commit comments