diff --git a/examples/multimodal_vision/llama4_example.py b/examples/multimodal_vision/llama4_example.py index 53b98621f..c17ec01a1 100644 --- a/examples/multimodal_vision/llama4_example.py +++ b/examples/multimodal_vision/llama4_example.py @@ -19,7 +19,7 @@ # NOTE: This restructuring is specifically required for vLLM compatibility. # To define custom calibration logic, create a new calibration module in # modeling/llama4.py that inherits from `MoECalibrationModule`, and register -# it using the `@register_moe_calibration` decorator with the appropriate +# it using the `@MoECalibrationModule.register` decorator with the appropriate # module class name (e.g., "Llama4TextMoe"). DATASET_ID = "neuralmagic/calibration" diff --git a/src/llmcompressor/modeling/__init__.py b/src/llmcompressor/modeling/__init__.py index 9d105d823..1aeec42b0 100644 --- a/src/llmcompressor/modeling/__init__.py +++ b/src/llmcompressor/modeling/__init__.py @@ -9,5 +9,13 @@ needed for efficient compression. """ +# trigger registration +from .deepseek_v3 import CalibrationDeepseekV3MoE # noqa: F401 +from .llama4 import SequentialLlama4TextMoe # noqa: F401 +from .qwen3_moe import CalibrationQwen3MoeSparseMoeBlock # noqa: F401 +from .qwen3_vl_moe import CalibrateQwen3VLMoeTextSparseMoeBlock # noqa: F401 +from .qwen3_next_moe import CalibrationQwen3NextSparseMoeBlock # noqa: F401 +# TODO: add granite4, Qwen3Next + from .fuse import * from .prepare import * diff --git a/src/llmcompressor/modeling/deepseek_v3.py b/src/llmcompressor/modeling/deepseek_v3.py index c2dd8f4b6..4618d15b6 100644 --- a/src/llmcompressor/modeling/deepseek_v3.py +++ b/src/llmcompressor/modeling/deepseek_v3.py @@ -4,13 +4,10 @@ DeepseekV3MoE as OriginalDeepseekV3MoE, ) -from llmcompressor.modeling.moe_context import ( - MoECalibrationModule, - register_moe_calibration, -) +from llmcompressor.modeling.moe_context import MoECalibrationModule -@register_moe_calibration("DeepseekV3MoE") +@MoECalibrationModule.register("DeepseekV3MoE") class CalibrationDeepseekV3MoE(MoECalibrationModule): """ Calibration version of DeepseekV3MoE that sends all tokens to all experts. diff --git a/src/llmcompressor/modeling/llama4.py b/src/llmcompressor/modeling/llama4.py index 2b49a652a..1f2ef9b77 100644 --- a/src/llmcompressor/modeling/llama4.py +++ b/src/llmcompressor/modeling/llama4.py @@ -11,14 +11,11 @@ Llama4TextMoe, ) -from llmcompressor.modeling.moe_context import ( - MoECalibrationModule, - register_moe_calibration, -) +from llmcompressor.modeling.moe_context import MoECalibrationModule from llmcompressor.utils.dev import skip_weights_initialize -@register_moe_calibration("Llama4TextMoe") +@MoECalibrationModule.register("Llama4TextMoe") class SequentialLlama4TextMoe(MoECalibrationModule): """ Calibration version of Llama4TextMoe that unpacks experts for sequential processing. @@ -38,10 +35,8 @@ def __init__( calibrate_all_experts: bool = True, ): super().__init__() - # Extract text config from multimodal config if needed - text_config = ( - config.get_text_config() if hasattr(config, "get_text_config") else config - ) + # Extract text config from multimodal config + text_config: Llama4TextConfig = config.get_text_config() self.top_k = text_config.num_experts_per_tok self.hidden_dim = text_config.hidden_size self.num_experts = text_config.num_local_experts diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py index 232b271c0..6e96c4e2d 100644 --- a/src/llmcompressor/modeling/moe_context.py +++ b/src/llmcompressor/modeling/moe_context.py @@ -8,28 +8,25 @@ Key components: - MoECalibrationModule: Abstract base class for calibration modules -- MOE_CALIBRATION_MODULES: Registry mapping module class names to calibration classes - moe_calibration_context: Context manager that applies calibration to a model """ import contextlib from abc import ABC -from typing import Dict, Type import torch +from compressed_tensors.registry import RegistryMixin, standardize_lookup_name from loguru import logger from tqdm import tqdm from transformers import PreTrainedModel __all__ = [ "MoECalibrationModule", - "MOE_CALIBRATION_MODULES", - "register_moe_calibration", "moe_calibration_context", ] -class MoECalibrationModule(ABC, torch.nn.Module): +class MoECalibrationModule(ABC, torch.nn.Module, RegistryMixin): """ Abstract base class for MoE calibration modules. @@ -62,32 +59,6 @@ def restore(self, original: torch.nn.Module) -> torch.nn.Module: ) -# Registry: module class name -> calibration module class -MOE_CALIBRATION_MODULES: Dict[str, Type[MoECalibrationModule]] = {} - - -def register_moe_calibration(module_class_name: str): - """ - Decorator to register a MoE calibration module. - - Usage: - @register_moe_calibration("DeepseekV3MoE") - class CalibrationDeepseekV3MoE(MoECalibrationModule): - ... - - Args: - module_class_name: The class name of the original module to replace - """ - - def decorator(cls: Type[MoECalibrationModule]) -> Type[MoECalibrationModule]: - if not issubclass(cls, MoECalibrationModule): - raise TypeError(f"{cls.__name__} must inherit from MoECalibrationModule") - MOE_CALIBRATION_MODULES[module_class_name] = cls - return cls - - return decorator - - @contextlib.contextmanager def moe_calibration_context( model: PreTrainedModel, @@ -115,14 +86,15 @@ def moe_calibration_context( model(**batch) # Model is now restored (unless permanent) """ + replaced = {} # Step 1: Collect all MoE modules that need replacement - logger.info("Entering MoE calibration context") + logger.debug("Entering MoE calibration context") modules_to_replace = [] for name, module in model.named_modules(): class_name = module.__class__.__name__ - if class_name in MOE_CALIBRATION_MODULES: + if _is_registered(class_name, MoECalibrationModule): modules_to_replace.append((name, module, class_name)) # Step 2: Replace modules with progress bar @@ -131,8 +103,8 @@ def moe_calibration_context( for name, module, class_name in tqdm( modules_to_replace, desc="Replacing MoE modules for calibration" ): - calibration_cls = MOE_CALIBRATION_MODULES[class_name] - replacement = calibration_cls( + replacement = MoECalibrationModule.load_from_registry( + class_name, module, model.config, calibrate_all_experts=calibrate_all_experts, @@ -165,3 +137,7 @@ def moe_calibration_context( if not replacement.is_permanent: restored = replacement.restore(original) model.set_submodule(name, restored) + + +def _is_registered(name: str, subclass: RegistryMixin): + return standardize_lookup_name(name) in subclass.registered_names() diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index 42173bb8b..af9920d1b 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -10,33 +10,12 @@ from compressed_tensors.utils import deprecated, replace_module from transformers import PreTrainedModel -# Import MoE calibration modules to trigger registration -from llmcompressor.modeling.deepseek_v3 import ( # noqa: F401 - CalibrationDeepseekV3MoE, -) -from llmcompressor.modeling.deepseek_v3 import ( - replace as replace_deepseekv3, -) -from llmcompressor.modeling.llama4 import ( # noqa: F401 - SequentialLlama4TextMoe, -) -from llmcompressor.modeling.llama4 import ( - replace as replace_llama4, -) -from llmcompressor.modeling.moe_context import ( # noqa: F401 - moe_calibration_context, -) -from llmcompressor.modeling.qwen3_moe import ( # noqa: F401 - CalibrationQwen3MoeSparseMoeBlock, -) -from llmcompressor.modeling.qwen3_next_moe import ( # noqa: F401 - CalibrationQwen3NextSparseMoeBlock, -) -from llmcompressor.modeling.qwen3_vl_moe import ( - replace as replace_Qwen3VLMoE, -) +# deprecated replacement functions +from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3 +from llmcompressor.modeling.llama4 import replace as replace_llama4 +from llmcompressor.modeling.qwen3_vl_moe import replace as replace_Qwen3VLMoE -__all__ = ["moe_calibration_context", "replace_modules_for_calibration"] +__all__ = ["replace_modules_for_calibration"] # ---------------------- module replacements; permanent ------------------------- replacements = { diff --git a/src/llmcompressor/modeling/qwen3_moe.py b/src/llmcompressor/modeling/qwen3_moe.py index 5432b731b..678e32f10 100644 --- a/src/llmcompressor/modeling/qwen3_moe.py +++ b/src/llmcompressor/modeling/qwen3_moe.py @@ -20,13 +20,10 @@ Qwen3MoeSparseMoeBlock as OriginalQwen3MoeSparseMoeBlock, ) -from llmcompressor.modeling.moe_context import ( - MoECalibrationModule, - register_moe_calibration, -) +from llmcompressor.modeling.moe_context import MoECalibrationModule -@register_moe_calibration("Qwen3MoeSparseMoeBlock") +@MoECalibrationModule.register("Qwen3MoeSparseMoeBlock") class CalibrationQwen3MoeSparseMoeBlock(MoECalibrationModule): """ Calibration version of Qwen3MoeSparseMoeBlock that sends all tokens to all experts. diff --git a/src/llmcompressor/modeling/qwen3_next_moe.py b/src/llmcompressor/modeling/qwen3_next_moe.py index cf11a84d0..823ca779b 100644 --- a/src/llmcompressor/modeling/qwen3_next_moe.py +++ b/src/llmcompressor/modeling/qwen3_next_moe.py @@ -16,13 +16,10 @@ import torch -from llmcompressor.modeling.moe_context import ( - MoECalibrationModule, - register_moe_calibration, -) +from llmcompressor.modeling.moe_context import MoECalibrationModule -@register_moe_calibration("Qwen3NextSparseMoeBlock") +@MoECalibrationModule.register("Qwen3NextSparseMoeBlock") class CalibrationQwen3NextSparseMoeBlock(MoECalibrationModule): from transformers import Qwen3NextConfig from transformers.models.qwen3_next.modeling_qwen3_next import ( diff --git a/src/llmcompressor/modeling/qwen3_vl_moe.py b/src/llmcompressor/modeling/qwen3_vl_moe.py index 5af6b7abf..c162c6d0c 100644 --- a/src/llmcompressor/modeling/qwen3_vl_moe.py +++ b/src/llmcompressor/modeling/qwen3_vl_moe.py @@ -1,19 +1,39 @@ import torch +from transformers import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig +from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( + Qwen3VLMoeTextSparseMoeBlock as OriginalQwen3VLMoeTextSparseMoeBlock, +) +from llmcompressor.modeling.moe_context import MoECalibrationModule from llmcompressor.utils.dev import skip_weights_initialize -class LinearQwen3VLMoeTextSparseMoeBlock(torch.nn.Module): - def __init__(self, config, original, calibrate_all_experts): +@MoECalibrationModule.register("Qwen3VLMoeTextSparseMoeBlock") +class CalibrateQwen3VLMoeTextSparseMoeBlock(MoECalibrationModule): + """ + Calibration version of Qwen3VLMoeTextSparseMoeBlock that sends all tokens to all + experts. + """ + + is_permanent = True + + def __init__( + self, + original: OriginalQwen3VLMoeTextSparseMoeBlock, + config: Qwen3VLMoeConfig, + calibrate_all_experts: bool, + ): super().__init__() - self.hidden_size = config.hidden_size - self.num_experts = config.num_experts + text_config: Qwen3VLMoeTextConfig = config.get_text_config() + + self.hidden_size = text_config.hidden_size + self.num_experts = text_config.num_experts self.top_k = original.top_k # Note: gate was changed to be a Linear layer in transformers==4.57.0 # https://github.com/JJJYmmm/transformers/commit/f5dea1c694af8c994c769170813a8702332119ee self.gate = original.gate self.calibrate_all_experts = calibrate_all_experts - self.experts = SequentialQwen3VLMoeTextExperts(config, original.experts) + self.experts = SequentialQwen3VLMoeTextExperts(text_config, original.experts) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape @@ -64,6 +84,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: next_states = next_states.reshape(batch_size, sequence_length, hidden_dim) return next_states, router_logits + def restore(self, original: torch.nn.Module) -> torch.nn.Module: + return original + class SequentialQwen3VLMoeTextExperts(torch.nn.ModuleList): def __init__(self, config, original): @@ -91,9 +114,13 @@ def __init__(self, config, original): self[i].down_proj.weight.data = down.t().clone().contiguous() -def replace(config, module, calibrate_all_experts): - return LinearQwen3VLMoeTextSparseMoeBlock( - config=config.get_text_config(), - original=module, +def replace( + config: Qwen3VLMoeConfig, + original: OriginalQwen3VLMoeTextSparseMoeBlock, + calibrate_all_experts: bool, +): + return CalibrateQwen3VLMoeTextSparseMoeBlock( + original=original, + config=config, calibrate_all_experts=calibrate_all_experts, ) diff --git a/tests/llmcompressor/modeling/test_calib_qwen3_vl_moe.py b/tests/llmcompressor/modeling/test_calib_qwen3_vl_moe.py index 513fb8cd7..46694a38d 100644 --- a/tests/llmcompressor/modeling/test_calib_qwen3_vl_moe.py +++ b/tests/llmcompressor/modeling/test_calib_qwen3_vl_moe.py @@ -1,20 +1,19 @@ import torch +from transformers import Qwen3VLMoeConfig +from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( + Qwen3VLMoeTextSparseMoeBlock, +) -from llmcompressor.modeling.qwen3_vl_moe import LinearQwen3VLMoeTextSparseMoeBlock +from llmcompressor.modeling.qwen3_vl_moe import CalibrateQwen3VLMoeTextSparseMoeBlock from llmcompressor.utils.helpers import calibration_forward_context from tests.testing_utils import requires_gpu @requires_gpu def test_calib_qwen3_vl_moe_module(): - from transformers import Qwen3VLMoeTextConfig - from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( - Qwen3VLMoeTextSparseMoeBlock, - ) - - config = Qwen3VLMoeTextConfig() + config = Qwen3VLMoeConfig() with torch.device("cuda"): - original = Qwen3VLMoeTextSparseMoeBlock(config).eval() + original = Qwen3VLMoeTextSparseMoeBlock(config.get_text_config()).eval() # these are initialized as empty / all 0s which results in outputs # from the experts being all 0 # update to use a small random value @@ -22,23 +21,23 @@ def test_calib_qwen3_vl_moe_module(): original.experts.down_proj.data.normal_(mean=0.0, std=0.02) # Create dummy input tensor that simulates hidden_states - hidden_dim = config.hidden_size + hidden_dim = config.get_text_config().hidden_size batch, seq_len = 4, 32 sample = torch.randn(batch, seq_len, hidden_dim, device="cuda") with calibration_forward_context(original): true_output = original(sample) - module = LinearQwen3VLMoeTextSparseMoeBlock( - config, original, calibrate_all_experts=True + module = CalibrateQwen3VLMoeTextSparseMoeBlock( + original, config, calibrate_all_experts=True ) with calibration_forward_context(module): output = module(sample) assert torch.nn.functional.mse_loss(true_output[0], output[0]) < 1e-10 assert torch.nn.functional.mse_loss(true_output[1], output[1]) < 1e-10 - module = LinearQwen3VLMoeTextSparseMoeBlock( - config, original, calibrate_all_experts=False + module = CalibrateQwen3VLMoeTextSparseMoeBlock( + original, config, calibrate_all_experts=False ) with calibration_forward_context(module): output = module(sample)