Skip to content

Commit 315e9b4

Browse files
[mxfp8] fix test nan != nan issue (#3273)
1 parent f856d36 commit 315e9b4

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,6 @@ def test_some_zeros(elem_dtype):
116116
_test_mx(data, elem_dtype, block_size)
117117

118118

119-
# TODO(future PR): fix and reenable this test
120-
@pytest.mark.skip(reason="does not pass on B200 yet")
121119
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
122120
def test_to_mx_rceil():
123121
# nan
@@ -131,11 +129,7 @@ def test_to_mx_rceil():
131129
],
132130
dtype=torch.uint32,
133131
).view(torch.float32)
134-
# fmt: on
135-
ground_truth_scale = torch.tensor([255], dtype=torch.uint8).view(
136-
torch.float8_e8m0fnu
137-
)
138-
# fmt: off
132+
139133
ground_truth_fp8 = torch.tensor(
140134
[
141135
127, 0, 0, 0, 0, 0, 0, 0,
@@ -149,7 +143,7 @@ def test_to_mx_rceil():
149143
data_mx = MXTensor.to_mx(
150144
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
151145
)
152-
torch.testing.assert_close(data_mx.scale, ground_truth_scale)
146+
assert torch.isnan(data_mx.scale)
153147
assert torch.isnan(data_mx.qdata[0])
154148
assert torch.all(data_mx.qdata[1:] == 0)
155149
# fp32 denorm

0 commit comments

Comments
 (0)