Skip to content

Commit 073fd75

Browse files
committed
Improve warnings
Signed-off-by: Antoni Viros i Martin <aviros@ibm.com>
1 parent e32ea99 commit 073fd75

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

scripts/inference.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -288,10 +288,6 @@
288288
if default_dtype is not None:
289289
torch.set_default_dtype(default_dtype)
290290

291-
if default_dtype is not None or args.cast_bf16_to_fp16:
292-
dprint("You may be casting your checkpoint to a data type with lower dynamic range." \
293-
" This can lead to a loss of model accuracy.")
294-
295291
dprint(f"{args}")
296292

297293
is_aiu_backend = "aiu" in args.device_type
@@ -509,8 +505,10 @@ def select_int8_module(
509505
raise ValueError("FP8 checkpoints on GPU with fp16 weights require casting to bf16 using --cast_fp16_to_bf16. Do not use --default_dtype!")
510506

511507
if args.cast_bf16_to_fp16:
512-
for param in model.parameters():
508+
for name, param in model.named_parameters():
513509
if param.dtype == torch.bfloat16:
510+
if param.max() > torch.finfo(torch.float16).max:
511+
dprint(f"[WARNING] You are casting param {name} to fp16, which will cause loss of accuracy. You can ignore this warning if this is intended.")
514512
param.data = param.data.to(dtype=torch.float16)
515513

516514
if args.cast_fp16_to_bf16:

0 commit comments

Comments
 (0)