88
99Key components:
1010- MoECalibrationModule: Abstract base class for calibration modules
11- - MOE_CALIBRATION_MODULES: Registry mapping module class names to calibration classes
1211- moe_calibration_context: Context manager that applies calibration to a model
1312"""
1413
1514import contextlib
1615from abc import ABC
17- from typing import Dict , Type
1816
1917import torch
18+ from compressed_tensors .registry import RegistryMixin , standardize_lookup_name
2019from loguru import logger
2120from tqdm import tqdm
2221from transformers import PreTrainedModel
2322
2423__all__ = [
2524 "MoECalibrationModule" ,
26- "MOE_CALIBRATION_MODULES" ,
27- "register_moe_calibration" ,
2825 "moe_calibration_context" ,
2926]
3027
3128
32- class MoECalibrationModule (ABC , torch .nn .Module ):
29+ class MoECalibrationModule (ABC , torch .nn .Module , RegistryMixin ):
3330 """
3431 Abstract base class for MoE calibration modules.
3532
@@ -62,32 +59,6 @@ def restore(self, original: torch.nn.Module) -> torch.nn.Module:
6259 )
6360
6461
65- # Registry: module class name -> calibration module class
66- MOE_CALIBRATION_MODULES : Dict [str , Type [MoECalibrationModule ]] = {}
67-
68-
69- def register_moe_calibration (module_class_name : str ):
70- """
71- Decorator to register a MoE calibration module.
72-
73- Usage:
74- @register_moe_calibration("DeepseekV3MoE")
75- class CalibrationDeepseekV3MoE(MoECalibrationModule):
76- ...
77-
78- Args:
79- module_class_name: The class name of the original module to replace
80- """
81-
82- def decorator (cls : Type [MoECalibrationModule ]) -> Type [MoECalibrationModule ]:
83- if not issubclass (cls , MoECalibrationModule ):
84- raise TypeError (f"{ cls .__name__ } must inherit from MoECalibrationModule" )
85- MOE_CALIBRATION_MODULES [module_class_name ] = cls
86- return cls
87-
88- return decorator
89-
90-
9162@contextlib .contextmanager
9263def moe_calibration_context (
9364 model : PreTrainedModel ,
@@ -115,12 +86,6 @@ def moe_calibration_context(
11586 model(**batch)
11687 # Model is now restored (unless permanent)
11788 """
118- # trigger registration
119- from .deepseek_v3 import CalibrationDeepseekV3MoE # noqa: F401
120- from .llama4 import SequentialLlama4TextMoe # noqa: F401
121- from .qwen3_moe import CalibrationQwen3MoeSparseMoeBlock # noqa: F401
122- from .qwen3_vl_moe import CalibrateQwen3VLMoeTextSparseMoeBlock # noqa: F401
123- # TODO: add granite4, Qwen3Next
12489
12590 replaced = {}
12691
@@ -129,7 +94,7 @@ def moe_calibration_context(
12994 modules_to_replace = []
13095 for name , module in model .named_modules ():
13196 class_name = module .__class__ .__name__
132- if class_name in MOE_CALIBRATION_MODULES :
97+ if _is_registered ( class_name , MoECalibrationModule ) :
13398 modules_to_replace .append ((name , module , class_name ))
13499
135100 # Step 2: Replace modules with progress bar
@@ -138,8 +103,8 @@ def moe_calibration_context(
138103 for name , module , class_name in tqdm (
139104 modules_to_replace , desc = "Replacing MoE modules for calibration"
140105 ):
141- calibration_cls = MOE_CALIBRATION_MODULES [ class_name ]
142- replacement = calibration_cls (
106+ replacement = MoECalibrationModule . load_from_registry (
107+ class_name ,
143108 module ,
144109 model .config ,
145110 calibrate_all_experts = calibrate_all_experts ,
@@ -172,3 +137,7 @@ def moe_calibration_context(
172137 if not replacement .is_permanent :
173138 restored = replacement .restore (original )
174139 model .set_submodule (name , restored )
140+
141+
142+ def _is_registered (name : str , subclass : RegistryMixin ):
143+ return standardize_lookup_name (name ) in subclass .registered_names ()
0 commit comments