@@ -12,18 +12,17 @@ class UndefinedCastError(FloatingPointError):
1212 pass
1313
1414
15- def display_float32 ( value , sign = 1 , exponent = 8 , mantissa = 23 ):
15+ def display_int ( ival , sign = 1 , exponent = 8 , mantissa = 23 ):
1616 """
17- Displays a float32 into b .
17+ Displays an integer as bits .
1818
19- :param value : value to display (float32)
19+ :param ival : value to display (float32)
2020 :param sign: number of bits for the sign
2121 :param exponent: number of bits for the exponent
2222 :param mantissa: number of bits for the mantissa
2323 :return: string
2424 """
2525 t = sign + exponent + mantissa
26- ival = int .from_bytes (struct .pack ("<f" , numpy .float32 (value )), "little" )
2726 s = bin (ival )[2 :]
2827 s = "0" * (t - len (s )) + s
2928 s1 = s [:sign ]
@@ -32,6 +31,24 @@ def display_float32(value, sign=1, exponent=8, mantissa=23):
3231 return "." .join ([s1 , s2 , s3 ])
3332
3433
34+ def display_float32 (value , sign = 1 , exponent = 8 , mantissa = 23 ):
35+ """
36+ Displays a float32 into b.
37+
38+ :param value: value to display (float32)
39+ :param sign: number of bits for the sign
40+ :param exponent: number of bits for the exponent
41+ :param mantissa: number of bits for the mantissa
42+ :return: string
43+ """
44+ return display_int (
45+ int .from_bytes (struct .pack ("<f" , numpy .float32 (value )), "little" ),
46+ sign = sign ,
47+ exponent = exponent ,
48+ mantissa = mantissa ,
49+ )
50+
51+
3552def display_float16 (value , sign = 1 , exponent = 5 , mantissa = 10 ):
3653 """
3754 Displays a float32 into b.
@@ -42,14 +59,9 @@ def display_float16(value, sign=1, exponent=5, mantissa=10):
4259 :param mantissa: number of bits for the mantissa
4360 :return: string
4461 """
45- t = sign + exponent + mantissa
46- ival = numpy .float16 (value ).view ("H" ) # pylint: disable=E1121
47- s = bin (ival )[2 :]
48- s = "0" * (t - len (s )) + s
49- s1 = s [:sign ]
50- s2 = s [sign : sign + exponent ]
51- s3 = s [sign + exponent :]
52- return "." .join ([s1 , s2 , s3 ])
62+ return display_int (
63+ numpy .float16 (value ).view ("H" ), sign = sign , exponent = exponent , mantissa = mantissa
64+ )
5365
5466
5567def display_fexmx (value , sign , exponent , mantissa ):
@@ -64,14 +76,7 @@ def display_fexmx(value, sign, exponent, mantissa):
6476 :param mantissa: number of bits for the mantissa
6577 :return: string
6678 """
67- t = sign + exponent + mantissa
68- ival = value
69- s = bin (ival )[2 :]
70- s = "0" * (t - len (s )) + s
71- s1 = s [:sign ]
72- s2 = s [sign : sign + exponent ]
73- s3 = s [sign + exponent :]
74- return "." .join ([s1 , s2 , s3 ])
79+ return display_int (value , sign = sign , exponent = exponent , mantissa = mantissa )
7580
7681
7782def display_fe4m3 (value , sign = 1 , exponent = 4 , mantissa = 3 ):
@@ -534,7 +539,9 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
534539 else :
535540 ret |= ex << 3
536541 ret |= m >> 20
537- if m & 0x80000 :
542+ if (m & 0x80000 ) and (
543+ (m & 0x100000 ) or (m & 0x7FFFF )
544+ ): # round to nearest even
538545 if (ret & 0x7F ) < 0x7F :
539546 # rounding
540547 ret += 1
@@ -584,7 +591,7 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
584591 if (ret & 0x7F ) == 0x7F :
585592 ret &= 0xFE
586593 if (m & 0x80000 ) and (
587- (m & 0x100000 ) or (m & 0x7C000 )
594+ (m & 0x100000 ) or (m & 0x7FFFF )
588595 ): # round to nearest even
589596 if (ret & 0x7F ) < 0x7E :
590597 # rounding
@@ -642,7 +649,9 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
642649 ex = e - 111 # 127 - 16
643650 ret |= ex << 2
644651 ret |= m >> 21
645- if m & 0x100000 :
652+ if m & 0x100000 and (
653+ (m & 0xFFFFF ) or (m & 0x200000 )
654+ ): # round to nearest even
646655 if (ret & 0x7F ) < 0x7F :
647656 # rounding
648657 ret += 1
0 commit comments