|
1 | 1 | import inspect |
| 2 | +from itertools import product |
2 | 3 |
|
3 | 4 | import torch |
4 | 5 | from compressed_tensors.quantization import disable_quantization |
@@ -610,60 +611,59 @@ def _compute_best_scale( |
610 | 611 | w_mean = w_mean.view(-1).to(device) |
611 | 612 |
|
612 | 613 | if self.duo_scaling is None: |
613 | | - # if self.duo_scaling is unsert, perform half the grid search with |
| 614 | + # if self.duo_scaling is unset, perform half the grid search with |
614 | 615 | # duo_scaling off and half with duo_scaling on |
615 | 616 | n_grid = int(self.n_grid / 2) |
616 | 617 | duo_scalings = [False, True] |
617 | 618 | else: |
618 | 619 | n_grid = self.n_grid |
619 | 620 | duo_scalings = [self.duo_scaling] |
620 | | - for ratio in range(n_grid): |
621 | | - for duo_scaling in duo_scalings: |
622 | | - # create new scales |
623 | | - ratio = ratio / n_grid |
624 | | - |
625 | | - # NOTE: s^-1 * x is fused here, according to paper |
626 | | - if duo_scaling: |
627 | | - scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp( |
628 | | - min=1e-4 |
629 | | - ) |
630 | | - else: |
631 | | - scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1) |
632 | | - scales = scales / (scales.max() * scales.min()).sqrt() |
633 | | - _scalesview = scales.view(1, -1).to(device) |
634 | | - |
635 | | - # avoid scaling values that overflow |
636 | | - scales[torch.isinf(scales)] = 1 |
637 | | - scales[torch.isnan(scales)] = 1 |
638 | | - |
639 | | - # Q(W * s) |
640 | | - for linear in linears2scale: |
641 | | - linear.weight.mul_(_scalesview) |
642 | | - update_offload_parameter( |
643 | | - linear, |
644 | | - "weight", |
645 | | - _pseudo_quantize_tensor( |
646 | | - w=linear.weight.data, |
647 | | - symmetric=self._symmetric, |
648 | | - bit_width=self._num_bits, |
649 | | - group_size=self._group_size, |
650 | | - )[0] |
651 | | - / _scalesview, |
652 | | - ) |
| 621 | + for grid_idx, use_duo_scaling in product(range(n_grid), duo_scalings): |
| 622 | + # create new scales |
| 623 | + ratio = grid_idx / n_grid |
| 624 | + |
| 625 | + # NOTE: s^-1 * x is fused here, according to paper |
| 626 | + if use_duo_scaling: |
| 627 | + scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp( |
| 628 | + min=1e-4 |
| 629 | + ) |
| 630 | + else: |
| 631 | + scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1) |
| 632 | + scales = scales / (scales.max() * scales.min()).sqrt() |
| 633 | + _scalesview = scales.view(1, -1).to(device) |
| 634 | + |
| 635 | + # avoid scaling values that overflow |
| 636 | + scales[torch.isinf(scales)] = 1 |
| 637 | + scales[torch.isnan(scales)] = 1 |
| 638 | + |
| 639 | + # Q(W * s) |
| 640 | + for linear in linears2scale: |
| 641 | + linear.weight.mul_(_scalesview) |
| 642 | + update_offload_parameter( |
| 643 | + linear, |
| 644 | + "weight", |
| 645 | + _pseudo_quantize_tensor( |
| 646 | + w=linear.weight.data, |
| 647 | + symmetric=self._symmetric, |
| 648 | + bit_width=self._num_bits, |
| 649 | + group_size=self._group_size, |
| 650 | + )[0] |
| 651 | + / _scalesview, |
| 652 | + ) |
653 | 653 |
|
654 | | - # W * X |
655 | | - int_w_outputs = self._run_samples(parent_module) |
| 654 | + # W * X |
| 655 | + int_w_outputs = self._run_samples(parent_module) |
656 | 656 |
|
657 | | - # compute mean squared error (L2 norm) |
658 | | - loss = self._compute_loss(fp16_outputs, int_w_outputs, device) |
| 657 | + # compute mean squared error (L2 norm) |
| 658 | + loss = self._compute_loss(fp16_outputs, int_w_outputs, device) |
659 | 659 |
|
660 | | - history.append(loss) |
661 | | - if loss < best_error: |
662 | | - best_error = loss |
663 | | - best_ratio = ratio |
664 | | - best_scales = scales.clone() |
| 660 | + history.append(loss) |
| 661 | + if loss < best_error: |
| 662 | + best_error = loss |
| 663 | + best_ratio = ratio |
| 664 | + best_scales = scales.clone() |
665 | 665 |
|
666 | | - parent_module.load_state_dict(org_sd, strict=False) |
| 666 | + parent_module.load_state_dict(org_sd, strict=False) |
667 | 667 |
|
668 | 668 | if best_ratio == -1: |
669 | 669 | logger.debug(history) |
|
0 commit comments