From 05354c59399bfc33f2d8293e2cc24bf97ec5fcc9 Mon Sep 17 00:00:00 2001 From: Charles Hernandez Date: Thu, 6 Nov 2025 20:43:06 +0000 Subject: [PATCH] early and better error for divisibility issues Summary Adding divisibility check for block and group quantization Signed-off-by: HDCharles --- .../quantization/lifecycle/apply.py | 126 +++++++- .../test_quantization/lifecycle/test_apply.py | 298 ++++++++++++++++++ 2 files changed, 422 insertions(+), 2 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 28c8a7b9..a3f40692 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -28,7 +28,10 @@ initialize_module_for_quantization, is_attention_module, ) -from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_args import ( + QuantizationArgs, + QuantizationStrategy, +) from compressed_tensors.quantization.quant_config import ( QuantizationConfig, QuantizationStatus, @@ -110,7 +113,10 @@ def load_pretrained_quantization_parameters( def apply_quantization_config( - model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False + model: Module, + config: Union[QuantizationConfig, None], + run_compressed: bool = False, + validate_group_or_block_size: bool = True, ): """ Initializes the model for quantization in-place based on the given config. @@ -120,6 +126,8 @@ def apply_quantization_config( :param config: quantization config :param run_compressed: Whether the model will be run in compressed mode or decompressed fully on load + :param validate_group_or_block_size: if True, validates that weight dimensions are + evenly divisible by group_size or block_structure. Defaults to True. """ from compressed_tensors.linear.compressed_linear import CompressedLinear @@ -182,6 +190,11 @@ def apply_quantization_config( submodule.quantization_status = config.quantization_status + # Validate group/block size divisibility if enabled + if validate_group_or_block_size: + match_generator = match_named_modules(model, target_to_scheme, config.ignore) + _validate_group_or_block_size(match_generator) + def _apply_kv_cache_scheme( model: torch.nn.Module, @@ -258,3 +271,112 @@ def _scheme_from_targets( # return the first scheme (the prioritized one, # since the order of target_to_scheme matters) return target_to_scheme[targets[0]] + + +def _validate_group_or_block_size(modules: list[tuple[str, Module]]) -> None: + """ + Validates quantization parameter divisibility for all modules: + - GROUP strategy: weight columns must be evenly divisible by group_size + - BLOCK strategy: weight dimensions must be evenly divisible by block_structure + + Raises a ValueError if validation fails, providing a comprehensive error + message with suggested fixes. + + :param modules: List of (fqn, module) tuples to validate + :raises ValueError: If any module has dimension divisibility issues + """ + problematic_layers = [] + + for fqn, module in modules: + issue = _check_module_divisibility(fqn, module) + if issue is not None: + problematic_layers.append(fqn) + + if problematic_layers: + error_msg = _generate_divisibility_error_message(problematic_layers) + raise ValueError(error_msg) + + +def _check_module_divisibility(fqn: str, module: Module) -> Optional[str]: + """ + Checks a single module for group size divisibility (GROUP strategy) and + block structure divisibility (BLOCK strategy). + + :param fqn: Fully qualified name of the module + :param module: Module to check + :return: fqn if there's an issue, None otherwise + """ + quant_scheme = getattr(module, "quantization_scheme", None) + if quant_scheme is None: + return None + + quant_args = quant_scheme.weights + if quant_args is None: + return None + + # Check if module has weight + if not hasattr(module, "weight"): + return None + + weight = module.weight + + # Validate for GROUP strategy + if quant_args.strategy == QuantizationStrategy.GROUP: + group_size = quant_args.group_size + if group_size is None: + return None + + # Get number of columns based on module type + if isinstance(module, torch.nn.Conv2d): + num_columns = weight.shape[1] + else: + num_columns = weight.shape[-1] + + # Check divisibility + if num_columns % group_size != 0: + return fqn + + # Validate for BLOCK strategy + elif quant_args.strategy == QuantizationStrategy.BLOCK: + block_structure = quant_args.block_structure + if block_structure is None: + return None + + block_height, block_width = block_structure + + if isinstance(module, torch.nn.Conv2d): + num_rows, num_columns = weight.shape[:2] + else: + num_rows, num_columns = weight.shape[-2:] + + # Check divisibility for both dimensions + if num_rows % block_height != 0 or num_columns % block_width != 0: + return fqn + + return None + + +def _generate_divisibility_error_message(problematic_layers: List[str]) -> str: + """ + Generate error message for quantization divisibility validation failures. + + :param problematic_layers: List of layer names with divisibility issues + :return: Formatted error message + """ + header = "ERROR: Quantization divisibility validation failed!\n" + description = ( + "Found layers with weight dimensions that are not evenly divisible\n" + "by the specified group_size or block_structure.\n\n" + ) + + error_msg = "\n" + "=" * 80 + "\n" + header + "=" * 80 + "\n\n" + description + + error_msg += "\n" + "-" * 80 + "\n" + error_msg += "SUGGESTED FIX: Add the following to your quantization config:\n\n" + error_msg += f" ignore: {problematic_layers}\n" + error_msg += "\n" + "-" * 80 + "\n" + error_msg += "\nNote: any modules like lm_head which are ignored by default" + error_msg += "also need to be manually ignored" + error_msg += "=" * 80 + "\n" + + return error_msg diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 8d57ca40..4ffd849e 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -479,3 +479,301 @@ def test_apply_attention(): assert hasattr(layer.self_attn, "q_scale") assert hasattr(layer.self_attn, "k_scale") assert hasattr(layer.self_attn, "v_scale") + + +def test_group_size_validation_raises_error(): + """Test that GROUP strategy validation raises error for non-divisible layers""" + model = AutoModelForCausalLM.from_pretrained("nm-testing/llama2.c-stories15M") + + # Create config with GROUP strategy and group_size=128 + # Most layers have 288 input features (NOT divisible by 128) + # down_proj layers have 768 input features (divisible by 128) + config = QuantizationConfig( + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="group", + group_size=128, + ), + ) + }, + ) + + # Should raise ValueError because most layers (q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj) + # have 288 input features which are not divisible by group_size=128 + with pytest.raises(ValueError, match="Quantization divisibility validation failed"): + apply_quantization_config(model, config, validate_group_or_block_size=True) + + +def test_group_size_validation_error_message(): + """Test that validation error message contains helpful information""" + model = AutoModelForCausalLM.from_pretrained("nm-testing/llama2.c-stories15M") + + config = QuantizationConfig( + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="group", + group_size=128, + ), + ) + }, + ) + + try: + apply_quantization_config(model, config, validate_group_or_block_size=True) + pytest.fail("Should have raised ValueError") + except ValueError as e: + error_msg = str(e) + # Check that error message contains expected components + assert "Quantization divisibility validation failed" in error_msg + # Should contain some layer names with 288 input features + assert "q_proj" in error_msg or "k_proj" in error_msg or "gate_proj" in error_msg + assert "ignore:" in error_msg + assert "SUGGESTED FIX" in error_msg + + +def test_group_size_validation_disabled(): + """Test that validation can be disabled""" + model = AutoModelForCausalLM.from_pretrained("nm-testing/llama2.c-stories15M") + + config = QuantizationConfig( + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="group", + group_size=128, + ), + ) + }, + ) + + # Should NOT raise error when validation is disabled + apply_quantization_config(model, config, validate_group_or_block_size=False) + + # Verify that quantization was still applied + assert hasattr(model.model.layers[0].self_attn.q_proj, "quantization_scheme") + assert hasattr(model.model.layers[0].mlp.down_proj, "quantization_scheme") + + +def test_group_size_validation_with_ignore_list(): + """Test that validation respects ignore list""" + model = AutoModelForCausalLM.from_pretrained("nm-testing/llama2.c-stories15M") + + # Create config with problematic layers (288 input features) in ignore list + # Only quantize down_proj layers which have 768 input features (divisible by 128) + config = QuantizationConfig( + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="group", + group_size=128, + ), + ) + }, + ignore=["re:.*q_proj", "re:.*k_proj", "re:.*v_proj", "re:.*o_proj", + "re:.*gate_proj", "re:.*up_proj", "lm_head"], + ) + + # Should NOT raise error because problematic layers are ignored + apply_quantization_config(model, config, validate_group_or_block_size=True) + + # Verify that only down_proj layers were quantized + assert hasattr(model.model.layers[0].mlp.down_proj, "quantization_scheme") + assert not hasattr(model.model.layers[0].self_attn.q_proj, "quantization_scheme") + assert not hasattr(model.model.layers[0].mlp.gate_proj, "quantization_scheme") + + +def test_channel_strategy_no_validation(): + """Test that validation doesn't trigger for non-GROUP strategies""" + model = AutoModelForCausalLM.from_pretrained("nm-testing/llama2.c-stories15M") + + # Create config with CHANNEL strategy (not GROUP) + config = QuantizationConfig( + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="channel", + ), + ) + }, + ) + + # Should NOT raise error for CHANNEL strategy + apply_quantization_config(model, config, validate_group_or_block_size=True) + + # Verify that quantization was applied + assert hasattr(model.model.layers[0].self_attn.q_proj, "quantization_scheme") + assert hasattr(model.model.layers[0].mlp.down_proj, "quantization_scheme") + + +def test_tensor_strategy_no_validation(): + """Test that validation doesn't trigger for TENSOR strategy""" + model = AutoModelForCausalLM.from_pretrained("nm-testing/llama2.c-stories15M") + + # Create config with TENSOR strategy (not GROUP) + config = QuantizationConfig( + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=8, + type="int", + symmetric=True, + strategy="tensor", + ), + ) + }, + ) + + # Should NOT raise error for TENSOR strategy + apply_quantization_config(model, config, validate_group_or_block_size=True) + + # Verify that quantization was applied + assert hasattr(model.model.layers[0].self_attn.q_proj, "quantization_scheme") + assert hasattr(model.model.layers[0].mlp.down_proj, "quantization_scheme") + + +def test_block_strategy_validation_raises_error(): + """Test that BLOCK strategy validation raises error for non-divisible layers""" + model = AutoModelForCausalLM.from_pretrained("nm-testing/llama2.c-stories15M") + + # Create config with BLOCK strategy + # Using block_structure [100, 100] which won't divide layers with 288 input features + config = QuantizationConfig( + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="block", + block_structure=[100, 100], + ), + ) + }, + ) + + # Should raise ValueError because layers with 288 input features + # are not divisible by block_structure [100, 100] + with pytest.raises(ValueError, match="Quantization divisibility validation failed"): + apply_quantization_config(model, config, validate_group_or_block_size=True) + + +def test_block_strategy_validation_passes(): + """Test that BLOCK strategy validation passes when dimensions are divisible""" + model = AutoModelForCausalLM.from_pretrained("nm-testing/llama2.c-stories15M") + + # Create config with BLOCK strategy + # Using block_structure [96, 96] which divides down_proj layers (768x288) + # 768 % 96 = 0, 288 % 96 = 0 + # but testing with ignore list for others + config = QuantizationConfig( + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="block", + block_structure=[96, 96], + ), + ) + }, + ignore=["re:.*q_proj", "re:.*k_proj", "re:.*v_proj", "re:.*o_proj", + "re:.*gate_proj", "re:.*up_proj", "lm_head"], + ) + + # Should NOT raise error + apply_quantization_config(model, config, validate_group_or_block_size=True) + + # Verify that only down_proj layers were quantized + assert hasattr(model.model.layers[0].mlp.down_proj, "quantization_scheme") + assert not hasattr(model.model.layers[0].self_attn.q_proj, "quantization_scheme") + + +def test_group_size_validation_with_divisible_group_size(): + """Test that validation passes when all layers are divisible by group_size""" + model = AutoModelForCausalLM.from_pretrained("nm-testing/llama2.c-stories15M") + + # Using group_size=96 which divides both 288 and 768 + # 288 % 96 = 0, 768 % 96 = 0 + config = QuantizationConfig( + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="group", + group_size=96, + ), + ) + }, + ignore=["lm_head"], + ) + + # Should NOT raise error + apply_quantization_config(model, config, validate_group_or_block_size=True) + + # Verify quantization was applied to layers + assert hasattr(model.model.layers[0].self_attn.q_proj, "quantization_scheme") + assert model.model.layers[0].self_attn.q_proj.quantization_scheme.weights.group_size == 96 + assert hasattr(model.model.layers[0].mlp.down_proj, "quantization_scheme") + assert model.model.layers[0].mlp.down_proj.quantization_scheme.weights.group_size == 96 + + +def test_group_size_validation_with_partial_ignore(): + """Test validation with partial ignore list""" + model = AutoModelForCausalLM.from_pretrained("nm-testing/llama2.c-stories15M") + + # Ignore only some layers, so other layers with 288 input features should still cause an error + config = QuantizationConfig( + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="group", + group_size=128, + ), + ) + }, + ignore=["re:.*gate_proj", "lm_head"], # Only ignore gate_proj, not other 288-input layers + ) + + # Should raise ValueError because layers like q_proj, k_proj, etc. are not divisible and not ignored + with pytest.raises(ValueError) as exc_info: + apply_quantization_config(model, config, validate_group_or_block_size=True) + + # Check that error mentions layers that aren't ignored + error_msg = str(exc_info.value) + # Should mention some of the non-ignored 288-input layers + assert any(name in error_msg for name in ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj"]) + # Should NOT mention gate_proj since it's ignored + assert "gate_proj" not in error_msg