|
103 | 103 | type=str, |
104 | 104 | default=None, |
105 | 105 | choices=["bf16", "fp16", "fp32"], |
106 | | - help="If set to one of the choices, overrides the model checkpoint weight format by setting the default pytorch format", |
| 106 | + help="If set to one of the choices, overrides the model checkpoint weight format by setting the default pytorch format. This will break quantized checkpoints.", |
107 | 107 | ) |
108 | 108 | parser.add_argument( |
109 | 109 | "--cast_bf16_to_fp16", |
110 | 110 | action="store_true", |
111 | 111 | help="If set, cast any bf16 weights in the model to fp16 for AIU compiler. Doesn't touch fp32 or quantized", |
112 | 112 | ) |
| 113 | +parser.add_argument( |
| 114 | + "--cast_fp16_to_bf16", |
| 115 | + action="store_true", |
| 116 | + help="If set, cast any fp16 weights in the model to bf16 for GPU. Doesn't touch fp32 or quantized", |
| 117 | +) |
113 | 118 | parser.add_argument( |
114 | 119 | "--compile", |
115 | 120 | action="store_true", |
@@ -483,11 +488,36 @@ def select_int8_module( |
483 | 488 | fused_weights=fused_weights, |
484 | 489 | ) |
485 | 490 |
|
| 491 | +### Quantization |
| 492 | + |
| 493 | +# FP8 model checks |
| 494 | +has_fp8_weights = False |
| 495 | +has_bf16_weights = False |
| 496 | +has_fp16_weights = False |
| 497 | +for param in model.parameters(): |
| 498 | + if param.dtype == torch.float8_e4m3fn: |
| 499 | + has_fp8_weights = True |
| 500 | + elif param.dtype == torch.bfloat16: |
| 501 | + has_bf16_weights = True |
| 502 | + elif param.dtype == torch.float16: |
| 503 | + has_fp16_weights = True |
| 504 | + |
| 505 | +if has_fp8_weights: |
| 506 | + if is_aiu_backend and has_bf16_weights and not args.cast_bf16_to_fp16: |
| 507 | + raise ValueError("FP8 checkpoints on AIU with bf16 weights require casting to fp16 using --cast_bf16_to_fp16. Do not use --default_dtype!") |
| 508 | + elif device.type == "cuda" and has_fp16_weights and not args.cast_fp16_to_bf16: |
| 509 | + raise ValueError("FP8 checkpoints on GPU with fp16 weights require casting to bf16 using --cast_fp16_to_bf16. Do not use --default_dtype!") |
| 510 | + |
486 | 511 | if args.cast_bf16_to_fp16: |
487 | 512 | for param in model.parameters(): |
488 | 513 | if param.dtype == torch.bfloat16: |
489 | 514 | param.data = param.data.to(dtype=torch.float16) |
490 | 515 |
|
| 516 | +if args.cast_fp16_to_bf16: |
| 517 | + for param in model.parameters(): |
| 518 | + if param.dtype == torch.float16: |
| 519 | + param.data = param.data.to(dtype=torch.bfloat16) |
| 520 | + |
491 | 521 | if args.quantization in ["gptq", "int8"]: |
492 | 522 | if rank == 0 and args.verbose > 0: |
493 | 523 | dprint("PARAMS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_parameters())) |
|
0 commit comments