Skip to content

Commit 37434ed

Browse files
authored
feat: patch sm103 for 3xfp4 moe generation (#2082)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description Patch sm103 for 3xfp4 moe generation ## πŸ” Related Issues Following up of #2020 #1925 ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes ``` $ ls csrc/nv_internal/tensorrt_llm/cutlass_instantiations/103/gemm_grouped 100 103 80 $ pytest tests/moe/test_trtllm_cutlass_fused_moe.py 22 passed, 3 skipped, 1 warning in 771.89s (0:12:51) ``` <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added support for Blackwell (SM103) GPU architecture in MOE (Mixture of Experts) operations with specialized CUTLASS-optimized modules. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 636a3ab commit 37434ed

File tree

4 files changed

+26
-1
lines changed

4 files changed

+26
-1
lines changed

β€Žflashinfer/aot.pyβ€Ž

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from .jit.fp8_quantization import gen_mxfp8_quantization_sm100_module
4444
from .jit.fused_moe import (
4545
gen_cutlass_fused_moe_sm120_module,
46+
gen_cutlass_fused_moe_sm103_module,
4647
gen_cutlass_fused_moe_sm100_module,
4748
gen_cutlass_fused_moe_sm90_module,
4849
gen_trtllm_gen_fused_moe_sm100_module,
@@ -495,6 +496,7 @@ def gen_all_modules(
495496
jit_specs.append(gen_tgv_gemm_sm10x_module(torch.float16, use_sm_100f=True))
496497
if has_sm103:
497498
jit_specs.append(gen_fp4_quantization_sm103_module())
499+
jit_specs.append(gen_cutlass_fused_moe_sm103_module())
498500
if has_sm110:
499501
jit_specs.append(gen_fp4_quantization_sm110_module())
500502
if has_sm120:

β€Žflashinfer/fused_moe/__init__.pyβ€Ž

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
convert_to_block_layout,
2222
cutlass_fused_moe,
2323
gen_cutlass_fused_moe_sm120_module,
24+
gen_cutlass_fused_moe_sm103_module,
2425
gen_cutlass_fused_moe_sm100_module,
2526
gen_cutlass_fused_moe_sm90_module,
2627
gen_trtllm_gen_fused_moe_sm100_module,
@@ -39,6 +40,7 @@
3940
"convert_to_block_layout",
4041
"cutlass_fused_moe",
4142
"gen_cutlass_fused_moe_sm120_module",
43+
"gen_cutlass_fused_moe_sm103_module",
4244
"gen_cutlass_fused_moe_sm100_module",
4345
"gen_cutlass_fused_moe_sm90_module",
4446
"gen_trtllm_gen_fused_moe_sm100_module",

β€Žflashinfer/fused_moe/core.pyβ€Ž

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
)
3535
from ..jit.fused_moe import (
3636
gen_cutlass_fused_moe_sm120_module,
37+
gen_cutlass_fused_moe_sm103_module,
3738
gen_cutlass_fused_moe_sm100_module,
3839
gen_cutlass_fused_moe_sm90_module,
3940
gen_cutlass_fused_moe_sm89_module,
@@ -315,7 +316,9 @@ def convert_to_block_layout(input_tensor: torch.Tensor, blockK: int) -> torch.Te
315316
def get_cutlass_fused_moe_module(backend: str = "100", use_fast_build: bool = False):
316317
if backend in ("120", "121"):
317318
module = gen_cutlass_fused_moe_sm120_module(use_fast_build).build_and_load()
318-
elif backend in ("100", "103", "110"):
319+
elif backend == "103":
320+
module = gen_cutlass_fused_moe_sm103_module(use_fast_build).build_and_load()
321+
elif backend in ("100", "110"):
319322
module = gen_cutlass_fused_moe_sm100_module(use_fast_build).build_and_load()
320323
elif backend == "90":
321324
module = gen_cutlass_fused_moe_sm90_module(use_fast_build).build_and_load()

β€Žflashinfer/jit/fused_moe.pyβ€Ž

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,24 @@ def gen_cutlass_fused_moe_sm120_module(use_fast_build: bool = False) -> JitSpec:
4747
return gen_cutlass_fused_moe_module(nvcc_flags, "120", use_fast_build)
4848

4949

50+
def gen_cutlass_fused_moe_sm103_module(use_fast_build: bool = False) -> JitSpec:
51+
nvcc_flags = [
52+
"-DCOMPILE_BLACKWELL_TMA_GEMMS",
53+
"-DCOMPILE_BLACKWELL_TMA_GROUPED_GEMMS",
54+
"-DENABLE_BF16",
55+
"-DENABLE_FP8",
56+
"-DENABLE_FP4",
57+
"-DUSING_OSS_CUTLASS_MOE_GEMM",
58+
"-DCOMPILE_BLACKWELL_SM103_TMA_GROUPED_GEMMS",
59+
]
60+
61+
nvcc_flags += current_compilation_context.get_nvcc_flags_list(
62+
supported_major_versions=[10]
63+
)
64+
65+
return gen_cutlass_fused_moe_module(nvcc_flags, "103", use_fast_build)
66+
67+
5068
def gen_cutlass_fused_moe_sm100_module(use_fast_build: bool = False) -> JitSpec:
5169
nvcc_flags = [
5270
"-DCOMPILE_BLACKWELL_TMA_GEMMS",

0 commit comments

Comments
Β (0)