File tree Expand file tree Collapse file tree 1 file changed +3
-5
lines changed Expand file tree Collapse file tree 1 file changed +3
-5
lines changed Original file line number Diff line number Diff line change 288288if default_dtype is not None :
289289 torch .set_default_dtype (default_dtype )
290290
291- if default_dtype is not None or args .cast_bf16_to_fp16 :
292- dprint ("You may be casting your checkpoint to a data type with lower dynamic range." \
293- " This can lead to a loss of model accuracy." )
294-
295291dprint (f"{ args } " )
296292
297293is_aiu_backend = "aiu" in args .device_type
@@ -509,8 +505,10 @@ def select_int8_module(
509505 raise ValueError ("FP8 checkpoints on GPU with fp16 weights require casting to bf16 using --cast_fp16_to_bf16. Do not use --default_dtype!" )
510506
511507if args .cast_bf16_to_fp16 :
512- for param in model .parameters ():
508+ for name , param in model .named_parameters ():
513509 if param .dtype == torch .bfloat16 :
510+ if param .max () > torch .finfo (torch .float16 ).max :
511+ dprint (f"[WARNING] You are casting param { name } to fp16, which will cause loss of accuracy. You can ignore this warning if this is intended." )
514512 param .data = param .data .to (dtype = torch .float16 )
515513
516514if args .cast_fp16_to_bf16 :
You can’t perform that action at this time.
0 commit comments