11import os
22import pprint
3+ import struct
34import unittest
45import warnings
56import numpy
67import pandas
8+ from onnx import TensorProto
79from onnx_array_api .validation .f8 import (
810 CastFloat8 ,
911 UndefinedCastError ,
@@ -285,6 +287,15 @@ def test_search_float32_into_fe4m3fn(self):
285287 ok = "" if b == nf else "WRONG" ,
286288 true = value ,
287289 add = add ,
290+ exponent = (
291+ int .from_bytes (
292+ struct .pack ("<f" , numpy .float32 (v )), "little"
293+ )
294+ & 0x7F800000
295+ )
296+ >> 23 ,
297+ d1 = v - fe4m3_to_float32_float (nf ),
298+ d2 = v - fe4m3_to_float32_float (b ),
288299 )
289300 )
290301 if wrong > 0 :
@@ -449,10 +460,13 @@ def test_search_e4m3_pow(self):
449460 continue
450461 r2 = float32_to_fe4m3 (v )
451462 if r1 != r2 :
463+ ex = abs (v - fe4m3_to_float32 (r1 )) == abs (v - fe4m3_to_float32 (r2 ))
452464 raise AssertionError (
453465 f"p={ p } , v={ v } , "
454466 f"search={ r1 } :{ display_fe4m3 (r1 )} ={ fe4m3_to_float32 (r1 )} != "
455- f"bit={ r2 } :{ display_fe4m3 (r2 )} ={ fe4m3_to_float32 (r2 )} "
467+ f"bit={ r2 } :{ display_fe4m3 (r2 )} ={ fe4m3_to_float32 (r2 )} "
468+ f"d1={ v - fe4m3_to_float32 (r1 )} d2={ v - fe4m3_to_float32 (r2 )} "
469+ f"|d1|==|d2|={ ex } "
456470 )
457471 for p in range (1 , 40 ):
458472 v = - (2 ** (- p ))
@@ -462,10 +476,13 @@ def test_search_e4m3_pow(self):
462476 continue
463477 r2 = float32_to_fe4m3 (v )
464478 if r1 != r2 :
479+ ex = abs (v - fe4m3_to_float32 (r1 )) == abs (v - fe4m3_to_float32 (r2 ))
465480 raise AssertionError (
466481 f"p={ p } , v={ v } , "
467482 f"search={ r1 } :{ display_fe4m3 (r1 )} ={ fe4m3_to_float32 (r1 )} != "
468- f"bit={ r2 } :{ display_fe4m3 (r2 )} ={ fe4m3_to_float32 (r2 )} "
483+ f"bit={ r2 } :{ display_fe4m3 (r2 )} ={ fe4m3_to_float32 (r2 )} "
484+ f"d1={ v - fe4m3_to_float32 (r1 )} d2={ v - fe4m3_to_float32 (r2 )} "
485+ f"|d1|==|d2|={ ex } "
469486 )
470487
471488 def test_search_e5m2_pow (self ):
@@ -478,10 +495,13 @@ def test_search_e5m2_pow(self):
478495 continue
479496 r2 = float32_to_fe5m2 (v )
480497 if r1 != r2 :
498+ ex = abs (v - fe5m2_to_float32 (r1 )) == abs (v - fe5m2_to_float32 (r2 ))
481499 raise AssertionError (
482500 f"p={ p } , v={ v } , "
483501 f"search={ r1 } :{ display_fe5m2 (r1 )} ={ fe5m2_to_float32 (r1 )} != "
484- f"bit={ r2 } :{ display_fe5m2 (r2 )} ={ fe5m2_to_float32 (r2 )} "
502+ f"bit={ r2 } :{ display_fe5m2 (r2 )} ={ fe5m2_to_float32 (r2 )} "
503+ f"d1={ v - fe4m3_to_float32 (r1 )} d2={ v - fe5m2_to_float32 (r2 )} "
504+ f"|d1|==|d2|={ ex } "
485505 )
486506 for p in range (1 , 40 ):
487507 v = - (2 ** (- p ))
@@ -491,10 +511,13 @@ def test_search_e5m2_pow(self):
491511 continue
492512 r2 = float32_to_fe5m2 (v )
493513 if r1 != r2 :
514+ ex = abs (v - fe5m2_to_float32 (r1 )) == abs (v - fe5m2_to_float32 (r2 ))
494515 raise AssertionError (
495516 f"p={ p } , v={ v } , "
496517 f"search={ r1 } :{ display_fe5m2 (r1 )} ={ fe5m2_to_float32 (r1 )} != "
497- f"bit={ r2 } :{ display_fe5m2 (r2 )} ={ fe5m2_to_float32 (r2 )} "
518+ f"bit={ r2 } :{ display_fe5m2 (r2 )} ={ fe5m2_to_float32 (r2 )} "
519+ f"d1={ v - fe4m3_to_float32 (r1 )} d2={ v - fe5m2_to_float32 (r2 )} "
520+ f"|d1|==|d2|={ ex } "
498521 )
499522
500523 def test_float32_to_fe4m3fn_inf (self ):
@@ -1152,13 +1175,50 @@ def test_float8_e5m2fnuz_negative_nan(self):
11521175 self .assertTrue (numpy .isnan (back ))
11531176
11541177 def test_fe4m3fn_to_float32_bug (self ):
1155- cases = [(1.8131605 , 1.875 )]
1156- for val , expected in cases :
1157- with self .subTest (value = val , expected = expected ):
1158- res = fe4m3_to_float32 (search_float32_into_fe4m3 (val ))
1159- self .assertEqual (expected , res )
1160- res = fe4m3_to_float32 (float32_to_fe4m3 (val ))
1161- self .assertEqual (expected , res )
1178+ cases = [
1179+ (0.00439453125 , 0.00390625 , TensorProto .FLOAT8E4M3FN ),
1180+ (0.005859375 , 0.005859375 , TensorProto .FLOAT8E4M3FN ),
1181+ (0.005759375 , 0.005859375 , TensorProto .FLOAT8E4M3FN ),
1182+ (0.0046875 , 0.00390625 , TensorProto .FLOAT8E4M3FN ),
1183+ (0.001953125 , 0.001953125 , TensorProto .FLOAT8E4M3FN ),
1184+ (0.0029296875 , 0.00390625 , TensorProto .FLOAT8E4M3FN ),
1185+ (0.002053125 , 0.001953125 , TensorProto .FLOAT8E4M3FN ),
1186+ (0.00234375 , 0.001953125 , TensorProto .FLOAT8E4M3FN ),
1187+ (0.0087890625 , 0.0078125 , TensorProto .FLOAT8E4M3FN ),
1188+ (0.001171875 , 0.001953125 , TensorProto .FLOAT8E4M3FN ),
1189+ (1.8131605 , 1.875 , TensorProto .FLOAT8E4M3FN ),
1190+ (- 100 , - 96 , TensorProto .FLOAT8E4M3FNUZ ),
1191+ (416 , 384 , TensorProto .FLOAT8E5M2FNUZ ),
1192+ ]
1193+ for val , expected , pt in cases :
1194+ with self .subTest (value = val , expected = expected , proto = pt ):
1195+ if pt == TensorProto .FLOAT8E4M3FN :
1196+ res = fe4m3_to_float32 (search_float32_into_fe4m3 (val ))
1197+ self .assertEqual (expected , res )
1198+ res = fe4m3_to_float32 (float32_to_fe4m3 (val ))
1199+ self .assertEqual (expected , res )
1200+ continue
1201+ if pt == TensorProto .FLOAT8E4M3FNUZ :
1202+ res = fe4m3_to_float32 (
1203+ search_float32_into_fe4m3 (val , uz = True ), uz = True
1204+ )
1205+ self .assertEqual (expected , res )
1206+ res = fe4m3_to_float32 (float32_to_fe4m3 (val , uz = True ), uz = True )
1207+ self .assertEqual (expected , res )
1208+ continue
1209+ if pt == TensorProto .FLOAT8E5M2FNUZ :
1210+ res = fe5m2_to_float32 (
1211+ search_float32_into_fe5m2 (val , fn = True , uz = True ),
1212+ fn = True ,
1213+ uz = True ,
1214+ )
1215+ self .assertEqual (expected , res )
1216+ res = fe5m2_to_float32 (
1217+ float32_to_fe5m2 (val , fn = True , uz = True ), fn = True , uz = True
1218+ )
1219+ self .assertEqual (expected , res )
1220+ continue
1221+ raise AssertionError (f"Unexpected value for pt={ pt } ." )
11621222
11631223
11641224if __name__ == "__main__" :
0 commit comments