Skip to content

Commit 05237b2

Browse files
committed
add more tests
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 665e7ee commit 05237b2

File tree

3 files changed

+164
-4
lines changed

3 files changed

+164
-4
lines changed

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,11 +289,11 @@ def initialize_attn_qparams(
289289
kv_cache: Optional[QuantizedKVCache] = getattr(module, KV_CACHE_ATTR, None)
290290

291291
if impl is None and kv_cache is None:
292-
raise ValueError("Attention module has quantization scheme but no attached ")
292+
raise ValueError("Attention module has quantization scheme but no attached")
293293

294-
config: PretrainedConfig = getattr(impl, "config", None) or getattr(
295-
kv_cache, "config", None
296-
)
294+
_validate_attention_scheme(scheme)
295+
296+
config: PretrainedConfig = getattr(kv_cache, "config")
297297
head_dim = get_head_dim(config)
298298
observed_shape = (head_dim,) # (batch_size, num_attention_heads, slen, head_dim)
299299
observed_dtype = next(module.parameters()).dtype
@@ -325,3 +325,19 @@ def initialize_attn_qparams(
325325
observed_dtype=observed_dtype,
326326
force_zero_point=force_zero_point,
327327
)
328+
329+
330+
def _validate_attention_scheme(scheme: QuantizationScheme):
331+
if scheme.weights is not None:
332+
raise ValueError(
333+
"Cannot apply weight quantization to attention. "
334+
"Instead, target (q|k|v)_proj"
335+
)
336+
337+
if scheme.input_activations is None:
338+
raise ValueError(
339+
"Cannot apply attention quantization without specifying input activations"
340+
)
341+
342+
if scheme.output_activations is not None:
343+
raise ValueError("Cannot apply output quantization to attention")
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
from compressed_tensors.modeling import (
17+
IMPL_ATTR,
18+
KV_CACHE_ATTR,
19+
QuantizedAttentionImpl,
20+
QuantizedKVCache,
21+
initialize_hooked_attention,
22+
initialize_hooked_kv_cache,
23+
register_key_hook,
24+
register_query_hook,
25+
register_value_hook,
26+
)
27+
from tests.testing_utils import requires_gpu
28+
from transformers import AutoModelForCausalLM
29+
30+
31+
@requires_gpu
32+
def test_attention_cache():
33+
model = AutoModelForCausalLM.from_pretrained(
34+
"nm-testing/llama2.c-stories15M", device_map="cuda"
35+
)
36+
inputs = {key: value.to("cuda") for key, value in model.dummy_inputs.items()}
37+
true_outputs = model(**inputs)
38+
layers = model.model.layers
39+
40+
# check if hooks work
41+
k_called = [False for _ in range(len(layers))]
42+
v_called = [False for _ in range(len(layers))]
43+
44+
# apply kv cache quantization
45+
_apply_kv_cache(model, layers, k_called, v_called)
46+
47+
# check kv cache quantization
48+
outputs = model(**inputs)
49+
assert torch.equal(outputs.logits, true_outputs.logits)
50+
assert all(k_called) and all(v_called)
51+
52+
## apply attention quantization after kv cache quantization ##
53+
54+
# check if hooks work
55+
q_called = [False for _ in range(len(layers))]
56+
k_called = [False for _ in range(len(layers))]
57+
v_called = [False for _ in range(len(layers))]
58+
59+
_apply_attention(model, layers, q_called, k_called, v_called)
60+
outputs = model(**inputs)
61+
assert torch.equal(outputs.logits, true_outputs.logits)
62+
assert all(q_called) and all(k_called) and all(v_called)
63+
64+
65+
def _apply_kv_cache(model, layers, k_called, v_called):
66+
for layer_index, layer in enumerate(layers):
67+
module = layer.self_attn
68+
initialize_hooked_kv_cache(model, module)
69+
assert isinstance(getattr(module, KV_CACHE_ATTR), QuantizedKVCache)
70+
71+
# reapply is no-op
72+
initialize_hooked_kv_cache(model, module)
73+
74+
def k_hook(_module, _states, layer_index=layer_index): # NOTE: capture by value
75+
k_called[layer_index] = True
76+
77+
def v_hook(_module, _states, layer_index=layer_index):
78+
my_index = layer_index
79+
v_called[my_index] = True
80+
81+
register_key_hook(module, k_hook)
82+
register_value_hook(module, v_hook)
83+
84+
85+
def _apply_attention(model, layers, q_called, k_called, v_called):
86+
for layer_index, layer in enumerate(layers):
87+
module = layer.self_attn
88+
initialize_hooked_attention(model, module)
89+
assert isinstance(getattr(module, IMPL_ATTR), QuantizedAttentionImpl)
90+
91+
# reapply is no-op
92+
initialize_hooked_attention(model, module)
93+
94+
def q_hook(_module, _states, layer_index=layer_index):
95+
q_called[layer_index] = True
96+
97+
def k_hook(_module, _states, layer_index=layer_index):
98+
k_called[layer_index] = True
99+
100+
def v_hook(_module, _states, layer_index=layer_index):
101+
v_called[layer_index] = True
102+
103+
register_query_hook(module, q_hook)
104+
register_key_hook(module, k_hook)
105+
register_value_hook(module, v_hook)

tests/test_quantization/lifecycle/test_apply.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,42 @@ def test_multi_apply_quantization_config():
366366
weight_zero_point is not None
367367
and weight_zero_point.shape == torch.Size([1])
368368
)
369+
370+
@requires_accelerate()
371+
def test_apply_kv_cache():
372+
from accelerate import init_empty_weights
373+
374+
with init_empty_weights():
375+
model = AutoModelForCausalLM.from_pretrained("nm-testing/llama2.c-stories15M")
376+
377+
args = QuantizationArgs(num_bits=8, type="float", strategy="tensor")
378+
config = QuantizationConfig(config_groups={}, kv_cache_scheme=args)
379+
380+
apply_quantization_config(model, config)
381+
382+
for layer in model.model.layers:
383+
assert getattr(layer.self_attn, "quantization_scheme").input_activations == args
384+
assert hasattr(layer.self_attn, "k_scale")
385+
assert hasattr(layer.self_attn, "v_scale")
386+
387+
388+
@requires_accelerate()
389+
def test_apply_attention():
390+
from accelerate import init_empty_weights
391+
392+
with init_empty_weights():
393+
model = AutoModelForCausalLM.from_pretrained("nm-testing/llama2.c-stories15M")
394+
395+
scheme = QuantizationScheme(
396+
targets=["LlamaAttention"],
397+
input_activations=QuantizationArgs(num_bits=8, type="float", strategy="tensor"),
398+
)
399+
config = QuantizationConfig(config_groups={"attention": scheme})
400+
401+
apply_quantization_config(model, config)
402+
403+
for layer in model.model.layers:
404+
assert getattr(layer.self_attn, "quantization_scheme") == scheme
405+
assert hasattr(layer.self_attn, "q_scale")
406+
assert hasattr(layer.self_attn, "k_scale")
407+
assert hasattr(layer.self_attn, "v_scale")

0 commit comments

Comments
 (0)