Skip to content

Commit 57e6fcf

Browse files
authored
Fix convert script cannot generate bf16 weights (#104)
fix convert script to generate bf16 weights
1 parent 2880904 commit 57e6fcf

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ Need to manually modify the `config.json` in the checkpoint folder to make it a
7070
export input_ckpt_dir=Original llama weights directory
7171
export output_ckpt_dir=The output directory
7272
export model_name="llama-3" # or "llama-2", "gemma"
73-
export quantize_type="int8_per_channel" # Availabe quantize type: {"int8", "int4"} x {"per_channel", "blockwise"}, setting this will quantize the weights
73+
export quantize_weights=True # Whether to quantize weights
74+
export quantize_type="int8_per_channel" # "quantize_weights" needs to be turned on. Availabe quantize type: {"int8", "int4"} x {"per_channel", "blockwise"}, "int8_per_channel" is the default option if not specified.
7475
python -m convert_checkpoints --model_name=$model_name --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize_type=$quantize_type
7576
```
7677

convert_checkpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def main(argv) -> None:
391391
llama_model.Transformer.get_quantized_embedding_weight_to_scaler_map()
392392
)
393393

394-
if FLAGS.quantize_type:
394+
if FLAGS.quantize_weights:
395395
quantize_num_bits = 8 if "int8" in FLAGS.quantize_type else 4
396396
is_blockwise = "blockwise" in FLAGS.quantize_type
397397
weight_axis = lambda x: 0 if x in quantize_embedding_weight_map else 1

0 commit comments

Comments
 (0)