Skip to content

Commit ac6f6d2

Browse files
use itertools.product
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent f94df6f commit ac6f6d2

File tree

1 file changed

+44
-44
lines changed
  • src/llmcompressor/modifiers/awq

1 file changed

+44
-44
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
from itertools import product
23

34
import torch
45
from compressed_tensors.quantization import disable_quantization
@@ -610,60 +611,59 @@ def _compute_best_scale(
610611
w_mean = w_mean.view(-1).to(device)
611612

612613
if self.duo_scaling is None:
613-
# if self.duo_scaling is unsert, perform half the grid search with
614+
# if self.duo_scaling is unset, perform half the grid search with
614615
# duo_scaling off and half with duo_scaling on
615616
n_grid = int(self.n_grid / 2)
616617
duo_scalings = [False, True]
617618
else:
618619
n_grid = self.n_grid
619620
duo_scalings = [self.duo_scaling]
620-
for ratio in range(n_grid):
621-
for duo_scaling in duo_scalings:
622-
# create new scales
623-
ratio = ratio / n_grid
624-
625-
# NOTE: s^-1 * x is fused here, according to paper
626-
if duo_scaling:
627-
scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp(
628-
min=1e-4
629-
)
630-
else:
631-
scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1)
632-
scales = scales / (scales.max() * scales.min()).sqrt()
633-
_scalesview = scales.view(1, -1).to(device)
634-
635-
# avoid scaling values that overflow
636-
scales[torch.isinf(scales)] = 1
637-
scales[torch.isnan(scales)] = 1
638-
639-
# Q(W * s)
640-
for linear in linears2scale:
641-
linear.weight.mul_(_scalesview)
642-
update_offload_parameter(
643-
linear,
644-
"weight",
645-
_pseudo_quantize_tensor(
646-
w=linear.weight.data,
647-
symmetric=self._symmetric,
648-
bit_width=self._num_bits,
649-
group_size=self._group_size,
650-
)[0]
651-
/ _scalesview,
652-
)
621+
for grid_idx, use_duo_scaling in product(range(n_grid), duo_scalings):
622+
# create new scales
623+
ratio = grid_idx / n_grid
624+
625+
# NOTE: s^-1 * x is fused here, according to paper
626+
if use_duo_scaling:
627+
scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp(
628+
min=1e-4
629+
)
630+
else:
631+
scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1)
632+
scales = scales / (scales.max() * scales.min()).sqrt()
633+
_scalesview = scales.view(1, -1).to(device)
634+
635+
# avoid scaling values that overflow
636+
scales[torch.isinf(scales)] = 1
637+
scales[torch.isnan(scales)] = 1
638+
639+
# Q(W * s)
640+
for linear in linears2scale:
641+
linear.weight.mul_(_scalesview)
642+
update_offload_parameter(
643+
linear,
644+
"weight",
645+
_pseudo_quantize_tensor(
646+
w=linear.weight.data,
647+
symmetric=self._symmetric,
648+
bit_width=self._num_bits,
649+
group_size=self._group_size,
650+
)[0]
651+
/ _scalesview,
652+
)
653653

654-
# W * X
655-
int_w_outputs = self._run_samples(parent_module)
654+
# W * X
655+
int_w_outputs = self._run_samples(parent_module)
656656

657-
# compute mean squared error (L2 norm)
658-
loss = self._compute_loss(fp16_outputs, int_w_outputs, device)
657+
# compute mean squared error (L2 norm)
658+
loss = self._compute_loss(fp16_outputs, int_w_outputs, device)
659659

660-
history.append(loss)
661-
if loss < best_error:
662-
best_error = loss
663-
best_ratio = ratio
664-
best_scales = scales.clone()
660+
history.append(loss)
661+
if loss < best_error:
662+
best_error = loss
663+
best_ratio = ratio
664+
best_scales = scales.clone()
665665

666-
parent_module.load_state_dict(org_sd, strict=False)
666+
parent_module.load_state_dict(org_sd, strict=False)
667667

668668
if best_ratio == -1:
669669
logger.debug(history)

0 commit comments

Comments
 (0)