diff --git a/smoe/models/mixtral/modeling_mixtral.py b/smoe/models/mixtral/modeling_mixtral.py index 15e155b..93fbed9 100644 --- a/smoe/models/mixtral/modeling_mixtral.py +++ b/smoe/models/mixtral/modeling_mixtral.py @@ -3,15 +3,16 @@ import inspect import math import warnings +from packaging import version from dataclasses import dataclass from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union -import stk import torch import torch.distributed as dist import torch.nn.functional as F import torch.utils.checkpoint -from packaging import version +from megablocks.layers.arguments import Arguments as MBArgs +from megablocks.layers.dmoe import ParallelDroplessMLP from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers import ( @@ -1725,6 +1726,19 @@ def forward(self, hidden_states): return current_hidden_states +class ParallelDroplessMLPWithoutLBLSaving(ParallelDroplessMLP): + def forward(self, x, expert_weights, top_experts): + in_shape = x.size() + # Compute the experts. + x, _ = self.forward_fn(x, expert_weights, top_experts) + x = x.view(in_shape) + if self.bias is not None: + if self.args.return_bias: + return x, self.bias + return x + self.bias + return x + + MISTRAL_ATTENTION_CLASSES = { "eager": MixtralAttention, "flash_attention_2": MixtralFlashAttention2, @@ -1774,6 +1788,27 @@ def __init__(self, config): for _ in range(self.num_experts) ] # 🔍 ) + elif self.moe_type == "megablocks": + config: MixtralConfig + is_fp16 = self.gate.weight.data.dtype == torch.float16 + is_bf16 = self.gate.weight.data.dtype == torch.bfloat16 + mb_args = MBArgs( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_layers=config.num_hidden_layers, + bias=False, + return_bias=False, + activation_fn=nn.SiLU(), + moe_num_experts=config.num_local_experts, + moe_top_k=config.num_experts_per_tok, + memory_optimized_mlp=False, + mlp_type='glu', + mlp_impl='sparse', + fp16=is_fp16, + bf16=is_bf16, + device=torch.cuda.current_device(), + ) + self.experts = ParallelDroplessMLPWithoutLBLSaving(mb_args) else: raise NotImplementedError(f"Unsupported moe_type: {self.moe_type}") @@ -1790,45 +1825,50 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) - final_hidden_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be sollicitated - expert_mask = torch.nn.functional.one_hot( - selected_experts, num_classes=self.num_experts - ).permute(2, 1, 0) - - # Loop over all available experts in the model and perform the computation on each expert - for expert_idx in range(self.num_experts): - expert_layer = self.experts[expert_idx] - idx, top_x = torch.where(expert_mask[expert_idx]) + if self.moe_type == "modulelist": + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) - if ( - top_x.shape[0] == 0 and not self.training - ): # skip during training will lead to asynchrony among different GPUs and blocks the training! - continue + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=self.num_experts + ).permute(2, 1, 0) - # in torch it is faster to index using lists than torch tensors - top_x_list = top_x.tolist() - idx_list = idx.tolist() + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * ( - routing_weights[top_x_list, idx_list, None] * self.scale_factor - ) + if ( + top_x.shape[0] == 0 and not self.training + ): # skip during training will lead to asynchrony among different GPUs and blocks the training! + continue + + # in torch it is faster to index using lists than torch tensors + top_x_list = top_x.tolist() + idx_list = idx.tolist() + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * ( + routing_weights[top_x_list, idx_list, None] * self.scale_factor + ) - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. - final_hidden_states.index_add_( - 0, top_x, current_hidden_states.to(hidden_states.dtype) - ) + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) + elif self.moe_type == "megablocks": + final_hidden_states = self.experts(hidden_states, routing_weights, selected_experts) + else: + raise NotImplementedError(f"Unsupported moe_type: {self.moe_type}") final_hidden_states = final_hidden_states.reshape( batch_size, sequence_length, hidden_dim diff --git a/smoe/utils/expert_construction/convert_llama_to_mixtral_mb.py b/smoe/utils/expert_construction/convert_llama_to_mixtral_mb.py new file mode 100644 index 0000000..c2ed823 --- /dev/null +++ b/smoe/utils/expert_construction/convert_llama_to_mixtral_mb.py @@ -0,0 +1,240 @@ +""" +Convert the original llama weights into mixtral weights with megablocks support. +""" + +import math +import os.path +import re +import shutil +from collections import defaultdict +from pathlib import Path + +import torch +from safetensors import safe_open +from safetensors.torch import save_file +from torch.nn import init +from transformers.modeling_utils import dtype_byte_size + +from smoe.models.mixtral.configuration_mixtral import MixtralConfig +from smoe.models.mixtral.modeling_mixtral import MixtralForCausalLM +from smoe.utils.io import dump_json, load_json + + +def is_safetensors_file(filepath): + if isinstance(filepath, str): + filepath = Path(filepath) + string = filepath.name + return re.match(r"model-\d{5}-of-\d{5}.safetensors", string) is not None + + +FFN_TYPE_MAP = { + "modulelist": { + "gate": "w1", + "down": "w2", + "up": "w3", + }, + "megablocks": { + # mlp.(w1|v1|w2) + "gate": "w1", # (ffn_hidden_size x num_experts) x hsz + "up": "v1", # (ffn_hidden_size x num_experts) x hsz + "down": "w2", # (ffn_hidden_size x num_experts) x hsz + }, +} + + +def convert_safetensors( + model_dir, + dump_dir, + num_experts: int, + top_k: int, + scale_factor: float = 1.0, + num_moe_contract_layers: int = 0, + moe_type: str = "modulelist", + neuron_indices: dict = None, + gate_weights: dict = None, +): + # fmt: off + model_folder = Path(model_dir) + dump_folder = Path(dump_dir) + dump_folder.mkdir(parents=True, exist_ok=True) + ffn_type_map = FFN_TYPE_MAP[moe_type] + + raw_total_size = -1 + tensor_filepaths = [] + for filepath in model_folder.glob("*"): + if not os.path.isdir(filepath): + if is_safetensors_file(filepath): + tensor_filepaths.append(filepath) + if filepath.name == "config.json": + config = MixtralConfig.from_pretrained(filepath) + config.architectures = ["MixtralForCausalLM"] + config.num_experts_per_tok = top_k + config.num_local_experts = num_experts + config.router_aux_loss_coef = 1e-2 + config.scale_factor = scale_factor + config.moe_type = moe_type + config.num_moe_contract_layers=num_moe_contract_layers + config.intermediate_size = config.intermediate_size // num_experts + config.auto_map = { + "AutoConfig": "configuration_mixtral.MixtralConfig", + "AutoModel": "modeling_mixtral.MixtralModel", + "AutoModelForCausalLM": "modeling_mixtral.MixtralForCausalLM", + } + config.save_pretrained(dump_folder) + for filename in [ + "configuration_mixtral.py", + "modeling_mixtral.py", + ]: + shutil.copy2(f"smoe/models/mixtral/{filename}", dump_folder / filename) + (dump_folder / "__init__.py").touch() + elif filepath.name == "model.safetensors.index.json": + raw_total_size = load_json(filepath)["metadata"]["total_size"] + else: + # cp to dump_dir + shutil.copy2(filepath, dump_folder / filepath.name) + + router_records = set() + weight_map = {} + total_size = 0 + total_gate_size = 0 + visited_layers = set() + for fi, filepath in enumerate(tensor_filepaths): + with safe_open(filepath, framework="pt", device="cpu") as f: + tensors = {} + contained_layers = set() + for key in f.keys(): + tensor = f.get_tensor(key) + if ".mlp." in key: + # preparation + layer_idx, ffn_type = re.search( + r"model.layers.(\d+).mlp.(gate|up|down)_proj.weight", key + ).groups() + layer_idx = int(layer_idx) + + is_moe = (layer_idx >= num_moe_contract_layers) and (layer_idx < config.num_hidden_layers - num_moe_contract_layers) + + if is_moe: + contained_layers.add(layer_idx) + + if ffn_type == "down": + hsz, mid = tensor.shape + mid_idx = 1 + else: + mid, hsz = tensor.shape + mid_idx = 0 + + # initialize gate weights + if layer_idx not in router_records: + if gate_weights is None: # use newly initialized gate weights + gate_weight = torch.zeros(num_experts, hsz) + init.kaiming_uniform_(gate_weight, a=math.sqrt(5)) + tensors[ + f"model.layers.{layer_idx}.block_sparse_moe.gate.weight" + ] = gate_weight + else: # use provided gate weights + print(f"Initializing layer {layer_idx} gate weights using {gate_weights[layer_idx]}...") + tensors[ + f"model.layers.{layer_idx}.block_sparse_moe.gate.weight" + ] = gate_weights[layer_idx].clone() + router_records.add(layer_idx) + new_ffn_type = ffn_type_map[ffn_type] + + # initialize expert weights + if moe_type == "modulelist": + expert_size = mid // num_experts + for expert_idx in range(num_experts): + if mid_idx == 0: + if neuron_indices is None: # sequential split + expert_tensor = tensor[expert_idx * expert_size: (expert_idx + 1) * expert_size].clone() + else: # split according to the given indices + this_layer_indices: list = neuron_indices[layer_idx] + print(f"Initializing layer {layer_idx} expert {expert_idx} {ffn_type} using neurons with indices {this_layer_indices[expert_idx]}...") + expert_tensor = tensor[this_layer_indices[expert_idx]].clone() + else: + if neuron_indices is None: # sequential split + expert_tensor = tensor[:, expert_idx * expert_size: (expert_idx + 1) * expert_size].clone() + else: # split according to the given indices + this_layer_indices: list = neuron_indices[layer_idx] + print(f"Initializing layer {layer_idx} expert {expert_idx} {ffn_type} using neurons with indices {this_layer_indices[expert_idx]}...") + expert_tensor = tensor[:, this_layer_indices[expert_idx]].clone() + tensors[ + f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.{new_ffn_type}.weight" + ] = expert_tensor + + elif moe_type == "megablocks": + expert_size = mid // num_experts + tname = f"model.layers.{layer_idx}.block_sparse_moe.experts.mlp.{new_ffn_type}" + if mid_idx == 0: + # up & gate + tensors[tname] = tensor + else: + # down + tensors[tname] = tensor.t() + + else: + raise NotImplementedError + + else: + tensors[key] = tensor + + else: + tensors[key] = tensor + + for key in tensors: + tensors[key] = tensors[key].contiguous() + save_file(tensors, dump_folder / filepath.name, metadata={"format": "pt"}) + for key, tensor in tensors.items(): + weight_size = tensor.numel() * dtype_byte_size(tensor.dtype) + total_size += weight_size + weight_map[key] = filepath.name + if ".block_sparse_moe.gate." in key: + total_gate_size += weight_size + print(key, tensor.shape) + + metadata = {"total_size": total_size} + index = {"metadata": metadata, "weight_map": weight_map} + dump_json(index, dump_folder / "model.safetensors.index.json", indent=2) + assert total_size - total_gate_size == raw_total_size + + +if __name__ == "__main__": + num_experts = 8 + top_k = 2 + + # src_model_dir = "/mnt/petrelfs/share_data/quxiaoye/models/Meta-Llama-3-8B" + src_model_dir = "/mnt/petrelfs/share_data/quxiaoye/models/Meta-Llama-3-8B-Instruct" + # tgt_model_dir_prefix = f"/mnt/petrelfs/share_data/quxiaoye/llama_moe_v2/converted_models/split-sequential-Top{top_k}" + + + # moe_type = "modulelist" + # tgt_model_dir_prefix = "/mnt/petrelfs/share_data/quxiaoye/llama_moe_v2/converted_models/tzhu_mixtral_mb/ml_8top2" + + moe_type = "megablocks" + tgt_model_dir_prefix = "/mnt/petrelfs/share_data/quxiaoye/llama_moe_v2/converted_models/tzhu_mixtral_mb/mb_8top2" + + neuron_indices_file = "" + gate_weights_file = "" + + print(f"converting {moe_type}") + convert_safetensors( + src_model_dir, + f"{tgt_model_dir_prefix}", + num_experts=num_experts, + top_k=top_k, + moe_type=moe_type, + neuron_indices=None + if neuron_indices_file == "" + else torch.load(neuron_indices_file), + gate_weights=None + if gate_weights_file == "" + else torch.load(gate_weights_file), + ) + + print(f"testing {moe_type}") + m = MixtralForCausalLM.from_pretrained(f"{tgt_model_dir_prefix}", torch_dtype=torch.bfloat16) + + print(f"Re-saving {moe_type}") + m.save_pretrained(f"{tgt_model_dir_prefix}") + + print("Done") + # fmt: on