Skip to content

Commit 3910f27

Browse files
[mxfp] handle values close to max correctly w/o overflow (#8356)
Need to clamp since due to rounding, we can have overflow that was within the range before quantization. e.g., 3.3895e+38 -> log2(3.3895e+38 / max_fp8e4m3=448) ~= 119.17 -> round up to 120 + exp_bias=127 -> scale=247 3.3895e+38 / 2**120 ~= 254.9976 -> round to 256 in fp8e4m3fn Dequantization: 256 * 2**120 > 3.4e38 overflowing 3.38953139e38 <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 5d84a91 commit 3910f27

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

python/triton_kernels/tests/test_mxfp.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,22 @@ def test_mxfp4_rounding_cases(dst_dtype, device):
4545
assert_equal(dequant_torch, dequant)
4646

4747

48+
@pytest.mark.parametrize("src_dtype", ["float4_e2m1", "float8_e5m2", "float8_e4m3fn"])
49+
@pytest.mark.parametrize("dst_dtype", ["float16", "bfloat16", "float32"])
50+
def test_mxfp_extreme_values(src_dtype, dst_dtype, device):
51+
if "float8" in src_dtype and (is_cuda() and torch.cuda.get_device_capability()[0] < 9):
52+
pytest.skip("Float8 not tested on A100")
53+
src_dtype = dtype_str_to_torch(src_dtype)
54+
dst_dtype = dtype_str_to_torch(dst_dtype)
55+
BIG_VALUE = 65470 if dst_dtype == torch.float16 else 3.3895e38
56+
x = torch.tensor([BIG_VALUE, BIG_VALUE], dtype=dst_dtype, device=device)
57+
xq_value, xq_scale = downcast_to_mxfp(x, src_dtype, axis=-1)
58+
xdq = upcast_from_mxfp(xq_value, xq_scale, dst_dtype, axis=-1)
59+
xdq_ref = upcast_from_mxfp_torch(xq_value, xq_scale, dst_dtype, axis=-1)
60+
assert_equal(xdq_ref, xdq)
61+
assert not xdq.isinf().any()
62+
63+
4864
@pytest.mark.parametrize("src_dtype", ["float4_e2m1", "float8_e5m2", "float8_e4m3fn"])
4965
@pytest.mark.parametrize("dst_dtype", ["float16", "bfloat16", "float32"])
5066
def test_mxfp_quant_dequant(src_dtype, dst_dtype, device):

python/triton_kernels/triton_kernels/numerics_details/mxfp.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,17 @@ def upcast_from_mxfp_torch(tensor: torch.Tensor, scale: torch.Tensor, target_dty
297297
padded_tensor = padded_tensor.view(*new_shape)
298298
dq_scale_padded = dq_scale.unsqueeze(-1) # shape: [..., ceil(axis_shape/32), 1]
299299
out_padded = padded_tensor * dq_scale_padded
300+
# Need to clamp since due to rounding, we can have overflow that was within
301+
# the range before quantization.
302+
# e.g., 3.3895e+38 -> log2(3.3895e+38 / max_fp8e4m3=448) ~= 119.17 -> round
303+
# up to 120 + exp_bias=127 -> scale=247
304+
# 3.3895e+38 / 2**120 ~= 254.9976 -> round to 256 in fp8e4m3fn
305+
# Dequantization: 256 * 2**120 > 3.4e38 overflowing 3.38953139e38
306+
finfo = torch.finfo(target_dtype)
307+
out_padded = (padded_tensor * dq_scale_padded).clamp(finfo.min, finfo.max)
308+
if tensor.dtype == torch.float8_e5m2:
309+
# fp8e5m2 can have inf and we want to preserve so separately handle
310+
out_padded = out_padded.where(~padded_tensor.isinf(), padded_tensor.to(target_dtype))
300311

301312
# Flatten back and remove the padded tail
302313
out_padded = out_padded.view(*fp32_tensor.shape[:-1], new_axis_shape)

python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,16 @@ def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, mx_
119119
scale = scale.reshape(dst_scale.shape)
120120

121121
out_tensor = dst_tensor * dst_scale
122+
if dst_dtype == tl.float32:
123+
max_fin = 3.4028234663852886e+38
124+
elif dst_dtype == tl.bfloat16:
125+
max_fin = 3.3895313892515355e+38
126+
else:
127+
tl.static_assert(dst_dtype == tl.float16)
128+
max_fin = 65504
129+
# TODO: handle infinity same as upcast_from_mxfp_torch together with the
130+
# above FIXME
131+
out_tensor = tl.clamp(out_tensor, min=-max_fin, max=max_fin)
122132
# Correct any NaNs encoded via the scale.
123133
out_tensor = tl.where(scale == 0xFF, float("nan"), out_tensor)
124134
out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])

0 commit comments

Comments
 (0)