1414
1515import contextlib
1616from abc import ABC
17- from typing import Dict , Type
1817
1918import torch
19+ from compressed_tensors .registry import RegistryMixin
2020from loguru import logger
2121from tqdm import tqdm
2222from transformers import PreTrainedModel
2323
2424__all__ = [
2525 "MoECalibrationModule" ,
26- "MOE_CALIBRATION_MODULES" ,
27- "register_moe_calibration" ,
2826 "moe_calibration_context" ,
2927]
3028
3129
32- class MoECalibrationModule (ABC , torch .nn .Module ):
30+ class MoECalibrationModule (ABC , torch .nn .Module , RegistryMixin ):
3331 """
3432 Abstract base class for MoE calibration modules.
3533
@@ -62,32 +60,6 @@ def restore(self, original: torch.nn.Module) -> torch.nn.Module:
6260 )
6361
6462
65- # Registry: module class name -> calibration module class
66- MOE_CALIBRATION_MODULES : Dict [str , Type [MoECalibrationModule ]] = {}
67-
68-
69- def register_moe_calibration (module_class_name : str ):
70- """
71- Decorator to register a MoE calibration module.
72-
73- Usage:
74- @register_moe_calibration("DeepseekV3MoE")
75- class CalibrationDeepseekV3MoE(MoECalibrationModule):
76- ...
77-
78- Args:
79- module_class_name: The class name of the original module to replace
80- """
81-
82- def decorator (cls : Type [MoECalibrationModule ]) -> Type [MoECalibrationModule ]:
83- if not issubclass (cls , MoECalibrationModule ):
84- raise TypeError (f"{ cls .__name__ } must inherit from MoECalibrationModule" )
85- MOE_CALIBRATION_MODULES [module_class_name ] = cls
86- return cls
87-
88- return decorator
89-
90-
9163@contextlib .contextmanager
9264def moe_calibration_context (
9365 model : PreTrainedModel ,
@@ -127,9 +99,10 @@ def moe_calibration_context(
12799 # Step 1: Collect all MoE modules that need replacement
128100 logger .debug ("Entering MoE calibration context" )
129101 modules_to_replace = []
102+ moe_class_names = MoECalibrationModule .registered_names ()
130103 for name , module in model .named_modules ():
131104 class_name = module .__class__ .__name__
132- if class_name in MOE_CALIBRATION_MODULES :
105+ if class_name in moe_class_names :
133106 modules_to_replace .append ((name , module , class_name ))
134107
135108 # Step 2: Replace modules with progress bar
@@ -138,8 +111,8 @@ def moe_calibration_context(
138111 for name , module , class_name in tqdm (
139112 modules_to_replace , desc = "Replacing MoE modules for calibration"
140113 ):
141- calibration_cls = MOE_CALIBRATION_MODULES [ class_name ]
142- replacement = calibration_cls (
114+ replacement = MoECalibrationModule . load_from_registry (
115+ class_name ,
143116 module ,
144117 model .config ,
145118 calibrate_all_experts = calibrate_all_experts ,
0 commit comments