Skip to content

Commit 171133f

Browse files
[Bugfix] Fix test fused quant layernorm tests (vllm-project#27865)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
1 parent 32787d0 commit 171133f

File tree

2 files changed

+28
-10
lines changed

2 files changed

+28
-10
lines changed

csrc/quantization/w8a8/int8/scaled_quant.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <ATen/cuda/CUDAContext.h>
22
#include <torch/all.h>
3+
#include <c10/cuda/CUDAGuard.h>
34

45
#include <cmath>
56

@@ -275,6 +276,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
275276
int const num_tokens = input.numel() / hidden_size;
276277
dim3 const grid(num_tokens);
277278
dim3 const block(std::min(hidden_size, 256));
279+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
278280
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
279281
VLLM_DISPATCH_FLOATING_TYPES(
280282
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
@@ -306,6 +308,7 @@ void dynamic_scaled_int8_quant(
306308
int const num_tokens = input.numel() / hidden_size;
307309
dim3 const grid(num_tokens);
308310
dim3 const block(std::min(hidden_size, 256));
311+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
309312
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
310313
VLLM_DISPATCH_FLOATING_TYPES(
311314
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {

tests/kernels/core/test_fused_quant_layernorm.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
DTYPES = [torch.bfloat16, torch.float]
1313
QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn]
14-
VEC_HIDDEN_SIZES = range(1024, 1030)
14+
VEC_HIDDEN_SIZES = [1024, 1025, 1027, 1029]
1515
# Avoid combinatorial explosion with full Cartesian product
1616
NUM_TOKENS_HIDDEN_SIZES = [
1717
*[(1, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5120, 5137]],
@@ -65,7 +65,7 @@ def ref_dynamic_per_token_quant(
6565
)
6666
else:
6767
assert quant_dtype == torch.int8
68-
torch_out, scales = ops.scaled_int8_quant(torch_out)
68+
torch_out, scales, _ = ops.scaled_int8_quant(torch_out)
6969

7070
return torch_out, scales, residual
7171

@@ -109,7 +109,7 @@ def ops_impl(
109109

110110
@pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES)
111111
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
112-
@pytest.mark.parametrize("scale_ub", SCALE_UBS)
112+
@pytest.mark.parametrize("has_scale_ub", SCALE_UBS)
113113
@pytest.mark.parametrize("dtype", DTYPES)
114114
@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES)
115115
@pytest.mark.parametrize("seed", SEEDS)
@@ -119,7 +119,7 @@ def test_rms_norm(
119119
num_tokens: int,
120120
hidden_size: int,
121121
add_residual: bool,
122-
scale_ub: bool,
122+
has_scale_ub: bool,
123123
dtype: torch.dtype,
124124
quant_dtype: torch.dtype,
125125
seed: int,
@@ -130,7 +130,7 @@ def test_rms_norm(
130130
torch.cuda.manual_seed(seed)
131131
torch.set_default_device(device)
132132

133-
if scale_ub is not None and quant_dtype != torch.float8_e4m3fn:
133+
if has_scale_ub and quant_dtype != torch.float8_e4m3fn:
134134
# skip
135135
return
136136

@@ -143,9 +143,11 @@ def test_rms_norm(
143143
scale = 1 / (hidden_size)
144144
x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale
145145
residual = torch.randn_like(x) * scale if add_residual else None
146-
if scale_ub is not None:
146+
if has_scale_ub:
147147
rms_x, _ = ref_rms_norm(layer, x, residual)
148148
scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device="cuda")
149+
else:
150+
scale_ub = None
149151

150152
ref_out, ref_scales, ref_residual = ref_impl(
151153
layer, x, quant_dtype, residual, scale_ub
@@ -156,14 +158,27 @@ def test_rms_norm(
156158

157159
assert ref_out.dtype == quant_dtype
158160
assert ops_out.dtype == quant_dtype
159-
assert torch.allclose(ref_scales, ops_scales)
160161
if quant_dtype == torch.int8:
162+
assert torch.allclose(ref_scales, ops_scales, atol=1e-6)
161163
# big atol to account for round-off errors.
162164
assert torch.allclose(ref_out, ops_out, atol=1)
163165
else:
164-
assert torch.allclose(
165-
ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)
166-
)
166+
assert torch.allclose(ref_scales, ops_scales)
167+
a = ref_out.to(dtype=torch.float32)
168+
b = ops_out.to(dtype=torch.float32)
169+
ok = torch.allclose(a, b)
170+
if not ok:
171+
# fallback: compare dequantized values with relaxed tolerance
172+
a_deq = a * ref_scales.view(-1, 1)
173+
b_deq = b * ops_scales.view(-1, 1)
174+
# NOTE: It is possible that some future test cases trigger this
175+
# max diff due to precision issues. If such an error is
176+
# encountered, it's recommended to inspect the differences between
177+
# all corresponding elements from each tensor (e.g. by looping over
178+
# them) and checking how many the max diff error shows up on (just
179+
# a few bad elements should still be considered acceptable).
180+
ok = torch.allclose(a_deq, b_deq, rtol=5e-2, atol=5e-2)
181+
assert ok
167182
if add_residual:
168183
assert torch.allclose(ref_residual, ops_residual)
169184

0 commit comments

Comments
 (0)