Skip to content

Commit a425dc2

Browse files
authored
[Bugfix] [ROCm] [AITER]: Fix aiter block quant not compatible with torch compile dynamo (#28716)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
1 parent 964d65d commit a425dc2

File tree

3 files changed

+180
-7
lines changed

3 files changed

+180
-7
lines changed
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# This is a test for the AITER group_fp8_quant op.
4+
# It tests if the AITER op is
5+
# 1. correctly defined the relationship between
6+
# implementation and fake function
7+
# 2. can be used with torch.compile
8+
# 3. can be used with CUDA graphs
9+
# This file will be skipped if AITER is not installed
10+
# and the platform is not ROCm.
11+
12+
import importlib.util
13+
14+
import pytest
15+
import torch
16+
17+
# this import statement is needed to ensure the ops are registered
18+
from vllm._aiter_ops import rocm_aiter_ops
19+
from vllm.platforms import current_platform
20+
21+
# Check if aiter package is installed
22+
aiter_available = importlib.util.find_spec("aiter") is not None
23+
24+
pytestmark = pytest.mark.skipif(
25+
not (current_platform.is_rocm() and aiter_available),
26+
reason="AITER ops are only available on ROCm with aiter package installed",
27+
)
28+
29+
30+
def test_rocm_aiter_group_fp8_quant_fake_implementation():
31+
"""Test that the fake implementation is correctly
32+
defined for torch.ops.vllm.rocm_aiter_group_fp8_quant."""
33+
# Create test tensors
34+
M = 128
35+
N = 4096
36+
group_size = 128
37+
38+
input_tensor = torch.randn((M, N), dtype=torch.bfloat16, device="cuda")
39+
40+
# Verify the op's fake implementation using torch.library.opcheck
41+
# This checks that the fake function returns tensors with correct shapes and dtypes
42+
torch.library.opcheck(
43+
torch.ops.vllm.rocm_aiter_group_fp8_quant,
44+
(input_tensor, group_size),
45+
test_utils=("test_faketensor",),
46+
)
47+
48+
49+
def test_rocm_aiter_group_fp8_quant_torch_compile_with_cudagraph():
50+
"""Test that rocm_aiter_ops.group_fp8_quant
51+
with group size 128 can be used with
52+
torch.compile in cudagraph mode."""
53+
# Create test tensors
54+
M = 128
55+
N = 4096
56+
group_size = 128
57+
58+
input_tensor = torch.randn((M, N), dtype=torch.bfloat16, device="cuda")
59+
60+
# Define a function that uses the op
61+
def group_fp8_quant_fn(x):
62+
return rocm_aiter_ops.group_fp8_quant(x, group_size)
63+
64+
# Compile with cudagraph mode
65+
compiled_fn = torch.compile(
66+
group_fp8_quant_fn,
67+
fullgraph=True,
68+
backend="inductor",
69+
mode="reduce-overhead",
70+
dynamic=False,
71+
)
72+
73+
# Run eager mode
74+
x_fp8_eager, scales_eager = group_fp8_quant_fn(input_tensor)
75+
76+
# Run compiled version (first run will trigger compilation)
77+
x_fp8_compiled, scales_compiled = compiled_fn(input_tensor)
78+
79+
# Verify shapes match
80+
assert x_fp8_compiled.shape == x_fp8_eager.shape
81+
assert scales_compiled.shape == scales_eager.shape
82+
83+
# Verify expected shapes
84+
assert x_fp8_compiled.shape == (M, N)
85+
expected_scale_cols = (N + group_size - 1) // group_size
86+
assert scales_compiled.shape == (M, expected_scale_cols)
87+
88+
# Verify results match
89+
assert torch.allclose(
90+
x_fp8_compiled.to(torch.float32),
91+
x_fp8_eager.to(torch.float32),
92+
rtol=1e-2,
93+
atol=1e-2,
94+
)
95+
assert torch.allclose(scales_compiled, scales_eager, rtol=1e-3, atol=1e-3)
96+
97+
# Test with different input (reusing compiled graph)
98+
input_tensor_2 = torch.randn((M, N), dtype=torch.bfloat16, device="cuda")
99+
x_fp8_eager_2, scales_eager_2 = group_fp8_quant_fn(input_tensor_2)
100+
x_fp8_compiled_2, scales_compiled_2 = compiled_fn(input_tensor_2)
101+
102+
# Verify second run also produces correct results
103+
assert torch.allclose(
104+
x_fp8_compiled_2.to(torch.float32),
105+
x_fp8_eager_2.to(torch.float32),
106+
rtol=1e-2,
107+
atol=1e-2,
108+
)
109+
assert torch.allclose(scales_compiled_2, scales_eager_2, rtol=1e-3, atol=1e-3)
110+
111+
112+
def test_rocm_aiter_group_fp8_quant_different_shapes():
113+
"""Test rocm_aiter_ops.group_fp8_quant with different input shapes."""
114+
group_size = 128
115+
116+
test_shapes = [
117+
(64, 2048),
118+
(256, 8192),
119+
(32, 1024),
120+
(512, 4096),
121+
]
122+
123+
for M, N in test_shapes:
124+
input_tensor = torch.randn((M, N), dtype=torch.bfloat16, device="cuda")
125+
126+
x_fp8, scales = rocm_aiter_ops.group_fp8_quant(input_tensor, group_size)
127+
128+
# Verify shapes
129+
assert x_fp8.shape == (M, N)
130+
expected_scale_cols = (N + group_size - 1) // group_size
131+
assert scales.shape == (M, expected_scale_cols)
132+
133+
# Verify dtypes
134+
from aiter import dtypes
135+
136+
assert x_fp8.dtype == dtypes.fp8
137+
assert scales.dtype == torch.float32

vllm/_aiter_ops.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,36 @@ def wrapper(*args, **kwargs):
4343
return wrapper
4444

4545

46+
def _rocm_aiter_group_fp8_quant_impl(
47+
x: torch.Tensor,
48+
group_size: int,
49+
) -> tuple[torch.Tensor, torch.Tensor]:
50+
assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size"
51+
from aiter import QuantType, dtypes, get_hip_quant
52+
53+
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
54+
return aiter_per1x128_quant(x.contiguous(), quant_dtype=dtypes.fp8)
55+
56+
57+
def _rocm_aiter_group_fp8_quant_fake(
58+
x: torch.Tensor,
59+
group_size: int,
60+
) -> tuple[torch.Tensor, torch.Tensor]:
61+
from aiter import dtypes
62+
63+
M, N = x.shape
64+
x_fp8 = torch.empty((M, N), dtype=dtypes.fp8, device=x.device)
65+
out_bs = torch.empty(
66+
(
67+
M,
68+
(N + group_size - 1) // group_size,
69+
),
70+
dtype=torch.float32,
71+
device=x.device,
72+
)
73+
return x_fp8, out_bs
74+
75+
4676
def _rocm_aiter_fused_moe_impl(
4777
hidden_states: torch.Tensor,
4878
w1: torch.Tensor,
@@ -512,6 +542,14 @@ def register_ops_once() -> None:
512542
)
513543

514544
# register all the custom ops here
545+
direct_register_custom_op(
546+
op_name="rocm_aiter_group_fp8_quant",
547+
op_func=_rocm_aiter_group_fp8_quant_impl,
548+
mutates_args=[],
549+
fake_impl=_rocm_aiter_group_fp8_quant_fake,
550+
dispatch_key=current_platform.dispatch_key,
551+
)
552+
515553
direct_register_custom_op(
516554
op_name="rocm_aiter_asm_moe_tkw1",
517555
op_func=_rocm_aiter_asm_moe_tkw1_impl,
@@ -887,14 +925,12 @@ def triton_gemm_a8w8_blockscale(
887925
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
888926

889927
@staticmethod
890-
def per_1x128_fp8_quant(
928+
def group_fp8_quant(
891929
input_2d: torch.Tensor,
930+
group_size: int = 128,
892931
) -> tuple[torch.Tensor, ...]:
893-
"""Only applies quantization method for fp8 data type only."""
894-
from aiter import QuantType, dtypes, get_hip_quant
895-
896-
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
897-
return aiter_per1x128_quant(input_2d.contiguous(), quant_dtype=dtypes.fp8)
932+
assert group_size == 128, "Group size must be 128"
933+
return torch.ops.vllm.rocm_aiter_group_fp8_quant(input_2d, group_size)
898934

899935
@staticmethod
900936
def is_triton_gemm_w8a8_tuned(n: int, k: int) -> bool:

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def _run_aiter(
342342
)
343343
# MI300 uses tuned AITER ASM/C++ kernel
344344
else:
345-
q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d)
345+
q_input, input_scale = rocm_aiter_ops.group_fp8_quant(input_2d)
346346

347347
return gemm_a8w8_blockscale_op(
348348
q_input,

0 commit comments

Comments
 (0)