|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | +###################################################################### |
| 7 | + |
| 8 | +import argparse |
| 9 | +import copy |
| 10 | +import logging |
| 11 | +import sys |
| 12 | + |
| 13 | +import torch |
| 14 | +from torch import nn |
| 15 | +from torch.nn import functional as F |
| 16 | + |
| 17 | +from benchmarks.utils import bench_fwd_bwd_microseconds, profile_fwd_bwd |
| 18 | +from torchao.prototype.moe_training.conversion_utils import ( |
| 19 | + MoEScalingType, |
| 20 | + MoETrainingConfig, |
| 21 | +) |
| 22 | +from torchao.quantization.quant_api import quantize_ |
| 23 | + |
| 24 | +# this benchmark requires torchtitan |
| 25 | +try: |
| 26 | + from torchtitan.distributed.expert_parallel import ( |
| 27 | + set_token_group_alignment_size_m, |
| 28 | + ) |
| 29 | + from torchtitan.models.moe import MoE, MoEArgs |
| 30 | +except ImportError: |
| 31 | + logging.warning( |
| 32 | + "please pip install torchtitan to run this benchmark: https://github.com/pytorch/torchtitan" |
| 33 | + ) |
| 34 | + sys.exit(0) |
| 35 | + |
| 36 | + |
| 37 | +def bench_moe_training_fsdp(args: argparse.Namespace): |
| 38 | + ( |
| 39 | + recipe_name, |
| 40 | + enable_profile, |
| 41 | + local_num_experts, |
| 42 | + local_batch_size, |
| 43 | + seq_len, |
| 44 | + dim, |
| 45 | + hidden_dim, |
| 46 | + ) = ( |
| 47 | + args.recipe, |
| 48 | + args.profile, |
| 49 | + args.local_num_experts, |
| 50 | + args.local_batch_size, |
| 51 | + args.seq_len, |
| 52 | + args.dim, |
| 53 | + args.hidden_dim, |
| 54 | + ) |
| 55 | + assert torch.cuda.is_available() |
| 56 | + assert recipe_name in ["fp8_rowwise", "mxfp8"] |
| 57 | + recipe = MoEScalingType[recipe_name.upper()] |
| 58 | + if recipe == MoEScalingType.FP8_ROWWISE and torch.cuda.get_device_capability() != ( |
| 59 | + 9, |
| 60 | + 0, |
| 61 | + ): |
| 62 | + logging.warning( |
| 63 | + f"Skipping FP8 rowwise benchmarks, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}" |
| 64 | + ) |
| 65 | + return |
| 66 | + |
| 67 | + elif recipe == MoEScalingType.MXFP8 and torch.cuda.get_device_capability() != ( |
| 68 | + 10, |
| 69 | + 0, |
| 70 | + ): |
| 71 | + logging.warning( |
| 72 | + f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}" |
| 73 | + ) |
| 74 | + return |
| 75 | + |
| 76 | + # define model args |
| 77 | + target_fqns = ["experts"] |
| 78 | + model_args = MoEArgs( |
| 79 | + num_experts=local_num_experts, |
| 80 | + ) |
| 81 | + init_std = 0.02 |
| 82 | + device = torch.device("cuda") |
| 83 | + |
| 84 | + # reference bf16 MoE using llama4 shapes |
| 85 | + ref_model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).cuda() |
| 86 | + torch.manual_seed(42) |
| 87 | + ref_model.init_weights(init_std, device) |
| 88 | + |
| 89 | + # target MoE for testing conversion |
| 90 | + model = copy.deepcopy(ref_model) |
| 91 | + |
| 92 | + # Token group alignment size must be 16 for fp8 rowwise training |
| 93 | + alignment_size = 32 if recipe == MoEScalingType.MXFP8 else 16 |
| 94 | + set_token_group_alignment_size_m(alignment_size) |
| 95 | + |
| 96 | + # assert starting params are identical for both models |
| 97 | + for param1, param2 in zip(model.parameters(), ref_model.parameters()): |
| 98 | + assert torch.equal(param1, param2) |
| 99 | + |
| 100 | + # convert MoE to float8 training |
| 101 | + def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: |
| 102 | + for target_fqn in target_fqns: |
| 103 | + if target_fqn in cur_fqn: |
| 104 | + return True |
| 105 | + return False |
| 106 | + |
| 107 | + # quantize test model |
| 108 | + config = MoETrainingConfig(scaling_type=recipe) |
| 109 | + quantize_(model, config=config, filter_fn=moe_module_filter_fn) |
| 110 | + |
| 111 | + # inputs |
| 112 | + ref_x = torch.randn( |
| 113 | + local_batch_size, |
| 114 | + seq_len, |
| 115 | + dim, |
| 116 | + dtype=torch.bfloat16, |
| 117 | + requires_grad=True, |
| 118 | + device=device, |
| 119 | + ) |
| 120 | + x = ref_x.detach().clone().requires_grad_(True) |
| 121 | + |
| 122 | + def warmup(model, input, labels): |
| 123 | + for _ in range(3): |
| 124 | + out = model(input) |
| 125 | + loss = F.mse_loss(out, labels) |
| 126 | + loss.backward() |
| 127 | + torch.cuda.synchronize() |
| 128 | + |
| 129 | + labels = torch.ones_like(x) |
| 130 | + |
| 131 | + # Warmup bf16 |
| 132 | + warmup(ref_model, ref_x, labels) |
| 133 | + |
| 134 | + # Bench bf16 |
| 135 | + bf16_us = bench_fwd_bwd_microseconds( |
| 136 | + ref_model, |
| 137 | + ref_x, |
| 138 | + labels=labels, |
| 139 | + use_compile=True, |
| 140 | + fullgraph=False, |
| 141 | + ) |
| 142 | + bf16_ms = bf16_us / 1e3 |
| 143 | + if enable_profile: |
| 144 | + print("Profiling bf16 training") |
| 145 | + profile_fwd_bwd( |
| 146 | + ref_model, |
| 147 | + ref_x, |
| 148 | + labels=labels, |
| 149 | + use_compile=True, |
| 150 | + fullgraph=False, |
| 151 | + profile_name="bf16_profile", |
| 152 | + ) |
| 153 | + |
| 154 | + # Warmup quantized |
| 155 | + warmup(model, x, labels) |
| 156 | + |
| 157 | + # Bench quantized |
| 158 | + scaled_us = bench_fwd_bwd_microseconds( |
| 159 | + model, |
| 160 | + x, |
| 161 | + labels=labels, |
| 162 | + use_compile=True, |
| 163 | + fullgraph=False, |
| 164 | + ) |
| 165 | + scaled_ms = scaled_us / 1e3 |
| 166 | + if enable_profile: |
| 167 | + print("Profiling quantized training") |
| 168 | + profile_fwd_bwd( |
| 169 | + model, |
| 170 | + x, |
| 171 | + labels=labels, |
| 172 | + use_compile=True, |
| 173 | + fullgraph=False, |
| 174 | + profile_name=f"{recipe_name}_profile", |
| 175 | + ) |
| 176 | + |
| 177 | + print(f"total_M: {local_batch_size * seq_len}, N: {hidden_dim}, K: {dim}") |
| 178 | + print(f"bf16 time: {bf16_ms:.3f} ms") |
| 179 | + print(f"{recipe_name} time: {scaled_ms:.3f} ms") |
| 180 | + print(f"speedup: {bf16_us / scaled_us:.3f}x") |
| 181 | + |
| 182 | + |
| 183 | +if __name__ == "__main__": |
| 184 | + parser = argparse.ArgumentParser(description="Benchmark MoE layer with FSDP2") |
| 185 | + parser.add_argument( |
| 186 | + "--profile", |
| 187 | + action="store_true", |
| 188 | + help="Enable PyTorch profiling and save results to file", |
| 189 | + ) |
| 190 | + parser.add_argument( |
| 191 | + "--recipe", type=str, help="[fp8_rowwise, mxfp8]", required=True |
| 192 | + ) |
| 193 | + parser.add_argument( |
| 194 | + "--local_num_experts", |
| 195 | + type=int, |
| 196 | + default=8, |
| 197 | + ) |
| 198 | + parser.add_argument( |
| 199 | + "--seq_len", |
| 200 | + type=int, |
| 201 | + default=8192, |
| 202 | + ) |
| 203 | + parser.add_argument( |
| 204 | + "--local_batch_size", |
| 205 | + type=int, |
| 206 | + default=8, |
| 207 | + ) |
| 208 | + parser.add_argument( |
| 209 | + "--hidden_dim", |
| 210 | + type=int, |
| 211 | + default=8192, |
| 212 | + ) |
| 213 | + parser.add_argument( |
| 214 | + "--dim", |
| 215 | + type=int, |
| 216 | + default=5120, |
| 217 | + ) |
| 218 | + |
| 219 | + args = parser.parse_args() |
| 220 | + bench_moe_training_fsdp(args) |
0 commit comments