Skip to content

Commit d6bf068

Browse files
wang2yn84qihqi
andauthored
Mixtral enablement. (#120)
* Initial Mixtral enablement. * Adds the mistral tokenizer model. * Updates the convert checkpoint file to handle mistral model. * Renames the typo of the model name. * Fixing checkpoing loading. Still has some issue. Push to debug. * Running on CPU working, temporarily disable the generate jit to see it's moving. But the outputs doesn't make sense yet because weights are not loaded yet. * Fix checkpoint loading issue. Right now loading from the gpt-fast converter with qkv fusion. * Fix the ckpt conversion script for mistral model. Fix the freqs_cis for loading pth file. * Add quantized layer for moe quantization * Add the huggingface download script. Improved the convert checkpoints logging. * Clean up and fix lint errors. * Missing cleanups. * Add instructions for Mixtral. * Renames everything from mistral to mixtral. * Fix more lint errors. * Removes the unnecessary checkpoint name mapping from the original Mixtral checkpoints. * Fix the model calling arg sequence; Fix the checkpoint convert script. --------- Co-authored-by: Han Qi <hanq@google.com>
1 parent 87b8d92 commit d6bf068

File tree

11 files changed

+918
-8
lines changed

11 files changed

+918
-8
lines changed

README.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,21 @@ huggingface-cli download google/gemma-7b-pytorch --local-dir $input_ckpt_dir
6464

6565
Need to manually modify the `config.json` in the checkpoint folder to make it a valid JSON file. (Replace `'` with `"`, remove the excessive `,` after the last item in the JSON object)
6666

67+
## Mixtral
68+
### Get Mixtral Checkpoint from HuggingFace
69+
70+
Please sign agreement on Huggingface website to access Mixtral checkpoints. Download Mixtral PyTorch checkpoint using huggingface-cli. Mixtral Tokenizer is included in the checkpoint.
71+
72+
```bash
73+
huggingface-cli download mistralai/Mixtral-8x7B-v0.1 --local-dir $input_ckpt_dir
74+
```
75+
6776
## Run weight safetensor convert
6877

6978
```bash
7079
export input_ckpt_dir=Original llama weights directory
7180
export output_ckpt_dir=The output directory
72-
export model_name="llama-3" # or "llama-2", "gemma"
81+
export model_name="llama-3" # or "llama-2", "gemma", "mixtral"
7382
export quantize_weights=True # Whether to quantize weights
7483
export quantize_type="int8_per_channel" # "quantize_weights" needs to be turned on. Availabe quantize type: {"int8", "int4"} x {"per_channel", "blockwise"}, "int8_per_channel" is the default option if not specified.
7584
python -m convert_checkpoints --model_name=$model_name --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize_type=$quantize_type
@@ -108,6 +117,11 @@ python run_interactive.py --size=70b --model_name=$model_name --batch_size=8 --m
108117
python run_interactive.py --model_name=$model_name --size=7b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
109118
```
110119

120+
## Mixtral 8x7b
121+
```bash
122+
python run_interactive.py --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
123+
```
124+
111125

112126
# Run the server
113127
Here is an example to run the server with llama2 7B config.

convert_checkpoints.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import hashlib
2727
import json
2828
import os
29+
import re
2930
import time
3031

3132
import torch
@@ -37,6 +38,8 @@
3738
from jetstream_pt.config import FLAGS
3839
from jetstream_pt.third_party.gemma import model as gemma_model
3940
from jetstream_pt.third_party.llama import model_exportable as llama_model
41+
from jetstream_pt.third_party.mixtral import model as mixtral_model
42+
4043
from safetensors import safe_open
4144
from safetensors.torch import save_file
4245

@@ -123,6 +126,12 @@ def _quantize_state_dict(
123126
block_size = orig_block_size
124127
n_bit = orig_n_bit
125128
state_dict.update(updated_weights)
129+
for k, v in state_dict.items():
130+
if "layers" in k and "layers.0" not in k:
131+
continue
132+
print(
133+
f"After quantization the converted key: {k} and value: {v.shape} {v.dtype}"
134+
)
126135
return state_dict
127136

128137

@@ -470,6 +479,89 @@ def _get_gemma_state_dict(input_ckpt_dir):
470479
return state_dict, model_config
471480

472481

482+
def _get_mixtral_state_dict(input_ckpt_dir):
483+
ckpt_files = list(input_ckpt_dir.glob("*.pt"))
484+
assert len(ckpt_files) == 8, "only expect 8 ckpt file for Mistral model."
485+
486+
start = time.perf_counter()
487+
state_dict = {}
488+
for file in sorted(ckpt_files):
489+
ckpt = torch.load(
490+
str(file), map_location="cpu", mmap=True, weights_only=True
491+
)
492+
state_dict.update(ckpt)
493+
end = time.perf_counter()
494+
print(f"Loading checkpoints takes {end - start} seconds")
495+
496+
for k, v in state_dict.items():
497+
if "layers" in k and "layers.0" not in k:
498+
continue
499+
print(f"The loaded key: {k} and value: {v.shape} {v.dtype}")
500+
501+
config = json.loads((input_ckpt_dir / "config.json").read_text())
502+
print(f"Loaded config: {config}")
503+
weight_map = {
504+
"layers.{}.block_sparse_moe.w1": "layers.{}.block_sparse_moe.cond_ffn.w1",
505+
"layers.{}.block_sparse_moe.w2": "layers.{}.block_sparse_moe.cond_ffn.w2",
506+
"layers.{}.block_sparse_moe.w3": "layers.{}.block_sparse_moe.cond_ffn.w3",
507+
}
508+
for key in list(state_dict.keys()):
509+
if state_dict[key].dtype.is_complex and _OUTPUT_SAFETENSORS.value:
510+
assert (
511+
key == "freqs_cis"
512+
), "Only expect key 'freqs_cis' in the state_dict has complex dtype."
513+
# Remove "freqs_cis" since it has complex dtype, and safetensor doesn't support it.
514+
# The "freqs_cis" will be reconstructed when it's loaded by inference engine.
515+
state_dict.pop(key)
516+
continue
517+
prefix_to_remove = "model."
518+
new_key = key
519+
if key.startswith(prefix_to_remove):
520+
new_key = new_key.removeprefix(prefix_to_remove)
521+
522+
if "layers" in key:
523+
abstract_key = re.sub(r".(\d+).", ".{}.", key)
524+
layer_num = re.search(r"\d+", key).group(0)
525+
new_key = weight_map.get(abstract_key)
526+
if new_key is None:
527+
continue
528+
new_key = new_key.format(layer_num)
529+
530+
if new_key == key:
531+
continue
532+
533+
if "w1" in key or "w3" in key:
534+
state_dict[new_key] = (
535+
state_dict.pop(key)
536+
.reshape(
537+
config["num_local_experts"],
538+
config["intermediate_size"],
539+
config["hidden_size"],
540+
)
541+
.contiguous()
542+
)
543+
elif "w2" in key:
544+
state_dict[new_key] = (
545+
state_dict.pop(key)
546+
.reshape(
547+
config["num_local_experts"],
548+
config["intermediate_size"],
549+
config["hidden_size"],
550+
)
551+
.permute(0, 2, 1)
552+
.contiguous()
553+
)
554+
elif "gate" in key:
555+
state_dict[new_key] = state_dict.pop(key).contiguous()
556+
else:
557+
state_dict[new_key] = state_dict.pop(key)
558+
for k, v in state_dict.items():
559+
if "layers" in k and "layers.0" not in k:
560+
continue
561+
print(f"The converted key: {k} and value: {v.shape} {v.dtype}")
562+
return state_dict, config
563+
564+
473565
def main(argv) -> None:
474566
"""merge weights"""
475567

@@ -481,6 +573,14 @@ def main(argv) -> None:
481573
quantize_embedding_weight_map = (
482574
gemma_model.GemmaModel.get_quantized_embedding_weight_to_scaler_map()
483575
)
576+
elif FLAGS.model_name == "mixtral":
577+
state_dict, params = _get_mixtral_state_dict(_INPUT_CHECKPOINT_DIR.value)
578+
quantize_linear_weight_map = (
579+
mixtral_model.Transformer.get_quantized_linear_weight_to_scaler_map()
580+
)
581+
quantize_embedding_weight_map = (
582+
mixtral_model.Transformer.get_quantized_embedding_weight_to_scaler_map()
583+
)
484584
else:
485585
state_dict, params = _get_llama_state_dict(_INPUT_CHECKPOINT_DIR.value)
486586
quantize_linear_weight_map = (

default_shardings/mixtral.yaml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
2+
# Sharding config for mixtral
3+
# Sharding should either be an int between 0 and rank - 1
4+
# signifying the axis to shard or -1 / null signifying replicated
5+
6+
7+
freqs_cis : -1 # torch.complex64 (2048, 64)
8+
tok_embeddings.weight : 1 # torch.float32 (vocab_size, 4096)
9+
tok_embeddings.weight_scaler : 0 # torch.bfloat16 (4096,)
10+
layers.*.attention.wo.weight : 1 # torch.int8 (4096, 4096)
11+
layers.*.attention.wo.weight_scaler : -1 # torch.bfloat16 (4096,)
12+
layers.*.attention.wq.weight : 0 # torch.int8 (4096, 4096)
13+
layers.*.attention.wq.weight_scaler : 0 # torch.bfloat16 (4096,)
14+
layers.*.attention.wk.weight : 0 # torch.int8 (4096, 4096)
15+
layers.*.attention.wk.weight_scaler : 0 # torch.bfloat16 (4096,)
16+
layers.*.attention.wv.weight : 0 # torch.int8 (4096, 4096)
17+
layers.*.attention.wv.weight_scaler : 0 # torch.bfloat16 (4096,)
18+
layers.*.attention.wqkv.weight : 0 # torch.int8 (4096, 4096)
19+
layers.*.attention.wqkv.weight_scaler : 0 # torch.bfloat16 (4096,)
20+
layers.*.block_sparse_moe.gate.weight: -1
21+
layers.*.block_sparse_moe.gate.weight_scaler: -1
22+
layers.*.block_sparse_moe.cond_ffn.w1: 1
23+
layers.*.block_sparse_moe.cond_ffn.w1_scaler: 1
24+
layers.*.block_sparse_moe.cond_ffn.w2: 2
25+
layers.*.block_sparse_moe.cond_ffn.w2_scaler: -1
26+
layers.*.block_sparse_moe.cond_ffn.w3: 1
27+
layers.*.block_sparse_moe.cond_ffn.w3_scaler: 1
28+
layers.*.ffn_norm.weight : -1 # torch.float32 (4096,)
29+
layers.*.attention_norm.weight : -1 # torch.float32 (4096,)
30+
norm.weight : -1 # torch.float32 (4096,)
31+
output.weight : 0 # torch.float32 (vocab_size, 4096)
32+
output.weight_scaler : 0 # torch.float32 (4096,)

jetstream_pt/config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,13 @@ def create_engine_from_config_flags():
108108
sharding_file_name = FLAGS.sharding_config
109109
if not sharding_file_name:
110110
sharding_file_name = (
111-
"llama" if FLAGS.model_name.startswith("llama") else "gemma"
111+
"llama"
112+
if FLAGS.model_name.startswith("llama")
113+
else "gemma"
114+
if FLAGS.model_name.startswith("gemma")
115+
else "mixtral"
116+
if FLAGS.model_name.startswith("mixtral")
117+
else None
112118
)
113119
if (
114120
quant_config.enable_weight_quantization

jetstream_pt/engine.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData, QuantizationConfig
3838
from jetstream_pt.third_party.llama import model_exportable as llama_model, model_args
3939
from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model
40+
from jetstream_pt.third_party.mixtral import config as mixtral_config, model as mixtral_model
4041

4142

4243
Mesh = jax.sharding.Mesh
@@ -359,7 +360,6 @@ def _insert_wrap(
359360

360361
start_insert = decode_state.current_position - prefix.seq_len
361362
tokens = decode_state.tokens.at[slot].set(prefix.token)
362-
363363
start_insert = start_insert % self.env.cache_sequence_length
364364
# pos < 0
365365
update_indexes = (
@@ -641,12 +641,17 @@ def _load_from_safetensors(self, path):
641641
def _load_from_state_dict(self, path):
642642
state_dict = torch.load(path, map_location=torch.device("cpu"))
643643
weights = {}
644+
print(f"Loaded keys are : {state_dict.keys()}")
644645
for key, model_weights in self.pt_model.state_dict().items():
646+
if key == "freqs_cis":
647+
continue
645648
assert key in state_dict, f"key: {key} not found"
646-
weights[key] = torchjax.from_torch(state_dict[key])
649+
weights[key] = torch_xla2.tensor.t2j(state_dict[key])
647650
assert tuple(model_weights.shape) == tuple(
648651
weights[key].shape
649652
), f"key: {key} error: {model_weights.shape} != {weights[key].shape}"
653+
654+
weights["freqs_cis"] = torch_xla2.tensor.t2j(self.pt_model.freqs_cis)
650655
return weights
651656

652657
# pylint: disable-next=all
@@ -760,7 +765,7 @@ def create_pytorch_engine(
760765
) -> PyTorchEngine:
761766
"""Returns: The pytorch engine."""
762767

763-
supported_models = ["llama-2", "llama-3", "gemma"]
768+
supported_models = ["llama-2", "llama-3", "gemma", "mixtral"]
764769
if model_name not in supported_models:
765770
raise NotImplementedError(
766771
f"Model name should be one of{','.join(supported_models)}"
@@ -772,7 +777,6 @@ def create_pytorch_engine(
772777
jax.config.update("jax_traceback_filtering", "off")
773778
torch_dtype = torch.bfloat16 if bf16_enable else torch.float32
774779
torch.set_default_dtype(torch_dtype)
775-
776780
checkpoint_format = ""
777781
checkpoint_path = ""
778782

@@ -797,8 +801,14 @@ def create_pytorch_engine(
797801

798802
pt_model = None
799803

804+
sharding_file_name = ""
800805
if not sharding_config:
801-
sharding_file_name = "llama" if model_name.startswith("llama") else "gemma"
806+
if model_name.startswith("llama"):
807+
sharding_file_name = "llama"
808+
elif model_name.startswith("gemma"):
809+
sharding_file_name = "gemma"
810+
elif model_name.startswith("mixtral"):
811+
sharding_file_name = "mixtral"
802812
sharding_config = os.path.join(
803813
"default_shardings", sharding_file_name + ".yaml"
804814
)
@@ -851,6 +861,18 @@ def create_pytorch_engine(
851861
env = JetEngineEnvironment(env_data)
852862
print(f"Enviroment variables: {vars(env)}")
853863
pt_model = gemma_model.GemmaModel(args, env)
864+
elif model_name == "mixtral":
865+
args = mixtral_config.ModelArgs.from_name("Mixtral-8x7B-v0.1")
866+
args.device = "meta"
867+
env_data.cache_shape = (
868+
batch_size,
869+
args.n_local_heads,
870+
max_cache_length,
871+
args.dim // args.n_head,
872+
)
873+
env_data.num_layers = args.n_layer
874+
env = JetEngineEnvironment(env_data)
875+
pt_model = mixtral_model.Transformer(args, env)
854876
else:
855877
raise RuntimeError(f"Model with name {model_name} not found")
856878

jetstream_pt/third_party/llama/model_exportable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,10 @@ def forward(
200200
):
201201
"""
202202
tokens: the input token for decoding
203+
input_pos: the decoding position relative to the start, which is the length of the decoding results
203204
caches: kv caches
204205
mask: causal mask to filter the attention results
205206
start: the starting position for each slot
206-
input_pos: the decoding position relative to the start, which is the length of the decoding results
207207
ragged_batch_index: precomputed batch index for ragged attention
208208
ragged_block_index: precomputed block index for ragged attention
209209
"""

jetstream_pt/third_party/mixtral/__init__.py

Whitespace-only changes.
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# pylint: disable-all
2+
# # Copyright 2024 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Mixtral model config
17+
import dataclasses
18+
from dataclasses import dataclass
19+
20+
21+
def find_multiple(n: int, k: int) -> int:
22+
if n % k == 0:
23+
return n
24+
return n + k - (n % k)
25+
26+
27+
@dataclass
28+
class ModelArgs:
29+
block_size: int = 2048
30+
vocab_size: int = 32000
31+
n_layer: int = 32
32+
n_head: int = 32
33+
dim: int = 4096
34+
intermediate_size: int = None
35+
n_local_heads: int = -1
36+
head_dim: int = 64
37+
rope_base: float = 10000
38+
norm_eps: float = 1e-5
39+
num_experts: int = 8
40+
num_activated_experts: int = 2
41+
device: str = "meta"
42+
43+
def __post_init__(self):
44+
if self.n_local_heads == -1:
45+
self.n_local_heads = self.n_head
46+
if self.intermediate_size is None:
47+
hidden_dim = 4 * self.dim
48+
n_hidden = int(2 * hidden_dim / 3)
49+
self.intermediate_size = find_multiple(n_hidden, 256)
50+
self.head_dim = self.dim // self.n_head
51+
52+
@classmethod
53+
def from_name(cls, name: str):
54+
if name in transformer_configs:
55+
return cls(**transformer_configs[name])
56+
# fuzzy search
57+
config = [
58+
config
59+
for config in transformer_configs
60+
if config in str(name).upper() or config in str(name)
61+
]
62+
assert len(config) == 1, name
63+
return cls(**transformer_configs[config[0]])
64+
65+
66+
transformer_configs = {
67+
"Mixtral-8x7B-v0.1": dict(
68+
block_size=32768,
69+
n_layer=32,
70+
n_head=32,
71+
n_local_heads=8,
72+
dim=4096,
73+
intermediate_size=14336,
74+
rope_base=1000000.0,
75+
num_experts=8,
76+
num_activated_experts=2,
77+
),
78+
}

0 commit comments

Comments
 (0)