Skip to content

Commit 1f29141

Browse files
authored
[Refactor] Use DeepGEMM Col Major TMA Aligned Tensor (vllm-project#25517)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent 6160ba4 commit 1f29141

File tree

6 files changed

+34
-78
lines changed

6 files changed

+34
-78
lines changed

benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,16 @@
88

99
from vllm import _custom_ops as ops
1010
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
11-
get_col_major_tma_aligned_tensor,
1211
per_token_group_quant_fp8,
1312
w8a8_triton_block_scaled_mm,
1413
)
1514
from vllm.triton_utils import triton
16-
from vllm.utils.deep_gemm import calc_diff, fp8_gemm_nt, per_block_cast_to_fp8
15+
from vllm.utils.deep_gemm import (
16+
calc_diff,
17+
fp8_gemm_nt,
18+
get_col_major_tma_aligned_tensor,
19+
per_block_cast_to_fp8,
20+
)
1721

1822

1923
def benchmark_shape(m: int,

tests/kernels/quantization/test_block_fp8.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
native_w8a8_block_matmul)
1212
from vllm.config import VllmConfig
1313
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
14-
cutlass_scaled_mm, get_col_major_tma_aligned_tensor,
15-
per_token_group_quant_fp8, w8a8_triton_block_scaled_mm)
14+
cutlass_scaled_mm, per_token_group_quant_fp8, w8a8_triton_block_scaled_mm)
1615
from vllm.platforms import current_platform
1716
from vllm.utils import has_deep_gemm
18-
from vllm.utils.deep_gemm import fp8_gemm_nt, per_block_cast_to_fp8
17+
from vllm.utils.deep_gemm import (fp8_gemm_nt,
18+
get_col_major_tma_aligned_tensor,
19+
per_block_cast_to_fp8)
1920

2021
if current_platform.get_device_capability() < (9, 0):
2122
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@
3434
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
3535
select_nvfp4_gemm_impl)
3636
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
37-
expert_weight_is_col_major, get_col_major_tma_aligned_tensor,
38-
requant_weight_ue8m0_inplace)
37+
expert_weight_is_col_major, requant_weight_ue8m0_inplace)
3938
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
4039
check_moe_marlin_supports_layer, marlin_make_workspace_new,
4140
marlin_moe_permute_scales)
@@ -50,7 +49,8 @@
5049
from vllm.model_executor.utils import set_weight_attrs
5150
from vllm.platforms import CpuArchEnum, current_platform
5251
from vllm.scalar_type import scalar_types
53-
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
52+
from vllm.utils.deep_gemm import (get_col_major_tma_aligned_tensor,
53+
is_deep_gemm_e8m0_used)
5454

5555
logger = init_logger(__name__)
5656

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434
W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support,
3535
create_fp8_input_scale, create_fp8_scale_parameter,
3636
create_fp8_weight_parameter, expert_weight_is_col_major,
37-
get_col_major_tma_aligned_tensor, maybe_post_process_fp8_weight_block,
38-
process_fp8_weight_block_strategy, process_fp8_weight_tensor_strategy,
39-
requant_weight_ue8m0_inplace, validate_fp8_block_shape)
37+
maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy,
38+
process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace,
39+
validate_fp8_block_shape)
4040
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
4141
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
4242
prepare_moe_fp8_layer_for_marlin)
@@ -53,7 +53,9 @@
5353
from vllm.platforms import current_platform
5454
from vllm.scalar_type import scalar_types
5555
from vllm.utils import has_deep_gemm
56-
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported
56+
from vllm.utils.deep_gemm import (get_col_major_tma_aligned_tensor,
57+
is_deep_gemm_e8m0_used,
58+
is_deep_gemm_supported)
5759
from vllm.utils.flashinfer import has_flashinfer_moe
5860

5961
if TYPE_CHECKING:

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 1 addition & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
PerTensorScaleParameter)
2424
from vllm.platforms import current_platform
2525
from vllm.triton_utils import tl, triton
26-
from vllm.utils import cdiv, direct_register_custom_op
26+
from vllm.utils import direct_register_custom_op
2727
from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used,
2828
is_deep_gemm_supported,
2929
should_use_deepgemm_for_fp8_linear)
@@ -749,70 +749,6 @@ def grid(META):
749749
return C
750750

751751

752-
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
753-
# TODO(wentao): remove this function when DeepGEMM exposes this function
754-
def get_tma_aligned_size(x: int, element_size: int) -> int:
755-
"""
756-
Global memory address of TMA must be 16-byte aligned.
757-
Since we use column-major layout for the LHS scaling tensor,
758-
the M-axis of the LHS scaling tensor needs to be padded to a multiple of
759-
16 bytes.
760-
761-
Arguments:
762-
x: original M-axis shape of the LHS scaling tensor.
763-
element_size: element size of the LHS scaling tensor.
764-
765-
Returns:
766-
M-axis shape of the LHS scaling tensor after padding.
767-
"""
768-
tma_alignment_bytes = 16
769-
assert tma_alignment_bytes % element_size == 0
770-
alignment = tma_alignment_bytes // element_size
771-
return cdiv(x, alignment) * alignment
772-
773-
774-
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
775-
# TODO(wentao): remove this function when DeepGEMM exposes this function
776-
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
777-
"""
778-
Returns TMA-aligned transposed format of the input tensor. `torch.transpose`
779-
will be called if necessary.
780-
If the input tensor is already column-major layout and 16-byte aligned along
781-
the M axis (thus meets the requirement of LHS scaling tensor in
782-
DeepGEMM), this function will do nothing.
783-
784-
Arguments:
785-
x: usually the LHS scaling tensor in GEMM.
786-
787-
Returns:
788-
The LHS scaling tensor of TMA-aligned transposed format.
789-
"""
790-
# NOTES: for the extreme performance, you may rewrite/fuse this function in
791-
# CUDA
792-
assert x.dim() in (2, 3)
793-
remove_dim = False
794-
m, n = x.shape[-2], x.shape[-1]
795-
aligned_m = get_tma_aligned_size(m, x.element_size())
796-
if x.dim() == 2:
797-
if x.stride(0) == 1 and x.stride(1) == aligned_m:
798-
return x
799-
x, remove_dim = x.unsqueeze(0), True
800-
801-
b = x.shape[0]
802-
803-
# The last kernel gives a column-major TMA aligned layout
804-
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(
805-
2) == aligned_m:
806-
return x.squeeze(0) if remove_dim else x
807-
808-
# Normal layout requires transposing
809-
aligned_x = torch.transpose(
810-
torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
811-
aligned_x[:, :m, :] = x
812-
aligned_x = aligned_x[:, :m, :]
813-
return aligned_x.squeeze(0) if remove_dim else aligned_x
814-
815-
816752
def requant_weight_ue8m0_inplace(
817753
weight: torch.Tensor,
818754
weight_scale: torch.Tensor,

vllm/utils/deep_gemm.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,13 @@ def _missing(*_: Any, **__: Any) -> NoReturn:
7070
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
7171
_grouped_impl: Callable[..., Any] | None = None
7272
_grouped_masked_impl: Callable[..., Any] | None = None
73+
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
7374

7475

7576
def _lazy_init() -> None:
7677
"""Import deep_gemm and resolve symbols on first use."""
77-
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
78+
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl,\
79+
_get_mn_major_tma_aligned_tensor_impl
7880

7981
# fast path
8082
if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None
@@ -95,6 +97,16 @@ def _lazy_init() -> None:
9597
_fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
9698
_grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
9799
_grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None)
100+
_get_mn_major_tma_aligned_tensor_impl = getattr(
101+
_dg, "get_mn_major_tma_aligned_tensor", None)
102+
103+
104+
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
105+
"""Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor"""
106+
_lazy_init()
107+
if _get_mn_major_tma_aligned_tensor_impl is None:
108+
return _missing()
109+
return _get_mn_major_tma_aligned_tensor_impl(x)
98110

99111

100112
def fp8_gemm_nt(*args, **kwargs):
@@ -191,4 +203,5 @@ def should_use_deepgemm_for_fp8_linear(
191203
"is_deep_gemm_e8m0_used",
192204
"is_deep_gemm_supported",
193205
"should_use_deepgemm_for_fp8_linear",
206+
"get_col_major_tma_aligned_tensor",
194207
]

0 commit comments

Comments
 (0)