1414
1515import contextlib
1616from abc import ABC
17+ from typing import Dict , Type
1718
1819import torch
19- from compressed_tensors .registry import RegistryMixin
2020from loguru import logger
2121from tqdm import tqdm
2222from transformers import PreTrainedModel
2323
2424__all__ = [
2525 "MoECalibrationModule" ,
26+ "MOE_CALIBRATION_MODULES" ,
27+ "register_moe_calibration" ,
2628 "moe_calibration_context" ,
2729]
2830
2931
30- class MoECalibrationModule (ABC , torch .nn .Module , RegistryMixin ):
32+ class MoECalibrationModule (ABC , torch .nn .Module ):
3133 """
3234 Abstract base class for MoE calibration modules.
3335
@@ -60,6 +62,32 @@ def restore(self) -> torch.nn.Module:
6062 )
6163
6264
65+ # Registry: module class name -> calibration module class
66+ MOE_CALIBRATION_MODULES : Dict [str , Type [MoECalibrationModule ]] = {}
67+
68+
69+ def register_moe_calibration (module_class_name : str ):
70+ """
71+ Decorator to register a MoE calibration module.
72+
73+ Usage:
74+ @register_moe_calibration("DeepseekV3MoE")
75+ class CalibrationDeepseekV3MoE(MoECalibrationModule):
76+ ...
77+
78+ Args:
79+ module_class_name: The class name of the original module to replace
80+ """
81+
82+ def decorator (cls : Type [MoECalibrationModule ]) -> Type [MoECalibrationModule ]:
83+ if not issubclass (cls , MoECalibrationModule ):
84+ raise TypeError (f"{ cls .__name__ } must inherit from MoECalibrationModule" )
85+ MOE_CALIBRATION_MODULES [module_class_name ] = cls
86+ return cls
87+
88+ return decorator
89+
90+
6391@contextlib .contextmanager
6492def moe_calibration_context (
6593 model : PreTrainedModel ,
@@ -99,10 +127,9 @@ def moe_calibration_context(
99127 # Step 1: Collect all MoE modules that need replacement
100128 logger .debug ("Entering MoE calibration context" )
101129 modules_to_replace = []
102- moe_class_names = MoECalibrationModule .registered_names ()
103130 for name , module in model .named_modules ():
104131 class_name = module .__class__ .__name__
105- if class_name in moe_class_names :
132+ if class_name in MOE_CALIBRATION_MODULES :
106133 modules_to_replace .append ((name , module , class_name ))
107134
108135 # Step 2: Replace modules with progress bar
@@ -111,8 +138,8 @@ def moe_calibration_context(
111138 for name , module , class_name in tqdm (
112139 modules_to_replace , desc = "Replacing MoE modules for calibration"
113140 ):
114- replacement = MoECalibrationModule . load_from_registry (
115- class_name ,
141+ calibration_cls = MOE_CALIBRATION_MODULES [ class_name ]
142+ replacement = calibration_cls (
116143 module ,
117144 model .config ,
118145 calibrate_all_experts = calibrate_all_experts ,
0 commit comments