|
72 | 72 | _QUANTIZE = flags.DEFINE_bool( |
73 | 73 | "quantize", False, "When set to true, produces quantized weights" |
74 | 74 | ) |
| 75 | +_MODEL_TYPE = flags.DEFINE_string("model_name", "llama", "Type of the model.") |
75 | 76 |
|
76 | 77 | # ParallelEmbedding is col partitioned across the shards. |
77 | 78 | # ColumnParallelLinear is row partitioned across shards due to transpose. |
@@ -403,16 +404,71 @@ def merge_weights( |
403 | 404 | print(f"Export outputs takes {end - start} seconds") |
404 | 405 |
|
405 | 406 |
|
| 407 | +def convert_hf_gemma_weights( |
| 408 | + input_ckpt_dir: epath.Path, output_ckpt_dir: epath.Path |
| 409 | +): |
| 410 | + """Convert gemma weights from Huggingface to be compatible with JetStream |
| 411 | + 1. Map attention weights to new names. |
| 412 | + 2. Split qkv fusion. |
| 413 | + """ |
| 414 | + ckpt_file = list(input_ckpt_dir.glob("*.ckpt")) |
| 415 | + assert len(ckpt_file) == 1, "only expect 1 ckpt file for Gemma model." |
| 416 | + ckpt_file = ckpt_file[0] |
| 417 | + state_dict = torch.load(ckpt_file, map_location=torch.device("cpu"))[ |
| 418 | + "model_state_dict" |
| 419 | + ] |
| 420 | + model_config = json.loads((input_ckpt_dir / "config.json").read_text()) |
| 421 | + for key in list(state_dict.keys()): |
| 422 | + if state_dict[key].dtype.is_complex and _OUTPUT_SAFETENSORS.value: |
| 423 | + assert ( |
| 424 | + key == "freqs_cis" |
| 425 | + ), "Only expect key 'freqs_cis' in the state_dict has complex dtype." |
| 426 | + # Remove "freqs_cis" since it has complex dtype, and safetensor doesn't support it. |
| 427 | + # The "freqs_cis" will be reconstructed when it's loaded by inference engine. |
| 428 | + state_dict.pop(key) |
| 429 | + continue |
| 430 | + prefix_to_remove = "model." |
| 431 | + new_key = key |
| 432 | + if key.startswith(prefix_to_remove): |
| 433 | + new_key = new_key.removeprefix(prefix_to_remove) |
| 434 | + if "qkv_proj" in key: |
| 435 | + q_dim = model_config["num_attention_heads"] * model_config["head_dim"] |
| 436 | + kv_dim = model_config["num_key_value_heads"] * model_config["head_dim"] |
| 437 | + qkv = state_dict.pop(key) |
| 438 | + q, k, v = qkv.split( |
| 439 | + [ |
| 440 | + q_dim, |
| 441 | + kv_dim, |
| 442 | + kv_dim, |
| 443 | + ], |
| 444 | + dim=0, |
| 445 | + ) |
| 446 | + state_dict[new_key.replace("qkv_proj", "wq")] = q |
| 447 | + state_dict[new_key.replace("qkv_proj", "wk")] = k |
| 448 | + state_dict[new_key.replace("qkv_proj", "wv")] = v |
| 449 | + continue |
| 450 | + if "o_proj" in key: |
| 451 | + new_key = new_key.replace("o_proj", "wo") |
| 452 | + if new_key != key: |
| 453 | + state_dict[new_key] = state_dict.pop(key) |
| 454 | + _export_to_local(output_ckpt_dir, model_config, state_dict) |
| 455 | + |
| 456 | + |
406 | 457 | def main(argv: Sequence[str]) -> None: |
407 | 458 | """convert checkpoint main function""" |
408 | 459 | if len(argv) > 1: |
409 | 460 | raise app.UsageError("Too many command-line arguments.") |
410 | | - merge_weights( |
411 | | - _INPUT_CHECKPOINT_DIR.value, |
412 | | - _OUTPUT_CHECKPOINT_DIR.value, |
413 | | - _MINIMIZE_MEMORY_FOOTPRINT.value, |
414 | | - _ENABLE_FLOAT32.value, |
415 | | - ) |
| 461 | + if "gemma" in _MODEL_TYPE.value: |
| 462 | + convert_hf_gemma_weights( |
| 463 | + _INPUT_CHECKPOINT_DIR.value, _OUTPUT_CHECKPOINT_DIR.value |
| 464 | + ) |
| 465 | + else: |
| 466 | + merge_weights( |
| 467 | + _INPUT_CHECKPOINT_DIR.value, |
| 468 | + _OUTPUT_CHECKPOINT_DIR.value, |
| 469 | + _MINIMIZE_MEMORY_FOOTPRINT.value, |
| 470 | + _ENABLE_FLOAT32.value, |
| 471 | + ) |
416 | 472 |
|
417 | 473 |
|
418 | 474 | if __name__ == "__main__": |
|
0 commit comments