Skip to content

Commit a26de67

Browse files
committed
testing infra
1 parent 9322ff6 commit a26de67

File tree

6 files changed

+388
-142
lines changed

6 files changed

+388
-142
lines changed

torchtitan/models/moe/moe.py

Lines changed: 13 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -413,14 +413,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
413413
bs, slen, dim = x.shape
414414
x = x.view(-1, dim)
415415

416-
# top_scores shape (bs*slen, top_k)
417-
# selected_experts_indices shape (bs*slen, top_k)
416+
# top_scores and selected_experts_indices shape (bs*slen, top_k)
418417
# num_tokens_per_expert shape (num_experts,)
419418
(
420419
top_scores,
421420
selected_experts_indices,
422421
num_tokens_per_expert,
423422
) = self.router(x, self.expert_bias)
423+
top_k = selected_experts_indices.shape[-1]
424424

425425
# tokens_per_expert will be used to update the expert bias for load balancing.
426426
# and also to count the expert usage
@@ -461,22 +461,25 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
461461
# to "implicitly" overlap the shared expert compute with token combine communication
462462
out = self.shared_experts(x) if self.shared_experts is not None else None
463463

464+
# Unsort routed outputs
465+
routed_output_unsorted = torch.zeros(
466+
(bs * slen * top_k, dim),
467+
dtype=routed_output.dtype,
468+
device=routed_output.device,
469+
)
470+
routed_output_unsorted[token_indices_experts_sorted] = routed_output
471+
routed_output_unsorted = routed_output_unsorted.reshape(-1, top_k, dim)
464472
if not self.score_before_experts:
465-
# Unsort scores and routed outputs. Also save some allocations: store unsorted scores
466-
# and outputs in top_scores and routed_input, respectively.
467-
routed_input[token_indices_experts_sorted] = routed_output
468-
routed_input = routed_input.reshape(-1, self.router.top_k, dim)
469473
out_experts = (
470474
torch.bmm(
471-
top_scores.reshape(-1, 1, self.router.top_k), routed_input.float()
475+
top_scores.reshape(-1, 1, self.router.top_k),
476+
routed_output_unsorted.float(),
472477
)
473478
.to(x.dtype)
474479
.squeeze(1)
475480
)
476481
else:
477-
# Unsort routed outputs and save an allocation: store unsorted outputs in routed_input
478-
routed_input[token_indices_experts_sorted] = routed_output
479-
out_experts = routed_input.reshape(-1, self.router.top_k, dim).sum(dim=1)
482+
out_experts = routed_output_unsorted.sum(dim=1)
480483

481484
if out is None:
482485
return out_experts.reshape(bs, slen, dim)
@@ -500,88 +503,3 @@ def init_weights(
500503
self.expert_bias = torch.zeros(
501504
self.experts.num_experts, dtype=torch.float32
502505
)
503-
504-
505-
# For testing
506-
class MoEOld(MoE):
507-
def forward(self, x: torch.Tensor) -> torch.Tensor:
508-
"""
509-
Args:
510-
x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``.
511-
512-
Returns:
513-
out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``.
514-
"""
515-
bs, slen, dim = x.shape
516-
x = x.view(-1, dim)
517-
518-
# top_scores and selected_experts_indices shape (bs*slen*top_k,)
519-
# num_tokens_per_expert shape (num_experts,)
520-
(
521-
top_scores,
522-
selected_experts_indices,
523-
num_tokens_per_expert,
524-
) = self.router(x, self.expert_bias)
525-
526-
# tokens_per_expert will be used to update the expert bias for load balancing.
527-
# and also to count the expert usage
528-
# TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert --
529-
# first in the forward pass, and then in the backward pass. However, this has no
530-
# effect on the expert bias update thanks to the torch.sign() operator.
531-
with torch.no_grad():
532-
self.tokens_per_expert.add_(num_tokens_per_expert)
533-
534-
# top_scores and token_indices_experts_sorted shape (bs*slen*top_k,)
535-
# num_tokens_per_expert shape (num_experts,)
536-
# NOTE: the reason we need to compute num_tokens_per_expert again is:
537-
# 1st computation in router is to update self.tokens_per_expert
538-
# which would be the same across all TP ranks.
539-
# 2nd computation in reorderer is for the actual routing and experts computation
540-
# which would be sharded over TP ranks if expert_tensor_parallel_degree==1.
541-
# If tensor_paralllel_degree == expert_tensor_parallel_degree, they agree.
542-
(
543-
top_scores_experts_sorted,
544-
token_indices_experts_sorted,
545-
num_tokens_per_expert,
546-
) = self.reorderer(top_scores, selected_experts_indices)
547-
# NOTE: @goon - adjust for redefined reorderer output and divide by top_k
548-
token_indices_experts_sorted = (
549-
token_indices_experts_sorted // self.reorderer.top_k
550-
)
551-
552-
# shape (bs*slen*top_k, dim)
553-
token_indices_experts_sorted = token_indices_experts_sorted.reshape(
554-
-1, 1
555-
).expand(-1, dim)
556-
557-
# shape (bs*slen*top_k, dim)
558-
routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted)
559-
560-
if self.score_before_experts:
561-
routed_input = (
562-
routed_input.to(torch.float32)
563-
* top_scores_experts_sorted.reshape(-1, 1)
564-
).to(x.dtype)
565-
566-
# shape (bs*slen*top_k, dim)
567-
routed_output = self.experts(routed_input, num_tokens_per_expert)
568-
569-
# shared expert
570-
# Note: we execute the shared expert before scoring the output of the routed expert
571-
# to "implicitly" overlap the shared expert compute with token combine communication
572-
if self.shared_experts is not None:
573-
out = self.shared_experts(x)
574-
else:
575-
out = torch.zeros_like(x)
576-
577-
if not self.score_before_experts:
578-
routed_output = (
579-
routed_output.to(torch.float32)
580-
* top_scores_experts_sorted.reshape(-1, 1)
581-
).to(x.dtype)
582-
583-
out = out.scatter_add(
584-
dim=0, index=token_indices_experts_sorted, src=routed_output
585-
)
586-
out = out.reshape(bs, slen, dim)
587-
return out
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import types
8+
9+
import torch
10+
import torch.nn as nn
11+
12+
from torchtitan.models.moe.moe import MoE, TokenReorderer
13+
14+
15+
# NOTE: @goon - torch.testing.assert_close is very strict and hard to pass. Use the more-lenient
16+
# assert_close from FLA, slightly modified to remove their CI related code.
17+
# https://github.com/fla-org/flash-linear-attention/blob/3ddba2a043100837a1f6499b5eb6692de71a477b/fla/utils.py?plain=1#L82
18+
def get_abs_err(x, y):
19+
return (x.detach() - y.detach()).flatten().abs().max().item()
20+
21+
22+
def get_err_ratio(x, y):
23+
err = (x.detach() - y.detach()).flatten().square().mean().sqrt().item()
24+
base = (x.detach()).flatten().square().mean().sqrt().item()
25+
return err / (base + 1e-8)
26+
27+
28+
def assert_close(prefix, ref, tri, ratio, err_atol=1e-6):
29+
abs_atol = get_abs_err(ref, tri)
30+
msg = f"{prefix:>16} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}"
31+
error_rate = get_err_ratio(ref, tri)
32+
if abs_atol <= err_atol:
33+
return
34+
assert error_rate < ratio, msg
35+
36+
37+
# For testing: copy over old router and MoE impls and use these to monkey patch models in tests.
38+
# Code copied from
39+
# https://github.com/pytorch/torchtitan/blob/3819737fab042fdfd5443b1d99753b951b59696d/torchtitan/models/moe/moe.py?plain=1#L298
40+
class MoEOld(MoE):
41+
def forward(self, x: torch.Tensor) -> torch.Tensor:
42+
"""
43+
Args:
44+
x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``.
45+
46+
Returns:
47+
out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``.
48+
"""
49+
bs, slen, dim = x.shape
50+
x = x.view(-1, dim)
51+
52+
# top_scores and selected_experts_indices shape (bs*slen*top_k,)
53+
# num_tokens_per_expert shape (num_experts,)
54+
(
55+
top_scores,
56+
selected_experts_indices,
57+
num_tokens_per_expert,
58+
) = self.router(x, self.expert_bias)
59+
60+
# tokens_per_expert will be used to update the expert bias for load balancing.
61+
# and also to count the expert usage
62+
# TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert --
63+
# first in the forward pass, and then in the backward pass. However, this has no
64+
# effect on the expert bias update thanks to the torch.sign() operator.
65+
with torch.no_grad():
66+
self.tokens_per_expert.add_(num_tokens_per_expert)
67+
68+
# top_scores and token_indices_experts_sorted shape (bs*slen*top_k,)
69+
# num_tokens_per_expert shape (num_experts,)
70+
# NOTE: the reason we need to compute num_tokens_per_expert again is:
71+
# 1st computation in router is to update self.tokens_per_expert
72+
# which would be the same across all TP ranks.
73+
# 2nd computation in reorderer is for the actual routing and experts computation
74+
# which would be sharded over TP ranks if expert_tensor_parallel_degree==1.
75+
# If tensor_paralllel_degree == expert_tensor_parallel_degree, they agree.
76+
(
77+
top_scores_experts_sorted,
78+
token_indices_experts_sorted,
79+
num_tokens_per_expert,
80+
) = self.reorderer(top_scores, selected_experts_indices)
81+
82+
# shape (bs*slen*top_k, dim)
83+
token_indices_experts_sorted = token_indices_experts_sorted.reshape(
84+
-1, 1
85+
).expand(-1, dim)
86+
87+
# shape (bs*slen*top_k, dim)
88+
routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted)
89+
90+
if self.score_before_experts:
91+
routed_input = (
92+
routed_input.to(torch.float32)
93+
* top_scores_experts_sorted.reshape(-1, 1)
94+
).to(x.dtype)
95+
96+
# shape (bs*slen*top_k, dim)
97+
routed_output = self.experts(routed_input, num_tokens_per_expert)
98+
99+
# shared expert
100+
# Note: we execute the shared expert before scoring the output of the routed expert
101+
# to "implicitly" overlap the shared expert compute with token combine communication
102+
if self.shared_experts is not None:
103+
out = self.shared_experts(x)
104+
else:
105+
out = torch.zeros_like(x)
106+
107+
if not self.score_before_experts:
108+
routed_output = (
109+
routed_output.to(torch.float32)
110+
* top_scores_experts_sorted.reshape(-1, 1)
111+
).to(x.dtype)
112+
113+
out = out.scatter_add(
114+
dim=0, index=token_indices_experts_sorted, src=routed_output
115+
)
116+
out = out.reshape(bs, slen, dim)
117+
return out
118+
119+
120+
class TokenReordererOld(TokenReorderer):
121+
def forward(
122+
self,
123+
top_scores: torch.Tensor,
124+
selected_experts_indices: torch.Tensor,
125+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
126+
"""
127+
Reorders token indices to match the order of experts for MoE routing.
128+
129+
Args:
130+
top_scores (torch.Tensor): Routing scores for selected experts,
131+
shape (batch_size * seq_len, top_k)
132+
selected_experts_indices (torch.Tensor): Expert indices selected for each token,
133+
shape (batch_size*seq_len, top_k)
134+
135+
Returns:
136+
tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
137+
- top_scores_experts_sorted: Scores reordered to match expert ordering
138+
- token_indices_experts_sorted: Token indices reordered to match expert ordering
139+
- num_tokens_per_expert: Number of tokens assigned to each expert
140+
"""
141+
# group tokens together by expert indices from 0 to num_experts and pass that to experts forward
142+
num_tokens_per_expert = torch.histc(
143+
selected_experts_indices.view(-1),
144+
bins=self.num_experts,
145+
min=0,
146+
max=self.num_experts,
147+
)
148+
149+
# Reorder the token indices to match the order of the experts
150+
# token_indices_experts_sorted shape (bs*slen*top_k,)
151+
token_indices_experts_sorted = torch.argsort(
152+
selected_experts_indices.view(-1), stable=True
153+
)
154+
155+
top_scores_experts_sorted = top_scores.view(-1)[token_indices_experts_sorted]
156+
token_indices_experts_sorted = token_indices_experts_sorted // self.top_k
157+
158+
return (
159+
top_scores_experts_sorted,
160+
token_indices_experts_sorted,
161+
num_tokens_per_expert,
162+
)
163+
164+
165+
def apply_old_moe_monkey_patches(module: nn.Module) -> None:
166+
for mod in module.modules():
167+
if isinstance(mod, MoE):
168+
mod.forward = types.MethodType(MoEOld.forward, mod)
169+
if isinstance(mod, TokenReorderer):
170+
mod.forward = types.MethodType(TokenReordererOld.forward, mod)

moe_bench_and_test/moe_memory.py renamed to torchtitan/moe_bench_and_test/moe_memory.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import torch
1010

1111
from torchtitan.components.metrics import DeviceMemoryMonitor
12-
from torchtitan.models.moe.moe import MoE, MoEArgs, MoEOld
12+
from torchtitan.models.moe.moe import MoE, MoEArgs
13+
14+
from torchtitan.moe_bench_and_test import apply_old_moe_monkey_patches
1315

1416
if __name__ == "__main__":
1517
parser = argparse.ArgumentParser()
@@ -43,17 +45,13 @@
4345
use_grouped_mm=use_grouped_mm,
4446
)
4547

46-
if args.cls == "moe":
47-
cls = MoE
48-
elif args.cls == "moe_old":
49-
cls = MoEOld
50-
else:
51-
raise ValueError
52-
5348
torch.manual_seed(42)
54-
moe = cls(moe_args, dim=dim, hidden_dim=moe_inter_dim).to(
49+
moe = MoE(moe_args, dim=dim, hidden_dim=moe_inter_dim).to(
5550
device=device, dtype=torch.bfloat16
5651
)
52+
if args.cls == "moe_old":
53+
apply_old_moe_monkey_patches(moe)
54+
5755
moe.init_weights(1 / dim**0.5, device)
5856
inputs = torch.randn(
5957
args.bsz,
@@ -66,6 +64,8 @@
6664
for _ in range(args.iters):
6765
moe(inputs).sum().backward()
6866
moe.zero_grad()
67+
68+
torch.cuda.synchronize()
6969
print(f"\n{args=}")
7070
peak_stats = mem_monitor.get_peak_stats()
7171
print(f"{peak_stats.max_active_gib=}")

moe_bench_and_test/moe_timing.py renamed to torchtitan/moe_bench_and_test/moe_timing.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import torch
1010
from triton.testing import do_bench
1111

12-
from torchtitan.models.moe.moe import MoE, MoEArgs, MoEOld
12+
from torchtitan.models.moe.moe import MoE, MoEArgs
13+
from torchtitan.moe_bench_and_test import apply_old_moe_monkey_patches
1314

1415
if __name__ == "__main__":
1516
parser = argparse.ArgumentParser()
@@ -43,17 +44,12 @@
4344
use_grouped_mm=use_grouped_mm,
4445
)
4546

46-
if args.cls == "moe":
47-
cls = MoE
48-
elif args.cls == "moe_old":
49-
cls = MoEOld
50-
else:
51-
raise ValueError
52-
5347
torch.manual_seed(42)
54-
moe = cls(moe_args, dim=dim, hidden_dim=moe_inter_dim).to(
48+
moe = MoE(moe_args, dim=dim, hidden_dim=moe_inter_dim).to(
5549
device=device, dtype=torch.bfloat16
5650
)
51+
if args.cls == "moe_old":
52+
apply_old_moe_monkey_patches(moe)
5753
moe.init_weights(1 / dim**0.5, device)
5854
inputs = torch.randn(
5955
args.bsz,

0 commit comments

Comments
 (0)