|
28 | 28 | initialize_module_for_quantization, |
29 | 29 | is_attention_module, |
30 | 30 | ) |
| 31 | +from compressed_tensors.quantization.quant_args import QuantizationArgs |
31 | 32 | from compressed_tensors.quantization.quant_config import ( |
32 | 33 | QuantizationConfig, |
33 | 34 | QuantizationStatus, |
@@ -128,21 +129,11 @@ def apply_quantization_config( |
128 | 129 | # force zero points during initialization |
129 | 130 | force_zero_point = config.quantization_status != QuantizationStatus.COMPRESSED |
130 | 131 |
|
131 | | - # apply kv cache quantization before any attention quantization |
132 | | - # because attention quantization is a superset of kv cache quantization |
| 132 | + # apply and initialize kv cache quantization |
133 | 133 | if config.kv_cache_scheme is not None: |
134 | | - scheme = QuantizationScheme( |
135 | | - targets=[".*self_attn$"], input_activations=config.kv_cache_scheme |
| 134 | + _apply_kv_cache_scheme( |
| 135 | + model, config.kv_cache_scheme, config.quantization_status, force_zero_point |
136 | 136 | ) |
137 | | - for submodule in model.modules(): |
138 | | - if is_attention_module(submodule): |
139 | | - submodule.quantization_scheme = scheme |
140 | | - initialize_hooked_kv_cache(model, submodule) |
141 | | - initialize_module_for_quantization( |
142 | | - submodule, |
143 | | - force_zero_point=force_zero_point, |
144 | | - ) |
145 | | - submodule.quantization_status = config.quantization_status |
146 | 137 |
|
147 | 138 | # build mapping of targets to schemes for easier matching |
148 | 139 | # use ordered dict to preserve target ordering in config |
@@ -191,6 +182,29 @@ def apply_quantization_config( |
191 | 182 | submodule.quantization_status = config.quantization_status |
192 | 183 |
|
193 | 184 |
|
| 185 | +def _apply_kv_cache_scheme( |
| 186 | + model: torch.nn.Module, |
| 187 | + kv_cache_scheme: QuantizationArgs, |
| 188 | + status: QuantizationStatus, |
| 189 | + force_zero_point: bool, |
| 190 | +): |
| 191 | + # applies and initializes kv cache quantization |
| 192 | + # this step cannot come after attention apply/initialize |
| 193 | + # otherwise it will override the attention qparams |
| 194 | + scheme = QuantizationScheme( |
| 195 | + targets=[".*self_attn$"], input_activations=kv_cache_scheme |
| 196 | + ) |
| 197 | + for submodule in model.modules(): |
| 198 | + if is_attention_module(submodule): |
| 199 | + submodule.quantization_scheme = scheme |
| 200 | + initialize_hooked_kv_cache(model, submodule) |
| 201 | + initialize_module_for_quantization( |
| 202 | + submodule, |
| 203 | + force_zero_point=force_zero_point, |
| 204 | + ) |
| 205 | + submodule.quantization_status = status |
| 206 | + |
| 207 | + |
194 | 208 | def _load_quant_args_from_mapping( |
195 | 209 | base_name: str, module_name: str, module: Module, mapping: Dict |
196 | 210 | ): |
|
0 commit comments