Skip to content

Commit 82ae011

Browse files
[moe training] change api _scaled_grouped_mm -> _quantize_then_scaled_grouped_mm (#3218)
1 parent bd69a73 commit 82ae011

File tree

6 files changed

+32
-20
lines changed

6 files changed

+32
-20
lines changed

benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
bench_fwd_microseconds,
2020
profile_fwd_bwd,
2121
)
22-
from torchao.prototype.moe_training import _scaled_grouped_mm
22+
from torchao.prototype.moe_training import _quantize_then_scaled_grouped_mm
2323
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
2424
from torchao.prototype.moe_training.utils import generate_jagged_offs
2525

@@ -158,7 +158,7 @@ def run_experiment(
158158

159159
# fwd_bwd scaled benchmark + profiling
160160
scaled_fwd_bwd_us = bench_fwd_bwd_microseconds(
161-
_scaled_grouped_mm,
161+
_quantize_then_scaled_grouped_mm,
162162
A,
163163
B_t,
164164
offs,
@@ -169,7 +169,7 @@ def run_experiment(
169169
)
170170
if args.profile:
171171
profile_fwd_bwd(
172-
_scaled_grouped_mm,
172+
_quantize_then_scaled_grouped_mm,
173173
A,
174174
B_t,
175175
offs,
@@ -190,7 +190,7 @@ def run_experiment(
190190
fullgraph=True,
191191
)
192192
scaled_fwd_us = bench_fwd_microseconds(
193-
_scaled_grouped_mm,
193+
_quantize_then_scaled_grouped_mm,
194194
A,
195195
B_t,
196196
offs,

test/prototype/moe_training/test_scaled_grouped_mm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from torchao.prototype.moe_training.scaled_grouped_mm import (
3333
_emulated_mxfp8_scaled_grouped_mm_2d_2d,
3434
_emulated_mxfp8_scaled_grouped_mm_2d_3d,
35-
_scaled_grouped_mm,
35+
_quantize_then_scaled_grouped_mm,
3636
)
3737
from torchao.prototype.moe_training.utils import (
3838
_to_mxfp8_per_group_colwise,
@@ -73,7 +73,7 @@ def test_valid_scaled_grouped_mm_2d_3d(m, n, k, n_groups):
7373
b_t = b.contiguous().transpose(-2, -1).requires_grad_(True)
7474

7575
# Compute output.
76-
out = _scaled_grouped_mm(
76+
out = _quantize_then_scaled_grouped_mm(
7777
a,
7878
b_t,
7979
offs=offs,
@@ -142,7 +142,7 @@ def test_K_or_N_dim_not_multiple_of_16(m, n, k):
142142

143143
# Compute output.
144144
with pytest.raises(AssertionError):
145-
_scaled_grouped_mm(
145+
_quantize_then_scaled_grouped_mm(
146146
a,
147147
b_t,
148148
offs=offs,
@@ -199,7 +199,7 @@ def compute_reference_forward(
199199
result_list.append(result[start : offs_cpu[i]])
200200
start = offs_cpu[i]
201201

202-
# Validate each actual result group from the _scaled_grouped_mm is equal to:
202+
# Validate each actual result group from the _quantize_then_scaled_grouped_mm is equal to:
203203
# 1. A manual _scaled_mm for the group.
204204
# 2. A matmul_with_hp_or_float8_args for the group (which is differentiable, and thus used to validate gradients).
205205
outputs = []

torchao/prototype/moe_training/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ This prototype provides:
2727
import torch
2828
from torch.nn import functional as F
2929
from torchao.prototype.moe_training import (
30-
_scaled_grouped_mm as torchao_scaled_grouped_mm
30+
_quantize_then_scaled_grouped_mm as torchao_scaled_grouped_mm
3131
)
3232
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
3333
from torchao.prototype.moe_training.utils import generate_jagged_offs
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1-
from torchao.prototype.moe_training.scaled_grouped_mm import _scaled_grouped_mm
1+
from torchao.prototype.moe_training.scaled_grouped_mm import (
2+
_quantize_then_scaled_grouped_mm,
3+
)
24

3-
__all__ = ["_scaled_grouped_mm"]
5+
__all__ = ["_quantize_then_scaled_grouped_mm"]

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import logging
8+
from functools import partial
89
from typing import Optional
910

1011
import torch
@@ -38,15 +39,15 @@
3839
logger: logging.Logger = logging.getLogger(__name__)
3940

4041

41-
def _scaled_grouped_mm(
42+
def _quantize_then_scaled_grouped_mm(
4243
A: torch.Tensor,
4344
B_t: torch.Tensor,
4445
offs: Optional[torch.Tensor] = None,
4546
out_dtype: Optional[torch.dtype] = torch.bfloat16,
4647
scaling_type: MoEScalingType = MoEScalingType.FP8_ROWWISE,
4748
) -> torch.Tensor:
4849
"""
49-
This function performs dynamic float8 quantization with row-wise scaling
50+
This function performs dynamic quantization with the given recipe
5051
on the input tensors A and B, then performs a scaled grouped GEMM and returns the results.
5152
5253
Args:
@@ -78,6 +79,15 @@ def _scaled_grouped_mm(
7879
raise ValueError(f"Unsupported scaling type {scaling_type}")
7980

8081

82+
# Aliases for convenience/clarity
83+
_to_mxfp8_then_scaled_grouped_mm = partial(
84+
_quantize_then_scaled_grouped_mm, scaling_type=MoEScalingType.MXFP8
85+
)
86+
_to_fp8_rowwise_then_scaled_grouped_mm = partial(
87+
_quantize_then_scaled_grouped_mm, scaling_type=MoEScalingType.FP8_ROWWISE
88+
)
89+
90+
8191
class _Float8GroupedMM(torch.autograd.Function):
8292
"""Differentiable implementation of grouped GEMM with dynamic float8 quantization."""
8393

@@ -89,7 +99,7 @@ def forward(
8999
offs: Optional[torch.Tensor] = None,
90100
out_dtype: Optional[torch.dtype] = torch.bfloat16,
91101
) -> torch.Tensor:
92-
# torchao _scaled_grouped_mm only supports A=2D|3D and B=3D.
102+
# torchao _quantize_then_scaled_grouped_mm only supports A=2D|3D and B=3D.
93103
assert A.ndim == 2 or A.ndim == 3, "A must be 2D or 3D"
94104
assert B_t.ndim == 3, "B must be 3D"
95105

@@ -113,7 +123,7 @@ def forward(
113123

114124
# Assert A and B dims are compatible for a scaled grouped GEMM.
115125
assert A.size(-1) == B_t.size(-2), (
116-
f"shape {A.shape} and {B_t.shape} are not compatible for _scaled_grouped_mm"
126+
f"shape {A.shape} and {B_t.shape} are not compatible for _quantize_then_scaled_grouped_mm"
117127
)
118128

119129
# The left operand in the scaled grouped GEMM must be row-major due to hardware requirements.
@@ -295,7 +305,7 @@ def forward(
295305
out_dtype: Optional[torch.dtype] = torch.bfloat16,
296306
emulated: bool = False,
297307
) -> torch.Tensor:
298-
# torchao _scaled_grouped_mm only supports A=2D and B=3D.
308+
# torchao _quantize_then_scaled_grouped_mm only supports A=2D and B=3D.
299309
assert A.ndim == 2, "A must be 2D"
300310
assert B_t.ndim == 3, "B must be 3D"
301311
assert block_size == 32, "Only block_size=32 is supported"

torchao/prototype/moe_training/tensor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from torch.distributed.device_mesh import DeviceMesh
1616
from torch.distributed.fsdp import MixedPrecisionPolicy
1717

18-
from torchao.prototype.moe_training import _scaled_grouped_mm
18+
from torchao.prototype.moe_training import _quantize_then_scaled_grouped_mm
1919
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
2020

2121
logger: logging.Logger = logging.getLogger(__name__)
@@ -39,7 +39,7 @@ class ScaledGroupedMMTensor(torch.Tensor):
3939
"""
4040
ScaledGroupedMMTensor is a simple tensor subclass that wraps a regular tensor
4141
and overrides the torch._grouped_mm op by dispatching to the
42-
differentiable _scaled_grouped_mm autograd function.
42+
differentiable _quantize_then_scaled_grouped_mm autograd function.
4343
"""
4444

4545
scaling_type: MoEScalingType = MoEScalingType.FP8_ROWWISE
@@ -77,7 +77,7 @@ def __init__(
7777

7878
@classmethod
7979
def __torch_function__(cls, func, types, args, kwargs={}):
80-
# override the grouped mm op to use the differentiable _scaled_grouped_mm
80+
# override the grouped mm op to use the differentiable _quantize_then_scaled_grouped_mm
8181
if func.__name__ == cls.grouped_mm_func_name:
8282
# Use torchao scaled grouped mm with dynamic quant for
8383
# "2d x 3d with offsets" case (used for routed experts).
@@ -99,7 +99,7 @@ def __torch_function__(cls, func, types, args, kwargs={}):
9999
has_offs = kwargs.get(cls.offs_arg_name) is not None
100100
other_args = args[2:]
101101
if A_is_2d and B_is_2d_or_3d and has_offs:
102-
return _scaled_grouped_mm(
102+
return _quantize_then_scaled_grouped_mm(
103103
A,
104104
B,
105105
*other_args,

0 commit comments

Comments
 (0)