|
7 | 7 | from pathlib import Path |
8 | 8 | import random |
9 | 9 | import time |
10 | | -import contextlib |
11 | 10 |
|
12 | 11 | # Third Party |
13 | 12 | from aiu_fms_testing_utils.utils import aiu_setup, warmup_model |
|
104 | 103 | type=str, |
105 | 104 | default=None, |
106 | 105 | choices=["bf16", "fp16", "fp32"], |
107 | | - help="If set to one of the choices, overrides the model checkpoint weight format by setting the default pytorch format", |
| 106 | + help="If set to one of the choices, overrides the model checkpoint weight format by setting the default pytorch format. This will break quantized checkpoints.", |
| 107 | +) |
| 108 | +parser.add_argument( |
| 109 | + "--cast_bf16_to_fp16", |
| 110 | + action="store_true", |
| 111 | + help="If set, cast any bf16 weights in the model to fp16 for AIU compiler. Doesn't touch fp32 or quantized", |
| 112 | +) |
| 113 | +parser.add_argument( |
| 114 | + "--cast_fp16_to_bf16", |
| 115 | + action="store_true", |
| 116 | + help="If set, cast any fp16 weights in the model to bf16 for GPU. Doesn't touch fp32 or quantized", |
108 | 117 | ) |
109 | 118 | parser.add_argument( |
110 | 119 | "--compile", |
|
221 | 230 | parser.add_argument( |
222 | 231 | "--attention_type", |
223 | 232 | type=str, |
224 | | - choices=["sdpa", "paged"], |
| 233 | + choices=["sdpa", "paged", "math_fp8", "paged_fp8"], |
225 | 234 | default="sdpa", |
226 | 235 | help="which backend attention to use in mha", |
227 | 236 | ) |
228 | 237 | args = parser.parse_args() |
229 | 238 |
|
230 | | -if args.attention_type == "paged": |
| 239 | +attention_map = { |
| 240 | + "sdpa": "sdpa_causal", |
| 241 | + "paged": "spyre_paged_attn", |
| 242 | + "math_fp8": "math_fp8", |
| 243 | + "paged_fp8": "spyre_paged_attn_fp8", |
| 244 | +} |
| 245 | + |
| 246 | +attn_name = attention_map[args.attention_type] |
| 247 | + |
| 248 | +if "paged" in attn_name: |
231 | 249 | from aiu_fms_testing_utils.utils.paged import generate |
232 | 250 | else: |
233 | 251 | from fms.utils.generation import generate |
234 | 252 |
|
| 253 | +if "fp8" in attn_name: |
| 254 | + import fms_mo.aiu_addons.fp8.fp8_attn |
| 255 | + |
235 | 256 | if args.quantization == "gptq": |
236 | 257 | if "aiu" in args.device_type: |
237 | 258 | try: |
|
329 | 350 | print("must set AIU_WORLD_RANK_0") |
330 | 351 | exit() |
331 | 352 | os.environ.setdefault("FLEX_COMPUTE", "SENTIENT") |
332 | | - os.environ.setdefault("FLEX_DEVICE", "VFIO") |
| 353 | + os.environ.setdefault("FLEX_DEVICE", "PF") |
333 | 354 |
|
334 | 355 | device = torch.device("cpu") |
335 | 356 | else: |
@@ -463,6 +484,38 @@ def select_int8_module( |
463 | 484 | fused_weights=fused_weights, |
464 | 485 | ) |
465 | 486 |
|
| 487 | +### Quantization |
| 488 | + |
| 489 | +# FP8 model checks |
| 490 | +has_fp8_weights = False |
| 491 | +has_bf16_weights = False |
| 492 | +has_fp16_weights = False |
| 493 | +for param in model.parameters(): |
| 494 | + if param.dtype == torch.float8_e4m3fn: |
| 495 | + has_fp8_weights = True |
| 496 | + elif param.dtype == torch.bfloat16: |
| 497 | + has_bf16_weights = True |
| 498 | + elif param.dtype == torch.float16: |
| 499 | + has_fp16_weights = True |
| 500 | + |
| 501 | +if has_fp8_weights: |
| 502 | + if is_aiu_backend and has_bf16_weights and not args.cast_bf16_to_fp16: |
| 503 | + raise ValueError("FP8 checkpoints on AIU with bf16 weights require casting to fp16 using --cast_bf16_to_fp16. Do not use --default_dtype!") |
| 504 | + elif device.type == "cuda" and has_fp16_weights and not args.cast_fp16_to_bf16: |
| 505 | + raise ValueError("FP8 checkpoints on GPU with fp16 weights require casting to bf16 using --cast_fp16_to_bf16. Do not use --default_dtype!") |
| 506 | + |
| 507 | +if args.cast_bf16_to_fp16: |
| 508 | + for name, param in model.named_parameters(): |
| 509 | + if param.dtype == torch.bfloat16: |
| 510 | + if param.max() > torch.finfo(torch.float16).max: |
| 511 | + dprint(f"[WARNING] You are casting param {name} to fp16, which will cause loss of accuracy. You can ignore this warning if this is intended.") |
| 512 | + param.data = param.data.to(dtype=torch.float16) |
| 513 | + |
| 514 | +if args.cast_fp16_to_bf16: |
| 515 | + for param in model.parameters(): |
| 516 | + if param.dtype == torch.float16: |
| 517 | + param.data = param.data.to(dtype=torch.bfloat16) |
| 518 | + |
466 | 519 | if args.quantization in ["gptq", "int8"]: |
467 | 520 | if rank == 0 and args.verbose > 0: |
468 | 521 | dprint("PARAMS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_parameters())) |
@@ -606,7 +659,9 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length): |
606 | 659 | ids = prompts |
607 | 660 | if isinstance(ids, list) and len(ids) == 1: |
608 | 661 | ids = ids[0].unsqueeze(0) |
609 | | - extra_generation_kwargs = None |
| 662 | + extra_generation_kwargs = {} |
| 663 | + |
| 664 | +extra_generation_kwargs["attn_name"] = attn_name |
610 | 665 |
|
611 | 666 |
|
612 | 667 | def print_result(result, result_idx: int): |
@@ -648,19 +703,15 @@ def infer(use_cache, do_sample, warmup): |
648 | 703 | global extra_generation_kwargs |
649 | 704 | if extra_generation_kwargs is None: |
650 | 705 | extra_generation_kwargs = {} |
651 | | - extra_generation_kwargs["only_last_token"] = args.attention_type != "paged" |
652 | | - |
653 | | - if args.device_type == "cpu": |
654 | | - # Bug in 2.3.1 fixed in 2.4.1 for SDPA flash cpu impl when padding too much |
655 | | - extra_generation_kwargs["attn_algorithm"] = "math" |
| 706 | + extra_generation_kwargs["only_last_token"] = "paged" not in attn_name |
656 | 707 |
|
657 | 708 | if not args.no_early_termination and not warmup: |
658 | 709 | eos_token_id = tokenizer.eos_token_id |
659 | 710 | else: |
660 | 711 | eos_token_id = None |
661 | 712 |
|
662 | 713 | attention_specific_kwargs = {} |
663 | | - if args.attention_type == "sdpa": |
| 714 | + if attn_name == "sdpa_causal": |
664 | 715 | attention_specific_kwargs["contiguous_cache"] = True |
665 | 716 |
|
666 | 717 | result = generate( |
@@ -706,7 +757,8 @@ def infer(use_cache, do_sample, warmup): |
706 | 757 | dprint(f"compilation warmup") |
707 | 758 | pt_compile_model_time = time.time() |
708 | 759 | if args.device_type == "aiu": # only run warmup for AIU, no need for senulator |
709 | | - warmup_model(model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, attn_type=args.attention_type, **extra_generation_kwargs) |
| 760 | + for cache in use_cache: |
| 761 | + warmup_model(model, ids, args.max_new_tokens, args.compile_dynamic_sendnn, **extra_generation_kwargs) |
710 | 762 | aiu_warmup_time = time.time() |
711 | 763 | for sample, cache in itertools.product(do_sample, use_cache): |
712 | 764 | infer(cache, sample, True) |
|
0 commit comments