|
| 1 | +# Standard |
1 | 2 | import argparse |
| 3 | +from functools import partial |
2 | 4 | import itertools |
3 | 5 | import json |
4 | 6 | import os |
| 7 | +from pathlib import Path |
5 | 8 | import random |
6 | | -import sys |
7 | 9 | import time |
8 | | -from pathlib import Path |
9 | 10 |
|
| 11 | +# Third Party |
10 | 12 | from aiu_fms_testing_utils.utils import aiu_setup |
11 | 13 | from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, local_rank, world_size |
12 | 14 | import numpy as np |
13 | 15 | import torch |
14 | | -import torch._inductor.config |
| 16 | +from torch import distributed as dist |
15 | 17 | from fms.models import get_model, register_model |
16 | 18 | from fms.models.llama import LLaMAConfig, _llama_factory_factory |
17 | | -from fms.utils import fusion, generation, tokenizers |
| 19 | +from fms.utils import generation, tokenizers |
18 | 20 | from fms.utils.generation import generate, pad_input_ids |
19 | | -from torch import distributed as dist |
| 21 | + |
20 | 22 |
|
21 | 23 | # This example script validates the LLaMA implementation by running inference on a couple of prompts. |
22 | 24 | # |
|
59 | 61 | parser.add_argument( |
60 | 62 | "--quantization", |
61 | 63 | type=str, |
62 | | - choices=["gptq"], |
| 64 | + choices=["gptq", "int8"], |
63 | 65 | default=None, |
64 | 66 | help="Type of quantization of the model checkpoint", |
65 | 67 | ) |
| 68 | +parser.add_argument( |
| 69 | + "--int8_weight_per_channel", |
| 70 | + action="store_true", |
| 71 | + help="Enable per-channel weight quantization in INT8 quantized model", |
| 72 | +) |
| 73 | +parser.add_argument( |
| 74 | + "--int8_activ_quant_type", |
| 75 | + default="per_token", |
| 76 | + choices=["per_token", "per_tensor_symm", "per_tensor_asymm"], |
| 77 | + type=str, |
| 78 | + help="Define strategy for activation quantization in INT8 quantized model", |
| 79 | +) |
| 80 | +parser.add_argument( |
| 81 | + "--int8_smoothquant", |
| 82 | + action="store_true", |
| 83 | + help="Enable smoothquant in INT8 quantized model", |
| 84 | +) |
66 | 85 | parser.add_argument( |
67 | 86 | "--tokenizer", |
68 | 87 | type=str, |
|
196 | 215 | args = parser.parse_args() |
197 | 216 |
|
198 | 217 | if args.quantization == "gptq": |
199 | | - GPTQ_ENABLED = True |
200 | | - try: |
201 | | - if "aiu" in args.device_type: |
| 218 | + if "aiu" in args.device_type: |
| 219 | + try: |
202 | 220 | from fms_mo.aiu_addons.gptq import gptq_aiu_adapter, gptq_aiu_linear |
203 | 221 | print("Loaded `aiu_addons` functionalities") |
204 | | - elif args.device_type != "cpu": |
205 | | - raise ValueError(f"Device {args.device_type} unsupported for GPTQ run") |
206 | | - except ImportError as e: |
207 | | - print(f"Failed to import addon packages: {e}") |
208 | | - GPTQ_ENABLED = False |
209 | | - |
210 | | - if not GPTQ_ENABLED: |
211 | | - raise Exception("GPTQ not enabled") |
| 222 | + except: |
| 223 | + raise ImportError("Failed to import GPTQ addons from fms-mo.") |
| 224 | +elif args.quantization == "int8": |
| 225 | + try: |
| 226 | + from fms_mo.aiu_addons.i8i8 import i8i8_aiu_adapter, i8i8_aiu_linear |
| 227 | + print("Loaded `aiu_addons` functionalities") |
| 228 | + except: |
| 229 | + raise ImportError("Failed to import INT8 addons from fms-mo.") |
212 | 230 |
|
213 | 231 | # this is a test model config |
214 | 232 | config = LLaMAConfig( |
|
319 | 337 |
|
320 | 338 | fused_weights = not args.unfuse_weights |
321 | 339 | if args.quantization == "gptq": |
| 340 | + if fused_weights and is_aiu_backend: |
| 341 | + raise ValueError("GPTQ checkpoints on AIU must always run with --unfuse_weights") |
| 342 | + if default_dtype is not None: |
| 343 | + raise ValueError( |
| 344 | + "GPTQ default_dtype must be None to preserve the checkpoint data types." |
| 345 | + ) |
| 346 | + |
322 | 347 | if "aiu" in args.device_type: |
323 | 348 | linear_type = "gptq_aiu" |
324 | 349 | elif args.device_type == "cpu": |
|
352 | 377 | "group_size": group_size, |
353 | 378 | "desc_act": desc_act, |
354 | 379 | } |
355 | | - # [ATTENTION] for GPTQ on AIU, we must always instantiate an unfused |
356 | | - # model, the adapter will take care of converting key/values from |
357 | | - # ckpt into the appropriate form for the model |
358 | | - if fused_weights: |
359 | | - raise ValueError("GPTQ checkpoints on AIU must always run with --unfuse_weights") |
360 | | - default_dtype = None # GPTQ dtype always comes from ckpt, can't be enforced |
| 380 | +elif args.quantization == "int8": |
| 381 | + if fused_weights and is_aiu_backend: |
| 382 | + raise ValueError("INT8 checkpoints on AIU must always run with --unfuse_weights") |
| 383 | + if default_dtype is not None: |
| 384 | + raise ValueError( |
| 385 | + "INT8 default_dtype must be None to preserve the checkpoint data types." |
| 386 | + ) |
| 387 | + |
| 388 | + def select_int8_module( |
| 389 | + module_name: str | None = None, |
| 390 | + smoothquant: bool = True, |
| 391 | + smoothquant_layers: list[str] | None = None, |
| 392 | + ): |
| 393 | + if module_name is None: |
| 394 | + return "int8_aiu" |
| 395 | + smoothquant_on_module = ( |
| 396 | + any([m in module_name for m in smoothquant_layers]) |
| 397 | + if smoothquant_layers is not None |
| 398 | + else True |
| 399 | + ) |
| 400 | + use_smoothquant = smoothquant and smoothquant_on_module |
| 401 | + return "int8_smoothquant_aiu" if use_smoothquant else "int8_aiu" |
| 402 | + |
| 403 | + if args.int8_smoothquant: |
| 404 | + # TODO: consider saving this info into config during quantization |
| 405 | + if any("granite" in p.lower() for p in [args.model_path, args.architecture]): |
| 406 | + smoothquant_layers = ["key", "value", "w1", "wg"] |
| 407 | + elif any("roberta" in p.lower() for p in [args.model_path, args.architecture]): |
| 408 | + smoothquant_layers = ["query", "key", "value", "w1"] |
| 409 | + else: |
| 410 | + raise NotImplementedError( |
| 411 | + "INT8 architecture does not support smoothquant." |
| 412 | + ) |
| 413 | + else: |
| 414 | + smoothquant_layers = [] |
| 415 | + |
| 416 | + linear_config = { |
| 417 | + "linear_type": partial( |
| 418 | + select_int8_module, |
| 419 | + smoothquant = args.int8_smoothquant, |
| 420 | + smoothquant_layers = smoothquant_layers, |
| 421 | + ), |
| 422 | + "weight_per_channel": args.int8_weight_per_channel, |
| 423 | + "activ_quant_type": args.int8_activ_quant_type, |
| 424 | + } |
361 | 425 | else: |
362 | 426 | linear_config = {"linear_type": "torch_linear"} |
363 | 427 |
|
|
381 | 445 | fused_weights=fused_weights, |
382 | 446 | ) |
383 | 447 |
|
384 | | -if args.quantization == "gptq": |
| 448 | +if args.quantization in ["gptq", "int8"]: |
385 | 449 | if rank == 0 and args.verbose > 0: |
386 | 450 | dprint("PARAMS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_parameters())) |
387 | 451 | dprint("BUFFERS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_buffers())) |
388 | 452 | dprint("="*60 + "\n") |
389 | 453 | if args.architecture == "llama": |
390 | | - dprint("[NOTE] It's OK for unused keys to contain bias and rotary embeddings, in GPTQ LLaMA models") |
| 454 | + dprint("[NOTE] In Llama models, it's OK for bias and rotary embeddings to be marked as unused keys.") |
391 | 455 | dprint(model) |
392 | 456 | dprint("="*60 + "\n") |
393 | 457 |
|
@@ -522,6 +586,8 @@ def truncate_prompts_to_max_length(prompts, max_len, max_allowed_length): |
522 | 586 | ids, extra_generation_kwargs = pad_input_ids(prompts, min_pad_length=padding_length) |
523 | 587 | else: |
524 | 588 | ids = prompts |
| 589 | + if isinstance(ids, list) and len(ids) == 1: |
| 590 | + ids = ids[0].unsqueeze(0) |
525 | 591 | extra_generation_kwargs = None |
526 | 592 |
|
527 | 593 |
|
|
0 commit comments