Skip to content

Commit 811d718

Browse files
authored
Support converting hf gemma weights (#71)
* add convert hf gemma weights
1 parent 93c8f8d commit 811d718

File tree

7 files changed

+123
-31
lines changed

7 files changed

+123
-31
lines changed

README.md

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,46 +43,65 @@ NOTE: the above script will export PYTHONPATH, so sourcing will make it to take
4343

4444
# Download and convert weights
4545

46-
## Get official llama weights from meta-llama
46+
## LLaMA
47+
### Get official llama weights from meta-llama
4748

4849
Following instructions here: https://github.com/meta-llama/llama#download
4950

5051
After you have downloaded the weights, it will also download a `tokenizer.model` file that is
5152
the tokenizer that we will use.
5253

54+
## Gemma
55+
### Get Gemma Checkpoint from HuggingFace
56+
57+
Please sign agreement on Huggingface website to access Gemma checkpoints. Download Gemma PyTorch checkpoint using huggingface-cli. Gemma Tokenizer is included in the checkpoint.
58+
59+
```bash
60+
huggingface-cli download google/gemma-7b-pytorch --local-dir $input_ckpt_dir
61+
```
62+
63+
Need to manually modify the `config.json` in the checkpoint folder to make it a valid JSON file. (Replace `'` with `"`, remove the excessive `,` after the last item in the JSON object)
64+
5365
## Run weight safetensor convert
5466

5567
```bash
5668
export input_ckpt_dir=Original llama weights directory
5769
export output_ckpt_dir=The output directory
5870
export quantize=True #whether to quantize
59-
python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
71+
export model_name="llama-2" # or "gemma"
72+
python -m convert_checkpoints --model_name=$model_name --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
6073
```
6174

6275

6376
# Local run
6477

6578
Set tokenizer path
6679
```bash
67-
export tokenizer_path=tokenizer model file path from meta-llama
80+
export tokenizer_path=tokenizer model file path
6881
```
6982

7083
## Llama 7b
7184
```bash
72-
python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path
85+
python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
7386
```
7487

7588
## Llama 13b
7689
```bash
77-
python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path
90+
python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
91+
```
92+
93+
94+
## Gemma 7b
95+
```bash
96+
python run_interactive.py --model_name=$model_name --size=7b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
7897
```
7998

8099

81100
# Run the server
82101
NOTE: the `--platform=tpu=8` need to specify number of tpu devices (which is 4 for v4-8 and 8 for v5light-8`)
83102

84103
```bash
85-
python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8
104+
python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8 --model=$model_name
86105
```
87106
Now you can fire gRPC to it
88107

convert_checkpoints.py

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
_QUANTIZE = flags.DEFINE_bool(
7373
"quantize", False, "When set to true, produces quantized weights"
7474
)
75+
_MODEL_TYPE = flags.DEFINE_string("model_name", "llama", "Type of the model.")
7576

7677
# ParallelEmbedding is col partitioned across the shards.
7778
# ColumnParallelLinear is row partitioned across shards due to transpose.
@@ -403,16 +404,71 @@ def merge_weights(
403404
print(f"Export outputs takes {end - start} seconds")
404405

405406

407+
def convert_hf_gemma_weights(
408+
input_ckpt_dir: epath.Path, output_ckpt_dir: epath.Path
409+
):
410+
"""Convert gemma weights from Huggingface to be compatible with JetStream
411+
1. Map attention weights to new names.
412+
2. Split qkv fusion.
413+
"""
414+
ckpt_file = list(input_ckpt_dir.glob("*.ckpt"))
415+
assert len(ckpt_file) == 1, "only expect 1 ckpt file for Gemma model."
416+
ckpt_file = ckpt_file[0]
417+
state_dict = torch.load(ckpt_file, map_location=torch.device("cpu"))[
418+
"model_state_dict"
419+
]
420+
model_config = json.loads((input_ckpt_dir / "config.json").read_text())
421+
for key in list(state_dict.keys()):
422+
if state_dict[key].dtype.is_complex and _OUTPUT_SAFETENSORS.value:
423+
assert (
424+
key == "freqs_cis"
425+
), "Only expect key 'freqs_cis' in the state_dict has complex dtype."
426+
# Remove "freqs_cis" since it has complex dtype, and safetensor doesn't support it.
427+
# The "freqs_cis" will be reconstructed when it's loaded by inference engine.
428+
state_dict.pop(key)
429+
continue
430+
prefix_to_remove = "model."
431+
new_key = key
432+
if key.startswith(prefix_to_remove):
433+
new_key = new_key.removeprefix(prefix_to_remove)
434+
if "qkv_proj" in key:
435+
q_dim = model_config["num_attention_heads"] * model_config["head_dim"]
436+
kv_dim = model_config["num_key_value_heads"] * model_config["head_dim"]
437+
qkv = state_dict.pop(key)
438+
q, k, v = qkv.split(
439+
[
440+
q_dim,
441+
kv_dim,
442+
kv_dim,
443+
],
444+
dim=0,
445+
)
446+
state_dict[new_key.replace("qkv_proj", "wq")] = q
447+
state_dict[new_key.replace("qkv_proj", "wk")] = k
448+
state_dict[new_key.replace("qkv_proj", "wv")] = v
449+
continue
450+
if "o_proj" in key:
451+
new_key = new_key.replace("o_proj", "wo")
452+
if new_key != key:
453+
state_dict[new_key] = state_dict.pop(key)
454+
_export_to_local(output_ckpt_dir, model_config, state_dict)
455+
456+
406457
def main(argv: Sequence[str]) -> None:
407458
"""convert checkpoint main function"""
408459
if len(argv) > 1:
409460
raise app.UsageError("Too many command-line arguments.")
410-
merge_weights(
411-
_INPUT_CHECKPOINT_DIR.value,
412-
_OUTPUT_CHECKPOINT_DIR.value,
413-
_MINIMIZE_MEMORY_FOOTPRINT.value,
414-
_ENABLE_FLOAT32.value,
415-
)
461+
if "gemma" in _MODEL_TYPE.value:
462+
convert_hf_gemma_weights(
463+
_INPUT_CHECKPOINT_DIR.value, _OUTPUT_CHECKPOINT_DIR.value
464+
)
465+
else:
466+
merge_weights(
467+
_INPUT_CHECKPOINT_DIR.value,
468+
_OUTPUT_CHECKPOINT_DIR.value,
469+
_MINIMIZE_MEMORY_FOOTPRINT.value,
470+
_ENABLE_FLOAT32.value,
471+
)
416472

417473

418474
if __name__ == "__main__":

jetstream_pt/engine.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -552,13 +552,35 @@ def _load_from_safetensors(self, path):
552552

553553
return weights
554554

555+
def _load_from_state_dict(self, path):
556+
state_dict = torch.load(path, map_location=torch.device("cpu"))
557+
weights = {}
558+
for key, model_weights in self.pt_model.state_dict().items():
559+
assert key in state_dict, f"key: {key} not found"
560+
arr = jax.device_put(
561+
torch_xla2.tensor.t2j(state_dict[key]), self.env.sharding_by_name(key)
562+
)
563+
assert tuple(model_weights.shape) == tuple(
564+
arr.shape
565+
), f"key: {key} error: {model_weights.shape} != {arr.shape}"
566+
weights[key] = arr
567+
568+
for k, v in weights.items():
569+
if k.startswith("layers") and not k.startswith("layers.0"):
570+
continue
571+
print(f"Name: {k}, shape: {v.shape} x {v.dtype}")
572+
573+
return weights
574+
555575
# pylint: disable-next=all
556576
def load_params(self) -> Params:
557577
# We want to fix this: load from files
558578
with jax.default_device(self.colocated_cpus):
559579
if self.env.checkpoint_path:
560580
if self.env.checkpoint_format == "safetensors":
561581
return self._load_from_safetensors(self.env.checkpoint_path)
582+
elif self.env.checkpoint_format == "state_dict":
583+
return self._load_from_state_dict(self.env.checkpoint_path)
562584
else:
563585
jax_weights = self._make_state_dict_jax(self.pt_model.state_dict())
564586
jax_weights = {
@@ -643,7 +665,7 @@ def create_pytorch_engine(
643665
) -> PyTorchEngine:
644666
"""Returns: The pytorch engine."""
645667

646-
supported_models = ["llama-2", "llama-3"]
668+
supported_models = ["llama-2", "llama-3", "gemma"]
647669
if model_name not in supported_models:
648670
raise NotImplementedError(
649671
f"Model name should be one of{','.join(supported_models)}"
@@ -664,10 +686,9 @@ def create_pytorch_engine(
664686
elif ".safetensors" in ckpt_path:
665687
checkpoint_format = "safetensors"
666688
checkpoint_path = ckpt_path
667-
elif ".pth" in ckpt_path:
668-
raise NotImplementedError(
669-
"Loading from Pytorch raw checkpoint is not supported!"
670-
)
689+
elif ".pth" in ckpt_path or ".ckpt" in ckpt_path:
690+
checkpoint_format = "state_dict"
691+
checkpoint_path = ckpt_path
671692
else:
672693
path = epath.Path(ckpt_path) if ckpt_path and ckpt_path is not None else ""
673694
if not path.exists():

jetstream_pt/layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def forward(self, input):
4545

4646
class WeightOnlyInt8Linear(torch.nn.Module):
4747

48-
def __init__(self, in_features, out_features, bias, device):
48+
def __init__(self, in_features, out_features, bias=None, device=None):
4949
super().__init__()
5050
self.in_features = in_features
5151
self.out_features = out_features

jetstream_pt/third_party/gemma/model.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,15 @@ def __init__(
7676
if env.enable_weight_quantization
7777
else torch.nn.Linear
7878
)
79-
self.gate_proj = Linear(hidden_size, intermediate_size, device)
80-
self.up_proj = Linear(hidden_size, intermediate_size, device)
81-
self.down_proj = Linear(intermediate_size, hidden_size, device)
79+
self.gate_proj = Linear(
80+
hidden_size, intermediate_size, bias=False, device=device
81+
)
82+
self.up_proj = Linear(
83+
hidden_size, intermediate_size, bias=False, device=device
84+
)
85+
self.down_proj = Linear(
86+
intermediate_size, hidden_size, bias=False, device=device
87+
)
8288

8389
def forward(self, x):
8490
gate = self.gate_proj(x)

run_interactive.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,6 @@
7272
_MAX_CACHE_LENGTH = flags.DEFINE_integer(
7373
"max_cache_length", 1024, "kv_cache_quantize"
7474
)
75-
_MODEL_NAME = flags.DEFINE_string(
76-
"model",
77-
"llama-2",
78-
"name of the model. Supported options are llama-2 and llama-3",
79-
)
8075
_SHARDING_CONFIG = flags.DEFINE_string(
8176
"sharding_config", "", "config file for sharding"
8277
)
@@ -98,7 +93,6 @@ def create_engine():
9893
param_size=_SIZE.value,
9994
context_length=_CONTEXT_LENGTH.value,
10095
batch_size=_BATCH_SIZE.value,
101-
model_name=_MODEL_NAME.value,
10296
quantize_weights=_QUANTIZE_WEIGHTS.value,
10397
quantize_kv=_QUANTIZE_KV_CACHE.value,
10498
max_cache_length=_MAX_CACHE_LENGTH.value,

run_server.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,6 @@
8989
_SHARDING_CONFIG = flags.DEFINE_string(
9090
"sharding_config", "", "config file for sharding"
9191
)
92-
_MODEL_NAME = flags.DEFINE_string(
93-
"model_name", "llama-2", "model name, defaults to llama-2"
94-
)
9592

9693

9794
# pylint: disable-next=all
@@ -119,7 +116,6 @@ def main(argv: Sequence[str]):
119116
quantize_kv=_QUANTIZE_KV_CACHE.value,
120117
max_cache_length=_MAX_CACHE_LENGTH.value,
121118
sharding_config=sharding_config_path,
122-
model_name=_MODEL_NAME.value,
123119
)
124120
server_config = ServerConfig(
125121
interleaved_slices=(_PLATFORM.value,),

0 commit comments

Comments
 (0)