From 5b7e4b04d053ed9dfe76e7610c1d72af751fdb70 Mon Sep 17 00:00:00 2001 From: stijn Date: Mon, 15 Sep 2025 20:50:43 +0200 Subject: [PATCH 1/2] Replace bfloat16 dtype from `bfloat16` package by `ml_dtypes` package --- kernel_tuner/accuracy.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/kernel_tuner/accuracy.py b/kernel_tuner/accuracy.py index 491541909..508140035 100644 --- a/kernel_tuner/accuracy.py +++ b/kernel_tuner/accuracy.py @@ -58,6 +58,12 @@ def __call__(self, params): def _find_bfloat16_if_available(): # Try to get bfloat16 if available. + try: + from ml_dtypes import bfloat16 + return bfloat16 + except ImportError: + pass + try: from bfloat16 import bfloat16 return bfloat16 From 4cb3f052804836de71080a48feb91ac653a715a1 Mon Sep 17 00:00:00 2001 From: stijn Date: Mon, 6 Oct 2025 16:41:23 +0200 Subject: [PATCH 2/2] Add bfloat16 dtype import from jax --- kernel_tuner/accuracy.py | 51 ++++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 18 deletions(-) diff --git a/kernel_tuner/accuracy.py b/kernel_tuner/accuracy.py index 508140035..b647947c1 100644 --- a/kernel_tuner/accuracy.py +++ b/kernel_tuner/accuracy.py @@ -58,28 +58,43 @@ def __call__(self, params): def _find_bfloat16_if_available(): # Try to get bfloat16 if available. - try: - from ml_dtypes import bfloat16 - return bfloat16 - except ImportError: - pass - - try: - from bfloat16 import bfloat16 - return bfloat16 - except ImportError: - pass + dtype = None + # get it via numpy if available try: - from tensorflow import bfloat16 - return bfloat16.as_numpy_dtype - except ImportError: + dtype = np.dtype("bfloat16") + except TypeError: pass - logging.warning( - "could not find `bfloat16` data type for numpy, " - + "please install either the package `bfloat16` or `tensorflow`" - ) + # otherwise, try ml_dtypes + if dtype is None: + try: + from ml_dtypes import bfloat16 + dtype = bfloat16 + except ImportError: + pass + + # otherwise, try jax + if dtype is None: + try: + from jax.numpy import bfloat16 + dtype = bfloat16 + except ImportError: + pass + + # otherwise, try tensorflow + if dtype is None: + try: + from tensorflow import bfloat16 + dtype = bfloat16.as_numpy_dtype + except ImportError: + pass + + if dtype is None: + logging.warning( + "could not find `bfloat16` data type for numpy, " + + "please install either the package `ml_dtypes`, `jax`, or `tensorflow`" + ) return None