Skip to content

Commit 3e4f164

Browse files
authored
[Attention] Attention head quantization strategy (#481)
* refactor Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * reduce diff Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * reduce diff Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * increase num of required observed dims Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove attention head Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove attn head Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * simplify Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * refactor Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * reduce diff Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * increase num of required observed dims Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add tests for attn head Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * reduce diff Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix shapes Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix shapes Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * revert Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 36c6fe1 commit 3e4f164

File tree

6 files changed

+79
-18
lines changed

6 files changed

+79
-18
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def _process_quantization(
330330
inv_perm = torch.argsort(perm)
331331
output = output.index_select(-1, inv_perm)
332332

333-
else: # covers channel, token and tensor strategies
333+
else: # covers tensor, channel, token, and attn_head strategies
334334
if do_quantize:
335335
output = _quantize(
336336
x=x,

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
import logging
17-
from typing import Optional, Tuple
17+
from typing import Optional, Tuple, Union
1818

1919
import torch
2020
from compressed_tensors.quantization import (
@@ -152,7 +152,7 @@ def initialize_qparams(
152152
module: Module,
153153
base_name: str,
154154
quantization_args: QuantizationArgs,
155-
observed_shape: Tuple[int],
155+
observed_shape: Tuple[Union[int, None]],
156156
observed_dtype: torch.dtype,
157157
force_zero_point: bool = True,
158158
):
@@ -234,6 +234,13 @@ def initialize_qparams(
234234
num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy)
235235
expected_shape = (num_rows, num_cols)
236236

237+
elif strategy == QuantizationStrategy.ATTN_HEAD:
238+
# (batch_size, num_attention_heads, seq_len, head_dim)
239+
if len(observed_shape) < 3:
240+
raise ValueError("Attention quant requires at least 3 observed dimensions")
241+
242+
expected_shape = (observed_shape[-3], 1, 1)
243+
237244
else:
238245
assert False, f"Unknown strategy {strategy}"
239246

src/compressed_tensors/quantization/quant_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class QuantizationStrategy(str, Enum):
101101
BLOCK = "block"
102102
TOKEN = "token"
103103
TENSOR_GROUP = "tensor_group"
104+
ATTN_HEAD = "attn_head"
104105

105106

106107
class DynamicType(str, Enum):

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
6565
QuantizationStrategy.TENSOR,
6666
QuantizationStrategy.GROUP,
6767
QuantizationStrategy.TENSOR_GROUP,
68+
QuantizationStrategy.ATTN_HEAD,
6869
):
6970
if (
7071
inputs.strategy == QuantizationStrategy.GROUP

tests/mock_observer.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ def flatten_for_quantization(
7777

7878

7979
def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs):
80+
# value.shape = (num_rows, num_cols)
81+
8082
if args.strategy == QuantizationStrategy.TENSOR:
8183
# (1, 1, num_weight_elems)
8284
return value.reshape((1, 1, -1))
@@ -110,10 +112,15 @@ def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs)
110112
.unsqueeze(0)
111113
)
112114

115+
if args.strategy == QuantizationStrategy.ATTN_HEAD:
116+
raise ValueError("attention head quantization cannot be applied to weights")
117+
113118
assert False, f"Unknown strategy {args.strategy}"
114119

115120

116121
def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationArgs):
122+
# value.shape = (batch_size, seq_len, hidden_dim)
123+
117124
if args.strategy == QuantizationStrategy.TENSOR:
118125
# (batch_size * seq_len, 1, hidden_dim)
119126
return value.reshape((-1, 1, value.size(-1)))
@@ -134,14 +141,18 @@ def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationA
134141
if args.strategy == QuantizationStrategy.BLOCK:
135142
raise ValueError("Block quantization cannot be applied to activations")
136143

144+
if args.strategy == QuantizationStrategy.ATTN_HEAD:
145+
raise ValueError("attention head quantization cannot be applied to linear acts")
146+
137147
assert False, f"Unknown strategy {args.strategy}"
138148

139149

140150
def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationArgs):
151+
# value.shape = (batch_size, num_heads, seq_len, head_dim)
152+
141153
if args.strategy == QuantizationStrategy.TENSOR:
142-
# (batch_size, seq_len, num_heads, head_dim)
143154
# (batch_size * seq_len, 1, num_heads * head_dim)
144-
return value.flatten(0, 1).flatten(-2, -1).unsqueeze(-2)
155+
return value.transpose(1, 2).flatten(0, 1).flatten(-2, -1).unsqueeze(-2)
145156

146157
if args.strategy == QuantizationStrategy.TOKEN:
147158
raise ValueError("Token quantization cannot be applied to attention")
@@ -155,4 +166,8 @@ def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationAr
155166
if args.strategy == QuantizationStrategy.BLOCK:
156167
raise ValueError("Block quantization cannot be applied to attention")
157168

169+
if args.strategy == QuantizationStrategy.ATTN_HEAD:
170+
# (batch_size * seq_len, num_heads, 1, 1, head_dim)
171+
return value.transpose(1, 2).flatten(0, 1).unsqueeze(-2).unsqueeze(-2)
172+
158173
assert False, f"Unknown strategy {args.strategy}"

tests/test_quantization/lifecycle/test_static_lifecycle.py

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -287,45 +287,82 @@ class MockAttention(torch.nn.Module):
287287
strategy="tensor",
288288
),
289289
torch.tensor([0.0]),
290-
torch.tensor([11.0]),
290+
torch.tensor([23.0]),
291291
torch.tensor(
292292
[
293293
[
294-
[[0.0000, 1.4688, 1.4688], [2.9375, 4.4062, 4.4062]],
295-
[[5.8750, 7.3438, 7.3438], [8.8125, 10.2500, 10.2500]],
294+
[
295+
[0.0000, 0.0000, 3.0625, 3.0625],
296+
[3.0625, 6.1250, 6.1250, 6.1250],
297+
[9.1875, 9.1875, 9.1875, 12.2500],
298+
],
299+
[
300+
[12.2500, 12.2500, 15.3125, 15.3125],
301+
[15.3125, 18.3750, 18.3750, 18.3750],
302+
[21.5000, 21.5000, 21.5000, 21.5000],
303+
],
296304
]
297305
]
298306
),
299-
0.19,
307+
0.81,
300308
),
301309
# static token is not supported
302310
# channel is not supported
303311
# group is not supported
304312
# tensor group is not supported
305313
# block is not supported
314+
(
315+
QuantizationArgs(
316+
num_bits=4,
317+
type="int",
318+
symmetric=True,
319+
strategy="attn_head",
320+
),
321+
torch.tensor([[[0.0]], [[12.0]]]),
322+
torch.tensor([[[11.0]], [[23.0]]]),
323+
torch.tensor(
324+
[
325+
[
326+
[
327+
[0.0000, 1.4688, 1.4688, 2.9375],
328+
[4.4062, 4.4062, 5.8750, 7.3438],
329+
[7.3438, 8.8125, 10.2500, 10.2500],
330+
],
331+
[
332+
[12.2500, 12.2500, 15.3125, 15.3125],
333+
[15.3125, 18.3750, 18.3750, 18.3750],
334+
[21.5000, 21.5000, 21.5000, 21.5000],
335+
],
336+
]
337+
]
338+
),
339+
0.55,
340+
),
306341
],
307342
)
308343
def test_static_attention_quantization(
309344
args, exp_min_val, exp_max_val, exp_quant, exp_loss
310345
):
311346
"""
312-
input = tensor([[[[ 0., 1., 2.],
313-
[ 3., 4., 5.]],
347+
input = tensor([[[[ 0., 1., 2., 3.],
348+
[ 4., 5., 6., 7.],
349+
[ 8., 9., 10., 11.]],
314350
315-
[[ 6., 7., 8.],
316-
[ 9., 10., 11.]]]])
351+
[[12., 13., 14., 15.],
352+
[16., 17., 18., 19.],
353+
[20., 21., 22., 23.]]]])
317354
"""
318-
# set up activation (and identity weight)
319-
batch_size, seq_len, num_heads, head_dim = 1, 2, 2, 3
355+
# set up attention
356+
batch_size, num_heads, seq_len, head_dim = 1, 2, 3, 4
320357
input = torch.arange(
321-
(batch_size * seq_len * num_heads * head_dim), dtype=torch.bfloat16
322-
).reshape((batch_size, seq_len, num_heads, head_dim))
358+
(batch_size * num_heads * seq_len * head_dim), dtype=torch.bfloat16
359+
).reshape((batch_size, num_heads, seq_len, head_dim))
323360
attention = MockAttention()
324361

325362
# initialize quantization parameters
326363
scheme = QuantizationScheme(targets=[], input_activations=args)
327364
initialize_qparams(
328-
attention, "k", args, (num_heads, head_dim), observed_dtype=torch.bfloat16
365+
attention, "k", args, (num_heads, None, head_dim), observed_dtype=torch.bfloat16
329366
)
330367
attention.quantization_scheme = scheme
331368
attention.quantization_status = QuantizationStatus.INITIALIZED

0 commit comments

Comments
 (0)