Skip to content

Commit 7e68d5e

Browse files
[mxfp8 moe training] make compile vs triton for dim0 cast configurable (#3219)
1 parent beee153 commit 7e68d5e

File tree

3 files changed

+51
-32
lines changed

3 files changed

+51
-32
lines changed

benchmarks/prototype/moe_training/bench_moe_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def warmup(model, input, labels):
205205
parser.add_argument(
206206
"--local_batch_size",
207207
type=int,
208-
default=8,
208+
default=12,
209209
)
210210
parser.add_argument(
211211
"--hidden_dim",

test/prototype/moe_training/test_scaled_grouped_mm.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
_emulated_mxfp8_scaled_grouped_mm_2d_2d,
3434
_emulated_mxfp8_scaled_grouped_mm_2d_3d,
3535
_quantize_then_scaled_grouped_mm,
36+
_to_mxfp8_then_scaled_grouped_mm,
3637
)
3738
from torchao.prototype.moe_training.utils import (
3839
_to_mxfp8_per_group_colwise,
@@ -317,11 +318,10 @@ def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts):
317318
"M,K,N", [(16640, 5120, 8192), (131072, 5120, 8192), (131072, 8192, 5120)]
318319
)
319320
@pytest.mark.parametrize("num_experts", (2, 4, 8, 16))
320-
def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts):
321-
from torchao.prototype.moe_training.scaled_grouped_mm import (
322-
_MXFP8GroupedMM,
323-
)
324-
321+
@pytest.mark.parametrize("use_triton_for_dim0_cast", (True, False))
322+
def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(
323+
M, K, N, num_experts, use_triton_for_dim0_cast
324+
):
325325
block_size = 32
326326
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
327327
w = torch.randn(
@@ -340,7 +340,9 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts):
340340
)
341341

342342
# Forward
343-
out = _MXFP8GroupedMM.apply(x, w_t, offs, block_size, torch.bfloat16)
343+
out = _to_mxfp8_then_scaled_grouped_mm(
344+
x, w_t, offs, block_size, torch.bfloat16, use_triton_for_dim0_cast
345+
)
344346
ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16)
345347
sqnr = compute_error(ref_out, out)
346348
min_sqnr = 27.0

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import logging
8-
from functools import partial
98
from typing import Optional
109

1110
import torch
@@ -34,6 +33,7 @@
3433
ScaleCalculationMode,
3534
)
3635
from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim0
36+
from torchao.prototype.mx_formats.mx_tensor import to_mx
3737
from torchao.prototype.mx_formats.utils import _to_mxfp8_dim1_kernel_wrapper
3838

3939
logger: logging.Logger = logging.getLogger(__name__)
@@ -79,15 +79,6 @@ def _quantize_then_scaled_grouped_mm(
7979
raise ValueError(f"Unsupported scaling type {scaling_type}")
8080

8181

82-
# Aliases for convenience/clarity
83-
_to_mxfp8_then_scaled_grouped_mm = partial(
84-
_quantize_then_scaled_grouped_mm, scaling_type=MoEScalingType.MXFP8
85-
)
86-
_to_fp8_rowwise_then_scaled_grouped_mm = partial(
87-
_quantize_then_scaled_grouped_mm, scaling_type=MoEScalingType.FP8_ROWWISE
88-
)
89-
90-
9182
class _Float8GroupedMM(torch.autograd.Function):
9283
"""Differentiable implementation of grouped GEMM with dynamic float8 quantization."""
9384

@@ -304,6 +295,7 @@ def forward(
304295
block_size: int = 32,
305296
out_dtype: Optional[torch.dtype] = torch.bfloat16,
306297
emulated: bool = False,
298+
use_triton_for_dim0_cast: bool = False,
307299
) -> torch.Tensor:
308300
# torchao _quantize_then_scaled_grouped_mm only supports A=2D and B=3D.
309301
assert A.ndim == 2, "A must be 2D"
@@ -313,17 +305,28 @@ def forward(
313305

314306
# A_data shape: (M, K)
315307
# A_scale shape: (M, K//block_size)
316-
A_data, A_scale = triton_to_mxfp8_dim0(
317-
A,
318-
inner_block_size=block_size,
319-
)
320-
321-
# B_data shape: (E, N, K)
322-
# B_scale shape: (E, N, K//block_size)
323-
B_data, B_scales = triton_to_mxfp8_dim0(
324-
B_t.transpose(-2, -1),
325-
inner_block_size=block_size,
326-
)
308+
if use_triton_for_dim0_cast:
309+
A_data, A_scale = triton_to_mxfp8_dim0(
310+
A,
311+
inner_block_size=block_size,
312+
)
313+
# B_data shape: (E, N, K)
314+
# B_scale shape: (E, N, K//block_size)
315+
B_data, B_scales = triton_to_mxfp8_dim0(
316+
B_t.transpose(-2, -1),
317+
inner_block_size=block_size,
318+
)
319+
else:
320+
A_scale, A_data = to_mx(
321+
A,
322+
elem_dtype=torch.float8_e4m3fn,
323+
block_size=block_size,
324+
)
325+
B_scales, B_data = to_mx(
326+
B_t.transpose(-2, -1),
327+
elem_dtype=torch.float8_e4m3fn,
328+
block_size=block_size,
329+
)
327330

328331
# Convert scales to blocked format for 2d-3d grouped mm
329332
_, blocked_scales_group_offsets_2d3d = (
@@ -351,19 +354,28 @@ def forward(
351354
ctx.block_size = block_size
352355
ctx.out_dtype = out_dtype
353356
ctx.emulated = emulated
357+
ctx.use_triton_for_dim0_cast = use_triton_for_dim0_cast
354358
return out
355359

356360
@staticmethod
357361
def backward(ctx, grad_out: torch.Tensor):
358362
A, B_t, offs, blocked_scales_group_offsets_2d3d = ctx.saved_tensors
359363
block_size = ctx.block_size
360364
out_dtype = ctx.out_dtype
365+
use_triton_for_dim0_cast = ctx.use_triton_for_dim0_cast
361366

362367
# grad_out_data shape: (M, N)
363368
# grad_out_scale shape: (M, N//block_size)
364-
grad_out_data, grad_out_scale = triton_to_mxfp8_dim0(
365-
grad_out, inner_block_size=block_size
366-
)
369+
if use_triton_for_dim0_cast:
370+
grad_out_data, grad_out_scale = triton_to_mxfp8_dim0(
371+
grad_out, inner_block_size=block_size
372+
)
373+
else:
374+
grad_out_scale, grad_out_data = to_mx(
375+
grad_out,
376+
elem_dtype=torch.float8_e4m3fn,
377+
block_size=block_size,
378+
)
367379

368380
# Quantize 3d expert weights along N (contraction dimension for next grouped gemm)
369381
# (E, K, N) -> (E, N, K)
@@ -449,7 +461,7 @@ def backward(ctx, grad_out: torch.Tensor):
449461
)
450462
# grad_B_t shape = (E,K,N)
451463
grad_B_t = grad_B.transpose(-2, -1)
452-
return grad_A, grad_B_t, None, None, None
464+
return grad_A, grad_B_t, None, None, None, None
453465

454466

455467
def _to_mxfp8_dim1_3d(
@@ -659,3 +671,8 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
659671

660672
def round_up(x, y):
661673
return ((x + y - 1) // y) * y
674+
675+
676+
# Aliases for convenience/clarity
677+
_to_mxfp8_then_scaled_grouped_mm = _MXFP8GroupedMM.apply
678+
_to_fp8_rowwise_then_scaled_grouped_mm = _Float8GroupedMM.apply

0 commit comments

Comments
 (0)