Skip to content

Commit e32ea99

Browse files
committed
Improve fp8 quantization handling
Signed-off-by: Antoni Viros i Martin <aviros@ibm.com>
1 parent 3ba5ea0 commit e32ea99

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

scripts/inference.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,18 @@
103103
type=str,
104104
default=None,
105105
choices=["bf16", "fp16", "fp32"],
106-
help="If set to one of the choices, overrides the model checkpoint weight format by setting the default pytorch format",
106+
help="If set to one of the choices, overrides the model checkpoint weight format by setting the default pytorch format. This will break quantized checkpoints.",
107107
)
108108
parser.add_argument(
109109
"--cast_bf16_to_fp16",
110110
action="store_true",
111111
help="If set, cast any bf16 weights in the model to fp16 for AIU compiler. Doesn't touch fp32 or quantized",
112112
)
113+
parser.add_argument(
114+
"--cast_fp16_to_bf16",
115+
action="store_true",
116+
help="If set, cast any fp16 weights in the model to bf16 for GPU. Doesn't touch fp32 or quantized",
117+
)
113118
parser.add_argument(
114119
"--compile",
115120
action="store_true",
@@ -483,11 +488,36 @@ def select_int8_module(
483488
fused_weights=fused_weights,
484489
)
485490

491+
### Quantization
492+
493+
# FP8 model checks
494+
has_fp8_weights = False
495+
has_bf16_weights = False
496+
has_fp16_weights = False
497+
for param in model.parameters():
498+
if param.dtype == torch.float8_e4m3fn:
499+
has_fp8_weights = True
500+
elif param.dtype == torch.bfloat16:
501+
has_bf16_weights = True
502+
elif param.dtype == torch.float16:
503+
has_fp16_weights = True
504+
505+
if has_fp8_weights:
506+
if is_aiu_backend and has_bf16_weights and not args.cast_bf16_to_fp16:
507+
raise ValueError("FP8 checkpoints on AIU with bf16 weights require casting to fp16 using --cast_bf16_to_fp16. Do not use --default_dtype!")
508+
elif device.type == "cuda" and has_fp16_weights and not args.cast_fp16_to_bf16:
509+
raise ValueError("FP8 checkpoints on GPU with fp16 weights require casting to bf16 using --cast_fp16_to_bf16. Do not use --default_dtype!")
510+
486511
if args.cast_bf16_to_fp16:
487512
for param in model.parameters():
488513
if param.dtype == torch.bfloat16:
489514
param.data = param.data.to(dtype=torch.float16)
490515

516+
if args.cast_fp16_to_bf16:
517+
for param in model.parameters():
518+
if param.dtype == torch.float16:
519+
param.data = param.data.to(dtype=torch.bfloat16)
520+
491521
if args.quantization in ["gptq", "int8"]:
492522
if rank == 0 and args.verbose > 0:
493523
dprint("PARAMS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_parameters()))

0 commit comments

Comments
 (0)