Skip to content

Commit fe8dbde

Browse files
authored
Remove JSON config mangling for Gemma ckpt (#124)
update gemma convert
1 parent 8a125b6 commit fe8dbde

File tree

2 files changed

+4
-10
lines changed

2 files changed

+4
-10
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,12 @@ the tokenizer that we will use.
5959
Please sign agreement on Huggingface website to access Gemma checkpoints. Download Gemma PyTorch checkpoint using huggingface-cli. Gemma Tokenizer is included in the checkpoint.
6060

6161
```bash
62+
# Install huggingface-cli and login if it's not set up.
63+
pip install -U "huggingface_hub[cli]"
64+
huggingface-cli login
6265
huggingface-cli download google/gemma-7b-pytorch --local-dir $input_ckpt_dir
6366
```
6467

65-
Need to manually modify the `config.json` in the checkpoint folder to make it a valid JSON file. (Replace `'` with `"`, remove the excessive `,` after the last item in the JSON object)
66-
6768
## Mixtral
6869
### Get Mixtral Checkpoint from HuggingFace
6970

convert_checkpoints.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -428,21 +428,14 @@ def _get_llama_state_dict(input_ckpt_dir):
428428
return state_dict, params
429429

430430

431-
def fix_json(text):
432-
text = text.replace("'", '"')
433-
lines = text.split("\n")
434-
lines[-3] = lines[-3].replace(",", "")
435-
return "\n".join(lines)
436-
437-
438431
def _get_gemma_state_dict(input_ckpt_dir):
439432
ckpt_file = list(input_ckpt_dir.glob("*.ckpt"))
440433
assert len(ckpt_file) == 1, "only expect 1 ckpt file for Gemma model."
441434
ckpt_file = ckpt_file[0]
442435
state_dict = torch.load(str(ckpt_file), map_location=torch.device("cpu"))[
443436
"model_state_dict"
444437
]
445-
config_text = fix_json((input_ckpt_dir / "config.json").read_text())
438+
config_text = (input_ckpt_dir / "config.json").read_text()
446439
model_config = json.loads(config_text)
447440
for key in list(state_dict.keys()):
448441
if state_dict[key].dtype.is_complex and _OUTPUT_SAFETENSORS.value:

0 commit comments

Comments
 (0)