|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import types |
| 8 | + |
| 9 | +import torch |
| 10 | +import torch.nn as nn |
| 11 | + |
| 12 | +from torchtitan.models.moe.moe import MoE, TokenReorderer |
| 13 | + |
| 14 | + |
| 15 | +# NOTE: @goon - torch.testing.assert_close is very strict and hard to pass. Use the more-lenient |
| 16 | +# assert_close from FLA, slightly modified to remove their CI related code. |
| 17 | +# https://github.com/fla-org/flash-linear-attention/blob/3ddba2a043100837a1f6499b5eb6692de71a477b/fla/utils.py?plain=1#L82 |
| 18 | +def get_abs_err(x, y): |
| 19 | + return (x.detach() - y.detach()).flatten().abs().max().item() |
| 20 | + |
| 21 | + |
| 22 | +def get_err_ratio(x, y): |
| 23 | + err = (x.detach() - y.detach()).flatten().square().mean().sqrt().item() |
| 24 | + base = (x.detach()).flatten().square().mean().sqrt().item() |
| 25 | + return err / (base + 1e-8) |
| 26 | + |
| 27 | + |
| 28 | +def assert_close(prefix, ref, tri, ratio, err_atol=1e-6): |
| 29 | + abs_atol = get_abs_err(ref, tri) |
| 30 | + msg = f"{prefix:>16} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}" |
| 31 | + error_rate = get_err_ratio(ref, tri) |
| 32 | + if abs_atol <= err_atol: |
| 33 | + return |
| 34 | + assert error_rate < ratio, msg |
| 35 | + |
| 36 | + |
| 37 | +# For testing: copy over old router and MoE impls and use these to monkey patch models in tests. |
| 38 | +# Code copied from |
| 39 | +# https://github.com/pytorch/torchtitan/blob/3819737fab042fdfd5443b1d99753b951b59696d/torchtitan/models/moe/moe.py?plain=1#L298 |
| 40 | +class MoEOld(MoE): |
| 41 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 42 | + """ |
| 43 | + Args: |
| 44 | + x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``. |
| 45 | +
|
| 46 | + Returns: |
| 47 | + out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. |
| 48 | + """ |
| 49 | + bs, slen, dim = x.shape |
| 50 | + x = x.view(-1, dim) |
| 51 | + |
| 52 | + # top_scores and selected_experts_indices shape (bs*slen*top_k,) |
| 53 | + # num_tokens_per_expert shape (num_experts,) |
| 54 | + ( |
| 55 | + top_scores, |
| 56 | + selected_experts_indices, |
| 57 | + num_tokens_per_expert, |
| 58 | + ) = self.router(x, self.expert_bias) |
| 59 | + |
| 60 | + # tokens_per_expert will be used to update the expert bias for load balancing. |
| 61 | + # and also to count the expert usage |
| 62 | + # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- |
| 63 | + # first in the forward pass, and then in the backward pass. However, this has no |
| 64 | + # effect on the expert bias update thanks to the torch.sign() operator. |
| 65 | + with torch.no_grad(): |
| 66 | + self.tokens_per_expert.add_(num_tokens_per_expert) |
| 67 | + |
| 68 | + # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,) |
| 69 | + # num_tokens_per_expert shape (num_experts,) |
| 70 | + # NOTE: the reason we need to compute num_tokens_per_expert again is: |
| 71 | + # 1st computation in router is to update self.tokens_per_expert |
| 72 | + # which would be the same across all TP ranks. |
| 73 | + # 2nd computation in reorderer is for the actual routing and experts computation |
| 74 | + # which would be sharded over TP ranks if expert_tensor_parallel_degree==1. |
| 75 | + # If tensor_paralllel_degree == expert_tensor_parallel_degree, they agree. |
| 76 | + ( |
| 77 | + top_scores_experts_sorted, |
| 78 | + token_indices_experts_sorted, |
| 79 | + num_tokens_per_expert, |
| 80 | + ) = self.reorderer(top_scores, selected_experts_indices) |
| 81 | + |
| 82 | + # shape (bs*slen*top_k, dim) |
| 83 | + token_indices_experts_sorted = token_indices_experts_sorted.reshape( |
| 84 | + -1, 1 |
| 85 | + ).expand(-1, dim) |
| 86 | + |
| 87 | + # shape (bs*slen*top_k, dim) |
| 88 | + routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted) |
| 89 | + |
| 90 | + if self.score_before_experts: |
| 91 | + routed_input = ( |
| 92 | + routed_input.to(torch.float32) |
| 93 | + * top_scores_experts_sorted.reshape(-1, 1) |
| 94 | + ).to(x.dtype) |
| 95 | + |
| 96 | + # shape (bs*slen*top_k, dim) |
| 97 | + routed_output = self.experts(routed_input, num_tokens_per_expert) |
| 98 | + |
| 99 | + # shared expert |
| 100 | + # Note: we execute the shared expert before scoring the output of the routed expert |
| 101 | + # to "implicitly" overlap the shared expert compute with token combine communication |
| 102 | + if self.shared_experts is not None: |
| 103 | + out = self.shared_experts(x) |
| 104 | + else: |
| 105 | + out = torch.zeros_like(x) |
| 106 | + |
| 107 | + if not self.score_before_experts: |
| 108 | + routed_output = ( |
| 109 | + routed_output.to(torch.float32) |
| 110 | + * top_scores_experts_sorted.reshape(-1, 1) |
| 111 | + ).to(x.dtype) |
| 112 | + |
| 113 | + out = out.scatter_add( |
| 114 | + dim=0, index=token_indices_experts_sorted, src=routed_output |
| 115 | + ) |
| 116 | + out = out.reshape(bs, slen, dim) |
| 117 | + return out |
| 118 | + |
| 119 | + |
| 120 | +class TokenReordererOld(TokenReorderer): |
| 121 | + def forward( |
| 122 | + self, |
| 123 | + top_scores: torch.Tensor, |
| 124 | + selected_experts_indices: torch.Tensor, |
| 125 | + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| 126 | + """ |
| 127 | + Reorders token indices to match the order of experts for MoE routing. |
| 128 | +
|
| 129 | + Args: |
| 130 | + top_scores (torch.Tensor): Routing scores for selected experts, |
| 131 | + shape (batch_size * seq_len, top_k) |
| 132 | + selected_experts_indices (torch.Tensor): Expert indices selected for each token, |
| 133 | + shape (batch_size*seq_len, top_k) |
| 134 | +
|
| 135 | + Returns: |
| 136 | + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| 137 | + - top_scores_experts_sorted: Scores reordered to match expert ordering |
| 138 | + - token_indices_experts_sorted: Token indices reordered to match expert ordering |
| 139 | + - num_tokens_per_expert: Number of tokens assigned to each expert |
| 140 | + """ |
| 141 | + # group tokens together by expert indices from 0 to num_experts and pass that to experts forward |
| 142 | + num_tokens_per_expert = torch.histc( |
| 143 | + selected_experts_indices.view(-1), |
| 144 | + bins=self.num_experts, |
| 145 | + min=0, |
| 146 | + max=self.num_experts, |
| 147 | + ) |
| 148 | + |
| 149 | + # Reorder the token indices to match the order of the experts |
| 150 | + # token_indices_experts_sorted shape (bs*slen*top_k,) |
| 151 | + token_indices_experts_sorted = torch.argsort( |
| 152 | + selected_experts_indices.view(-1), stable=True |
| 153 | + ) |
| 154 | + |
| 155 | + top_scores_experts_sorted = top_scores.view(-1)[token_indices_experts_sorted] |
| 156 | + token_indices_experts_sorted = token_indices_experts_sorted // self.top_k |
| 157 | + |
| 158 | + return ( |
| 159 | + top_scores_experts_sorted, |
| 160 | + token_indices_experts_sorted, |
| 161 | + num_tokens_per_expert, |
| 162 | + ) |
| 163 | + |
| 164 | + |
| 165 | +def apply_old_moe_monkey_patches(module: nn.Module) -> None: |
| 166 | + for mod in module.modules(): |
| 167 | + if isinstance(mod, MoE): |
| 168 | + mod.forward = types.MethodType(MoEOld.forward, mod) |
| 169 | + if isinstance(mod, TokenReorderer): |
| 170 | + mod.forward = types.MethodType(TokenReordererOld.forward, mod) |
0 commit comments