@@ -2443,14 +2443,18 @@ def _get_analyze_compat_dtype(arr):
24432443 return np .dtype ('int16' if arr .max () <= np .iinfo (np .int16 ).max else 'int32' )
24442444
24452445 mn , mx = arr .min (), arr .max ()
2446- if (isinstance (mn , int ) and isinstance (mx , int )) or (
2447- np .can_cast (mn , np .int32 ) and np .can_cast (mx , np .int32 )
2448- ):
2446+ if np .can_cast (mn , np .int32 ) and np .can_cast (mx , np .int32 ):
24492447 return np .dtype ('int32' )
2450- if (isinstance (mn , float ) and isinstance (mx , float )) or (
2451- np .can_cast (mn , np .float32 ) and np .can_cast (mx , np .float32 )
2452- ):
2448+ elif (isinstance (mn , int ) and isinstance (mx , int )):
2449+ info = np .finfo ('int32' )
2450+ if mn >= info .min and mx <= info .max :
2451+ return np .dtype ('int32' )
2452+ if np .can_cast (mn , np .float32 ) and np .can_cast (mx , np .float32 ):
24532453 return np .dtype ('float32' )
2454+ elif (isinstance (mn , float ) and isinstance (mx , float ):
2455+ info = np .finfo ('float32' )
2456+ if mn >= info .min and mx <= info .max :
2457+ return np .dtype ('float32' )
24542458
24552459 raise ValueError (
24562460 f'Cannot find analyze-compatible dtype for array with dtype={ dtype } (min={ mn } , max={ mx } )'
0 commit comments