diff --git a/examples/awq/qwen3_moe_example.py b/examples/awq/qwen3_moe_example.py index 4c9644998f..f36660bfca 100644 --- a/examples/awq/qwen3_moe_example.py +++ b/examples/awq/qwen3_moe_example.py @@ -1,3 +1,4 @@ +import os from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer @@ -55,6 +56,8 @@ def tokenize(sample): ignore=["lm_head", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$"], scheme="W4A16", targets=["Linear"], + use_auto_awq_mem_hack=os.getenv("USE_HACK", "") == "yes", + # GPU VRAM consistently peakds at ~37784MiB regardless ), ] diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index bc86ba25f6..01808f881a 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -1,5 +1,5 @@ import inspect -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union, Any import torch from compressed_tensors.quantization import disable_quantization @@ -122,6 +122,7 @@ class AWQModifier(Modifier, QuantizationMixin): mappings: Optional[List[AWQMapping]] = None offload_device: Optional[torch.device] = None duo_scaling: bool = True + use_auto_awq_mem_hack: bool = True # Private vars set during validation _num_bits: Optional[int] = PrivateAttr(default=None) @@ -130,8 +131,10 @@ class AWQModifier(Modifier, QuantizationMixin): # Private vars set during initialization, cleared during finalization _resolved_mappings: List[ResolvedMapping] = PrivateAttr(default_factory=list) - # Cache list of forward input args for each parent module, one dict for each batch - _parent_args_cache: Dict[Module, IntermediatesCache] = PrivateAttr( + # Model-wise cache of kwargs for all parent modules + _model_kwargs_cache: IntermediatesCache = PrivateAttr() + # Cache of forward hidden states for each parent module, one tensor for each batch + _parent_kwargs_cache: dict[Module, IntermediatesCache] = PrivateAttr( default_factory=dict ) # Dict[smooth layer name, (activation means, activation counts)] @@ -290,7 +293,8 @@ def on_finalize(self, state: State, **kwargs) -> bool: if not self.ended_: self.on_end(state, None) - self._parent_args_cache.clear() + self._parent_kwargs_cache.clear() + self._model_kwargs_cache = None self._smooth_activation_means.clear() self._resolved_mappings.clear() @@ -387,13 +391,35 @@ def _setup_activation_cache_hooks(self) -> None: calculate the dynamic range during calibration """ - def cache_parent_kwargs_hook( + def cache_hidden_states_kwargs_hook( module: torch.nn.Module, args: Tuple[torch.Tensor, ...], kwargs, ): + batch_idx = len(self._parent_kwargs_cache[module]) + values = inspect.signature(module.forward).bind(*args, **kwargs) - self._parent_args_cache[module].append(values.arguments) + + # our original impl: all kwargs are cached for each parent + # technically correct way, but probably lots of redundancy + if not self.use_auto_awq_mem_hack: + self._parent_kwargs_cache[module].append(values.arguments) + return + + # autoawq impl: only first param is cached for each parent + # all others are pulled from model-wide cache + # much more memory efficient, but possibly incorrect + # depending on model definition + first_param_name, first_arg = next(iter(values.arguments.items())) + + self._parent_kwargs_cache[module].append({first_param_name: first_arg}) + + values.arguments.pop(first_param_name) + + if len(self._model_kwargs_cache) == 0: + self._model_kwargs_cache.append(values.arguments) + else: + self._model_kwargs_cache.update(0, values.arguments) def create_cache_smooth_activations_hook_fn(smooth_name): def cache_smooth_activations_hook( @@ -409,17 +435,19 @@ def cache_smooth_activations_hook( return cache_smooth_activations_hook + # Don't offload this, it will be used consistently + self._model_kwargs_cache = IntermediatesCache(None, None) for mapping in self._resolved_mappings: # parent kwargs needed for future forward passes # same parent may appear multiple times in resolved mappings - if mapping.parent not in self._parent_args_cache: - self._parent_args_cache[mapping.parent] = IntermediatesCache( + if mapping.parent not in self._parent_kwargs_cache: + self._parent_kwargs_cache[mapping.parent] = IntermediatesCache( None, self.offload_device, ) self.register_hook( mapping.parent, - cache_parent_kwargs_hook, + cache_hidden_states_kwargs_hook, "forward_pre", with_kwargs=True, ) @@ -444,6 +472,15 @@ def _apply_smoothing(self, model: Module) -> None: """ # NOTE: When using SequentialPipeline, not all the mappings # will have cached activations in the segment being udpated + + print("SIZE", self._model_kwargs_cache.size()) + try: + cache = self._model_kwargs_cache.fetch(0) + for k, v in cache.items(): + print(k, f"{v.shape} {v.device}" if isinstance(v, torch.Tensor) else v) + except: + pass + mappings_to_smooth = [ mapping for mapping in self._resolved_mappings @@ -555,19 +592,27 @@ def _smooth(module): # remove caches needed to smooth this mapping del self._smooth_activation_means[mapping.smooth_name] - for v in self._parent_args_cache.values(): + for v in self._parent_kwargs_cache.values(): v.batch_intermediates.clear() + self._assert_all_activations_consumed() - def _run_samples(self, module: Module) -> List[torch.Tensor]: - outputs = [ - module(**batch_kwargs) for batch_kwargs in self._parent_args_cache[module] - ] - return [ - # If Tuple, assume that first argument is the input - output[0] if isinstance(output, Tuple) else output - for output in outputs - ] + def _run_samples(self, module: Module) -> list[torch.Tensor]: + parameter_keys = inspect.signature(module.forward).parameters.keys() + + outputs = [] + for batch_idx in range(len(self._parent_kwargs_cache[module])): + batch_kwargs = self._model_kwargs_cache.fetch(0, ignore_missing=True) + batch_kwargs.update(self._parent_kwargs_cache[module].fetch(batch_idx)) + batch_kwargs = { + k: v for k, v in batch_kwargs.items() if k in parameter_keys + } + + output = module(**batch_kwargs) + # If tuple, assume that first argument is the input + + outputs.append(output[0] if isinstance(output, tuple) else output) + return outputs def _compute_best_scale( self, @@ -592,6 +637,10 @@ def _compute_best_scale( best_scales = None best_error = float("inf") + # NOTE: this changes the module pointers, so it invalidates + # field `_parent_kwargs_cache: dict[Module, IntermediatesCache]`` + # parent_module = torch.compile(parent_module) + org_sd = { k: v.cpu() for k, v in parent_module.state_dict().items() diff --git a/src/llmcompressor/pipelines/cache.py b/src/llmcompressor/pipelines/cache.py index dd600a0f76..5d67b2c2af 100644 --- a/src/llmcompressor/pipelines/cache.py +++ b/src/llmcompressor/pipelines/cache.py @@ -90,15 +90,23 @@ def from_dataloader( return cls(batch_intermediates, offload_device) def fetch( - self, batch_index: int, input_names: Optional[List[str]] = None + self, + batch_index: int, + input_names: Optional[List[str]] = None, + ignore_missing: bool = False, ) -> Dict[str, Any]: """ Fetch values belonging to a batch :param batch_index: index of batch whose values are being fetched :param input_names: list of keys whose values are being fetched + :ignore_missing: if an intermediate for batch_index is not found, + return an empty dict if this is True, otherwise an Out of Index + error will be raised. :return: dictionary mapping keys to onloaded values """ + if ignore_missing and batch_index >= len(self.batch_intermediates): + return {} intermediates = self.batch_intermediates[batch_index] return {