@@ -986,6 +986,27 @@ def refimpl(_x, _min, _max):
986986 if _max is not None and _max not in dh .dtype_ranges [x .dtype ]:
987987 return None
988988
989+ # If min or max are float64 and x is float32, they will need to be
990+ # downcast to float32. This could result in a round in the wrong
991+ # direction meaning the resulting clipped value might not actually be
992+ # between min and max. This behavior is unspecified, so skip any cases
993+ # where x is within the rounding error of downcasting min or max.
994+ if x .dtype == xp .float32 :
995+ if min is not None and not dh .is_scalar (min ) and min .dtype == xp .float64 and math .isfinite (_min ):
996+ _min_float32 = float (xp .asarray (_min , dtype = xp .float32 ))
997+ if math .isinf (_min_float32 ):
998+ return None
999+ tol = abs (_min - _min_float32 )
1000+ if math .isclose (_min , _min_float32 , abs_tol = tol ):
1001+ return None
1002+ if max is not None and not dh .is_scalar (max ) and max .dtype == xp .float64 and math .isfinite (_max ):
1003+ _max_float32 = float (xp .asarray (_max , dtype = xp .float32 ))
1004+ if math .isinf (_max_float32 ):
1005+ return None
1006+ tol = abs (_max - _max_float32 )
1007+ if math .isclose (_max , _max_float32 , abs_tol = tol ):
1008+ return None
1009+
9891010 if (math .isnan (_x )
9901011 or (_min is not None and math .isnan (_min ))
9911012 or (_max is not None and math .isnan (_max ))):
0 commit comments