From 61fafc1a7e92f94d29f98ec53d7cd2feaa9ee2a1 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 30 Oct 2025 21:10:23 +0000 Subject: [PATCH 1/7] p1 Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 64 ++++++++++++++++++------- 1 file changed, 47 insertions(+), 17 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index bc86ba25f6..281cf34dde 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -1,5 +1,5 @@ import inspect -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union, Any import torch from compressed_tensors.quantization import disable_quantization @@ -130,8 +130,10 @@ class AWQModifier(Modifier, QuantizationMixin): # Private vars set during initialization, cleared during finalization _resolved_mappings: List[ResolvedMapping] = PrivateAttr(default_factory=list) - # Cache list of forward input args for each parent module, one dict for each batch - _parent_args_cache: Dict[Module, IntermediatesCache] = PrivateAttr( + # Cache of kwargs for parent modules, one dict for each batch + _model_kwargs_cache: list[dict[str, Any]] = PrivateAttr(default_factory=list) + # Cache of forward hidden states for each parent module, one tensor for each batch + _parent_hidden_states_cache: dict[Module, IntermediatesCache] = PrivateAttr( default_factory=dict ) # Dict[smooth layer name, (activation means, activation counts)] @@ -290,7 +292,8 @@ def on_finalize(self, state: State, **kwargs) -> bool: if not self.ended_: self.on_end(state, None) - self._parent_args_cache.clear() + self._parent_hidden_states_cache.clear() + self._model_kwargs_cache = None self._smooth_activation_means.clear() self._resolved_mappings.clear() @@ -387,13 +390,31 @@ def _setup_activation_cache_hooks(self) -> None: calculate the dynamic range during calibration """ - def cache_parent_kwargs_hook( + def cache_hidden_states_kwargs_hook( module: torch.nn.Module, args: Tuple[torch.Tensor, ...], kwargs, ): - values = inspect.signature(module.forward).bind(*args, **kwargs) - self._parent_args_cache[module].append(values.arguments) + signature = inspect.signature(module.forward) + first_param = next(iter(signature.parameters)) + batch_idx = len(self._parent_hidden_states_cache[module]) + + self._parent_hidden_states_cache[module].append( + {first_param: args[0]} if len(args) > 0 else {} + ) + + if len(self._model_kwargs_cache) < batch_idx: + raise ValueError("THIS SHOULDNT HAPPEN") + elif len(self._model_kwargs_cache) == batch_idx: + self._model_kwargs_cache.append(kwargs) + else: + self._model_kwargs_cache[batch_idx].update( + { + k: v + for k, v in kwargs.items() + if k not in self._model_kwargs_cache[batch_idx] + } + ) def create_cache_smooth_activations_hook_fn(smooth_name): def cache_smooth_activations_hook( @@ -409,17 +430,18 @@ def cache_smooth_activations_hook( return cache_smooth_activations_hook + self._model_kwargs_cache = [] for mapping in self._resolved_mappings: # parent kwargs needed for future forward passes # same parent may appear multiple times in resolved mappings - if mapping.parent not in self._parent_args_cache: - self._parent_args_cache[mapping.parent] = IntermediatesCache( + if mapping.parent not in self._parent_hidden_states_cache: + self._parent_hidden_states_cache[mapping.parent] = IntermediatesCache( None, self.offload_device, ) self.register_hook( mapping.parent, - cache_parent_kwargs_hook, + cache_hidden_states_kwargs_hook, "forward_pre", with_kwargs=True, ) @@ -560,14 +582,20 @@ def _smooth(module): self._assert_all_activations_consumed() def _run_samples(self, module: Module) -> List[torch.Tensor]: - outputs = [ - module(**batch_kwargs) for batch_kwargs in self._parent_args_cache[module] - ] - return [ + outputs = [] + for batch_idx in range(len(self._parent_hidden_states_cache[module])): + batch_kwargs = self._model_kwargs_cache[batch_idx].copy() + batch_kwargs.update( + self._parent_hidden_states_cache[module].fetch(batch_idx) + ) + batch_kwargs = inspect.signature(module.forward).bind( + **batch_kwargs, + ) + output = module(**batch_kwargs) # If Tuple, assume that first argument is the input - output[0] if isinstance(output, Tuple) else output - for output in outputs - ] + outputs.append(output[0] if isinstance(output, Tuple) else output) + + return outputs def _compute_best_scale( self, @@ -592,6 +620,8 @@ def _compute_best_scale( best_scales = None best_error = float("inf") + parent_module = torch.compile(parent_module) + org_sd = { k: v.cpu() for k, v in parent_module.state_dict().items() From 5b57be244366df25f57e8f6aeea9666a68b17f8e Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 30 Oct 2025 22:45:43 +0000 Subject: [PATCH 2/7] seemingly working with flag Signed-off-by: Brian Dellabetta --- examples/awq/qwen3_moe_example.py | 1 + src/llmcompressor/modifiers/awq/base.py | 73 +++++++++++++++---------- src/llmcompressor/pipelines/cache.py | 10 +++- 3 files changed, 53 insertions(+), 31 deletions(-) diff --git a/examples/awq/qwen3_moe_example.py b/examples/awq/qwen3_moe_example.py index 4c9644998f..3209270f8a 100644 --- a/examples/awq/qwen3_moe_example.py +++ b/examples/awq/qwen3_moe_example.py @@ -55,6 +55,7 @@ def tokenize(sample): ignore=["lm_head", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$"], scheme="W4A16", targets=["Linear"], + use_auto_awq_mem_hack=False, ), ] diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 281cf34dde..714f0e9dba 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -122,6 +122,7 @@ class AWQModifier(Modifier, QuantizationMixin): mappings: Optional[List[AWQMapping]] = None offload_device: Optional[torch.device] = None duo_scaling: bool = True + use_auto_awq_mem_hack: bool = True # Private vars set during validation _num_bits: Optional[int] = PrivateAttr(default=None) @@ -130,10 +131,10 @@ class AWQModifier(Modifier, QuantizationMixin): # Private vars set during initialization, cleared during finalization _resolved_mappings: List[ResolvedMapping] = PrivateAttr(default_factory=list) - # Cache of kwargs for parent modules, one dict for each batch - _model_kwargs_cache: list[dict[str, Any]] = PrivateAttr(default_factory=list) + # Model-wise cache of kwargs for all parent modules + _model_kwargs_cache: IntermediatesCache = PrivateAttr() # Cache of forward hidden states for each parent module, one tensor for each batch - _parent_hidden_states_cache: dict[Module, IntermediatesCache] = PrivateAttr( + _parent_kwargs_cache: dict[Module, IntermediatesCache] = PrivateAttr( default_factory=dict ) # Dict[smooth layer name, (activation means, activation counts)] @@ -292,7 +293,7 @@ def on_finalize(self, state: State, **kwargs) -> bool: if not self.ended_: self.on_end(state, None) - self._parent_hidden_states_cache.clear() + self._parent_kwargs_cache.clear() self._model_kwargs_cache = None self._smooth_activation_means.clear() self._resolved_mappings.clear() @@ -395,26 +396,32 @@ def cache_hidden_states_kwargs_hook( args: Tuple[torch.Tensor, ...], kwargs, ): - signature = inspect.signature(module.forward) - first_param = next(iter(signature.parameters)) - batch_idx = len(self._parent_hidden_states_cache[module]) + batch_idx = len(self._parent_kwargs_cache[module]) - self._parent_hidden_states_cache[module].append( - {first_param: args[0]} if len(args) > 0 else {} - ) + values = inspect.signature(module.forward).bind(*args, **kwargs) + + # our original impl: all kwargs are cached for each parent + # technically correct way, but probably lots of redundancy + if not self.use_auto_awq_mem_hack: + self._parent_kwargs_cache[module].append(values.arguments) + return + + # autoawq impl: only first param is cached for each parent + # all others are pulled from model-wide cache + # much more memory efficient, but possibly incorrect + # depending on model definition + first_param_name, first_arg = next(iter(values.arguments.items())) + + self._parent_kwargs_cache[module].append({first_param_name: first_arg}) + + values.arguments.pop(first_param_name) if len(self._model_kwargs_cache) < batch_idx: raise ValueError("THIS SHOULDNT HAPPEN") elif len(self._model_kwargs_cache) == batch_idx: - self._model_kwargs_cache.append(kwargs) + self._model_kwargs_cache.append(values.arguments) else: - self._model_kwargs_cache[batch_idx].update( - { - k: v - for k, v in kwargs.items() - if k not in self._model_kwargs_cache[batch_idx] - } - ) + self._model_kwargs_cache.update(batch_idx, values.arguments) def create_cache_smooth_activations_hook_fn(smooth_name): def cache_smooth_activations_hook( @@ -430,12 +437,13 @@ def cache_smooth_activations_hook( return cache_smooth_activations_hook - self._model_kwargs_cache = [] + # Don't offload this, it will be used consistently + self._model_kwargs_cache = IntermediatesCache(None, None) for mapping in self._resolved_mappings: # parent kwargs needed for future forward passes # same parent may appear multiple times in resolved mappings - if mapping.parent not in self._parent_hidden_states_cache: - self._parent_hidden_states_cache[mapping.parent] = IntermediatesCache( + if mapping.parent not in self._parent_kwargs_cache: + self._parent_kwargs_cache[mapping.parent] = IntermediatesCache( None, self.offload_device, ) @@ -577,20 +585,23 @@ def _smooth(module): # remove caches needed to smooth this mapping del self._smooth_activation_means[mapping.smooth_name] - for v in self._parent_args_cache.values(): + for v in self._parent_kwargs_cache.values(): v.batch_intermediates.clear() self._assert_all_activations_consumed() def _run_samples(self, module: Module) -> List[torch.Tensor]: outputs = [] - for batch_idx in range(len(self._parent_hidden_states_cache[module])): - batch_kwargs = self._model_kwargs_cache[batch_idx].copy() - batch_kwargs.update( - self._parent_hidden_states_cache[module].fetch(batch_idx) - ) - batch_kwargs = inspect.signature(module.forward).bind( - **batch_kwargs, + parameter_keys = inspect.signature(module.forward).parameters.keys() + + for batch_idx in range(len(self._parent_kwargs_cache[module])): + batch_kwargs = self._model_kwargs_cache.fetch( + batch_idx, ignore_missing=True ) + batch_kwargs.update(self._parent_kwargs_cache[module].fetch(batch_idx)) + batch_kwargs = { + k: v for k, v in batch_kwargs.items() if k in parameter_keys + } + output = module(**batch_kwargs) # If Tuple, assume that first argument is the input outputs.append(output[0] if isinstance(output, Tuple) else output) @@ -620,7 +631,9 @@ def _compute_best_scale( best_scales = None best_error = float("inf") - parent_module = torch.compile(parent_module) + # NOTE: this changes the module pointers, so it invalidates + # field `_parent_kwargs_cache: dict[Module, IntermediatesCache]`` + # parent_module = torch.compile(parent_module) org_sd = { k: v.cpu() diff --git a/src/llmcompressor/pipelines/cache.py b/src/llmcompressor/pipelines/cache.py index dd600a0f76..5d67b2c2af 100644 --- a/src/llmcompressor/pipelines/cache.py +++ b/src/llmcompressor/pipelines/cache.py @@ -90,15 +90,23 @@ def from_dataloader( return cls(batch_intermediates, offload_device) def fetch( - self, batch_index: int, input_names: Optional[List[str]] = None + self, + batch_index: int, + input_names: Optional[List[str]] = None, + ignore_missing: bool = False, ) -> Dict[str, Any]: """ Fetch values belonging to a batch :param batch_index: index of batch whose values are being fetched :param input_names: list of keys whose values are being fetched + :ignore_missing: if an intermediate for batch_index is not found, + return an empty dict if this is True, otherwise an Out of Index + error will be raised. :return: dictionary mapping keys to onloaded values """ + if ignore_missing and batch_index >= len(self.batch_intermediates): + return {} intermediates = self.batch_intermediates[batch_index] return { From f115d8fbfde5ddca8db847f4808e01fa226c2212 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 30 Oct 2025 23:53:25 +0000 Subject: [PATCH 3/7] comment Signed-off-by: Brian Dellabetta --- examples/awq/qwen3_moe_example.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/awq/qwen3_moe_example.py b/examples/awq/qwen3_moe_example.py index 3209270f8a..5792f297bf 100644 --- a/examples/awq/qwen3_moe_example.py +++ b/examples/awq/qwen3_moe_example.py @@ -55,7 +55,8 @@ def tokenize(sample): ignore=["lm_head", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$"], scheme="W4A16", targets=["Linear"], - use_auto_awq_mem_hack=False, + use_auto_awq_mem_hack=False, # GPU VRAM 37784MiB + # use_auto_awq_mem_hack=True, # GPU VRAM 37792MiB ), ] From 115f4273d0b828ffea8ae4ef0f8b354ca7370f0f Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 11 Nov 2025 22:16:08 +0000 Subject: [PATCH 4/7] example showing why this feature won't add much value Signed-off-by: Brian Dellabetta --- examples/awq/qwen3_moe_example.py | 4 ++++ src/llmcompressor/modifiers/awq/base.py | 28 +++++++++++++++---------- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/examples/awq/qwen3_moe_example.py b/examples/awq/qwen3_moe_example.py index 5792f297bf..cfd44926bf 100644 --- a/examples/awq/qwen3_moe_example.py +++ b/examples/awq/qwen3_moe_example.py @@ -1,3 +1,4 @@ +import os from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer @@ -57,6 +58,8 @@ def tokenize(sample): targets=["Linear"], use_auto_awq_mem_hack=False, # GPU VRAM 37784MiB # use_auto_awq_mem_hack=True, # GPU VRAM 37792MiB + # use_auto_awq_mem_hack=os.getenv("USE_HACK", "") == "yes", # GPU VRAM 37784MiB + # use_auto_awq_mem_hack=True, # GPU VRAM 37792MiB ), ] @@ -67,6 +70,7 @@ def tokenize(sample): recipe=recipe, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, + moe_calibrate_all_experts=False, ) # Confirm generations of the quantized model look sane. diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 714f0e9dba..01808f881a 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -416,12 +416,10 @@ def cache_hidden_states_kwargs_hook( values.arguments.pop(first_param_name) - if len(self._model_kwargs_cache) < batch_idx: - raise ValueError("THIS SHOULDNT HAPPEN") - elif len(self._model_kwargs_cache) == batch_idx: + if len(self._model_kwargs_cache) == 0: self._model_kwargs_cache.append(values.arguments) else: - self._model_kwargs_cache.update(batch_idx, values.arguments) + self._model_kwargs_cache.update(0, values.arguments) def create_cache_smooth_activations_hook_fn(smooth_name): def cache_smooth_activations_hook( @@ -474,6 +472,15 @@ def _apply_smoothing(self, model: Module) -> None: """ # NOTE: When using SequentialPipeline, not all the mappings # will have cached activations in the segment being udpated + + print("SIZE", self._model_kwargs_cache.size()) + try: + cache = self._model_kwargs_cache.fetch(0) + for k, v in cache.items(): + print(k, f"{v.shape} {v.device}" if isinstance(v, torch.Tensor) else v) + except: + pass + mappings_to_smooth = [ mapping for mapping in self._resolved_mappings @@ -587,25 +594,24 @@ def _smooth(module): for v in self._parent_kwargs_cache.values(): v.batch_intermediates.clear() + self._assert_all_activations_consumed() - def _run_samples(self, module: Module) -> List[torch.Tensor]: - outputs = [] + def _run_samples(self, module: Module) -> list[torch.Tensor]: parameter_keys = inspect.signature(module.forward).parameters.keys() + outputs = [] for batch_idx in range(len(self._parent_kwargs_cache[module])): - batch_kwargs = self._model_kwargs_cache.fetch( - batch_idx, ignore_missing=True - ) + batch_kwargs = self._model_kwargs_cache.fetch(0, ignore_missing=True) batch_kwargs.update(self._parent_kwargs_cache[module].fetch(batch_idx)) batch_kwargs = { k: v for k, v in batch_kwargs.items() if k in parameter_keys } output = module(**batch_kwargs) - # If Tuple, assume that first argument is the input - outputs.append(output[0] if isinstance(output, Tuple) else output) + # If tuple, assume that first argument is the input + outputs.append(output[0] if isinstance(output, tuple) else output) return outputs def _compute_best_scale( From 4295560acd408761a7eb509b3fea91248bc3e66d Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 11 Nov 2025 22:19:32 +0000 Subject: [PATCH 5/7] rm Signed-off-by: Brian Dellabetta --- examples/awq/qwen3_moe_example.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/awq/qwen3_moe_example.py b/examples/awq/qwen3_moe_example.py index cfd44926bf..eeed51bc96 100644 --- a/examples/awq/qwen3_moe_example.py +++ b/examples/awq/qwen3_moe_example.py @@ -70,7 +70,6 @@ def tokenize(sample): recipe=recipe, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, - moe_calibrate_all_experts=False, ) # Confirm generations of the quantized model look sane. From 5917a16d5f8642020a35d80025d32a1ea17dd39a Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 11 Nov 2025 22:20:36 +0000 Subject: [PATCH 6/7] cleanup Signed-off-by: Brian Dellabetta --- examples/awq/qwen3_moe_example.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/awq/qwen3_moe_example.py b/examples/awq/qwen3_moe_example.py index eeed51bc96..6894393abb 100644 --- a/examples/awq/qwen3_moe_example.py +++ b/examples/awq/qwen3_moe_example.py @@ -56,10 +56,8 @@ def tokenize(sample): ignore=["lm_head", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$"], scheme="W4A16", targets=["Linear"], - use_auto_awq_mem_hack=False, # GPU VRAM 37784MiB - # use_auto_awq_mem_hack=True, # GPU VRAM 37792MiB - # use_auto_awq_mem_hack=os.getenv("USE_HACK", "") == "yes", # GPU VRAM 37784MiB - # use_auto_awq_mem_hack=True, # GPU VRAM 37792MiB + # use_auto_awq_mem_hack=os.getenv("USE_HACK", "") == "yes", + # GPU VRAM consistently peakds at ~37784MiB regardless ), ] From d2837a5881f4f97e455ad8bdbf79f142619f9e62 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 11 Nov 2025 22:20:43 +0000 Subject: [PATCH 7/7] cleanup Signed-off-by: Brian Dellabetta --- examples/awq/qwen3_moe_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/awq/qwen3_moe_example.py b/examples/awq/qwen3_moe_example.py index 6894393abb..f36660bfca 100644 --- a/examples/awq/qwen3_moe_example.py +++ b/examples/awq/qwen3_moe_example.py @@ -56,7 +56,7 @@ def tokenize(sample): ignore=["lm_head", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$"], scheme="W4A16", targets=["Linear"], - # use_auto_awq_mem_hack=os.getenv("USE_HACK", "") == "yes", + use_auto_awq_mem_hack=os.getenv("USE_HACK", "") == "yes", # GPU VRAM consistently peakds at ~37784MiB regardless ), ]