Skip to content

Commit fb1450d

Browse files
[BE] [moe training] bench script for single device moe layer (#3126)
[moe training] bench script for single device moe layer
1 parent 53a66f8 commit fb1450d

File tree

1 file changed

+220
-0
lines changed

1 file changed

+220
-0
lines changed
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
######################################################################
7+
8+
import argparse
9+
import copy
10+
import logging
11+
import sys
12+
13+
import torch
14+
from torch import nn
15+
from torch.nn import functional as F
16+
17+
from benchmarks.utils import bench_fwd_bwd_microseconds, profile_fwd_bwd
18+
from torchao.prototype.moe_training.conversion_utils import (
19+
MoEScalingType,
20+
MoETrainingConfig,
21+
)
22+
from torchao.quantization.quant_api import quantize_
23+
24+
# this benchmark requires torchtitan
25+
try:
26+
from torchtitan.distributed.expert_parallel import (
27+
set_token_group_alignment_size_m,
28+
)
29+
from torchtitan.models.moe import MoE, MoEArgs
30+
except ImportError:
31+
logging.warning(
32+
"please pip install torchtitan to run this benchmark: https://github.com/pytorch/torchtitan"
33+
)
34+
sys.exit(0)
35+
36+
37+
def bench_moe_training_fsdp(args: argparse.Namespace):
38+
(
39+
recipe_name,
40+
enable_profile,
41+
local_num_experts,
42+
local_batch_size,
43+
seq_len,
44+
dim,
45+
hidden_dim,
46+
) = (
47+
args.recipe,
48+
args.profile,
49+
args.local_num_experts,
50+
args.local_batch_size,
51+
args.seq_len,
52+
args.dim,
53+
args.hidden_dim,
54+
)
55+
assert torch.cuda.is_available()
56+
assert recipe_name in ["fp8_rowwise", "mxfp8"]
57+
recipe = MoEScalingType[recipe_name.upper()]
58+
if recipe == MoEScalingType.FP8_ROWWISE and torch.cuda.get_device_capability() != (
59+
9,
60+
0,
61+
):
62+
logging.warning(
63+
f"Skipping FP8 rowwise benchmarks, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}"
64+
)
65+
return
66+
67+
elif recipe == MoEScalingType.MXFP8 and torch.cuda.get_device_capability() != (
68+
10,
69+
0,
70+
):
71+
logging.warning(
72+
f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}"
73+
)
74+
return
75+
76+
# define model args
77+
target_fqns = ["experts"]
78+
model_args = MoEArgs(
79+
num_experts=local_num_experts,
80+
)
81+
init_std = 0.02
82+
device = torch.device("cuda")
83+
84+
# reference bf16 MoE using llama4 shapes
85+
ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda()
86+
torch.manual_seed(42)
87+
ref_model.init_weights(init_std, device)
88+
89+
# target MoE for testing conversion
90+
model = copy.deepcopy(ref_model)
91+
92+
# Token group alignment size must be 16 for fp8 rowwise training
93+
alignment_size = 32 if recipe == MoEScalingType.MXFP8 else 16
94+
set_token_group_alignment_size_m(alignment_size)
95+
96+
# assert starting params are identical for both models
97+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
98+
assert torch.equal(param1, param2)
99+
100+
# convert MoE to float8 training
101+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
102+
for target_fqn in target_fqns:
103+
if target_fqn in cur_fqn:
104+
return True
105+
return False
106+
107+
# quantize test model
108+
config = MoETrainingConfig(scaling_type=recipe)
109+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
110+
111+
# inputs
112+
ref_x = torch.randn(
113+
local_batch_size,
114+
seq_len,
115+
dim,
116+
dtype=torch.bfloat16,
117+
requires_grad=True,
118+
device=device,
119+
)
120+
x = ref_x.detach().clone().requires_grad_(True)
121+
122+
def warmup(model, input, labels):
123+
for _ in range(3):
124+
out = model(input)
125+
loss = F.mse_loss(out, labels)
126+
loss.backward()
127+
torch.cuda.synchronize()
128+
129+
labels = torch.ones_like(x)
130+
131+
# Warmup bf16
132+
warmup(ref_model, ref_x, labels)
133+
134+
# Bench bf16
135+
bf16_us = bench_fwd_bwd_microseconds(
136+
ref_model,
137+
ref_x,
138+
labels=labels,
139+
use_compile=True,
140+
fullgraph=False,
141+
)
142+
bf16_ms = bf16_us / 1e3
143+
if enable_profile:
144+
print("Profiling bf16 training")
145+
profile_fwd_bwd(
146+
ref_model,
147+
ref_x,
148+
labels=labels,
149+
use_compile=True,
150+
fullgraph=False,
151+
profile_name="bf16_profile",
152+
)
153+
154+
# Warmup quantized
155+
warmup(model, x, labels)
156+
157+
# Bench quantized
158+
scaled_us = bench_fwd_bwd_microseconds(
159+
model,
160+
x,
161+
labels=labels,
162+
use_compile=True,
163+
fullgraph=False,
164+
)
165+
scaled_ms = scaled_us / 1e3
166+
if enable_profile:
167+
print("Profiling quantized training")
168+
profile_fwd_bwd(
169+
model,
170+
x,
171+
labels=labels,
172+
use_compile=True,
173+
fullgraph=False,
174+
profile_name=f"{recipe_name}_profile",
175+
)
176+
177+
print(f"total_M: {local_batch_size * seq_len}, N: {hidden_dim}, K: {dim}")
178+
print(f"bf16 time: {bf16_ms:.3f} ms")
179+
print(f"{recipe_name} time: {scaled_ms:.3f} ms")
180+
print(f"speedup: {bf16_us / scaled_us:.3f}x")
181+
182+
183+
if __name__ == "__main__":
184+
parser = argparse.ArgumentParser(description="Benchmark MoE layer with FSDP2")
185+
parser.add_argument(
186+
"--profile",
187+
action="store_true",
188+
help="Enable PyTorch profiling and save results to file",
189+
)
190+
parser.add_argument(
191+
"--recipe", type=str, help="[fp8_rowwise, mxfp8]", required=True
192+
)
193+
parser.add_argument(
194+
"--local_num_experts",
195+
type=int,
196+
default=8,
197+
)
198+
parser.add_argument(
199+
"--seq_len",
200+
type=int,
201+
default=8192,
202+
)
203+
parser.add_argument(
204+
"--local_batch_size",
205+
type=int,
206+
default=8,
207+
)
208+
parser.add_argument(
209+
"--hidden_dim",
210+
type=int,
211+
default=8192,
212+
)
213+
parser.add_argument(
214+
"--dim",
215+
type=int,
216+
default=5120,
217+
)
218+
219+
args = parser.parse_args()
220+
bench_moe_training_fsdp(args)

0 commit comments

Comments
 (0)