Skip to content

Commit 4e78ae4

Browse files
committed
use registry mixin
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent fc34db6 commit 4e78ae4

File tree

5 files changed

+14
-53
lines changed

5 files changed

+14
-53
lines changed

src/llmcompressor/modeling/deepseek_v3.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,10 @@
44
DeepseekV3MoE as OriginalDeepseekV3MoE,
55
)
66

7-
from llmcompressor.modeling.moe_context import (
8-
MoECalibrationModule,
9-
register_moe_calibration,
10-
)
7+
from llmcompressor.modeling.moe_context import MoECalibrationModule
118

129

13-
@register_moe_calibration("DeepseekV3MoE")
10+
@MoECalibrationModule.register("DeepseekV3MoE")
1411
class CalibrationDeepseekV3MoE(MoECalibrationModule):
1512
"""
1613
Calibration version of DeepseekV3MoE that sends all tokens to all experts.

src/llmcompressor/modeling/llama4.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,11 @@
1111
Llama4TextMoe,
1212
)
1313

14-
from llmcompressor.modeling.moe_context import (
15-
MoECalibrationModule,
16-
register_moe_calibration,
17-
)
14+
from llmcompressor.modeling.moe_context import MoECalibrationModule
1815
from llmcompressor.utils.dev import skip_weights_initialize
1916

2017

21-
@register_moe_calibration("Llama4TextMoe")
18+
@MoECalibrationModule.register("Llama4TextMoe")
2219
class SequentialLlama4TextMoe(MoECalibrationModule):
2320
"""
2421
Calibration version of Llama4TextMoe that unpacks experts for sequential processing.

src/llmcompressor/modeling/moe_context.py

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,20 @@
1414

1515
import contextlib
1616
from abc import ABC
17-
from typing import Dict, Type
1817

1918
import torch
19+
from compressed_tensors.registry import RegistryMixin
2020
from loguru import logger
2121
from tqdm import tqdm
2222
from transformers import PreTrainedModel
2323

2424
__all__ = [
2525
"MoECalibrationModule",
26-
"MOE_CALIBRATION_MODULES",
27-
"register_moe_calibration",
2826
"moe_calibration_context",
2927
]
3028

3129

32-
class MoECalibrationModule(ABC, torch.nn.Module):
30+
class MoECalibrationModule(ABC, torch.nn.Module, RegistryMixin):
3331
"""
3432
Abstract base class for MoE calibration modules.
3533
@@ -62,32 +60,6 @@ def restore(self, original: torch.nn.Module) -> torch.nn.Module:
6260
)
6361

6462

65-
# Registry: module class name -> calibration module class
66-
MOE_CALIBRATION_MODULES: Dict[str, Type[MoECalibrationModule]] = {}
67-
68-
69-
def register_moe_calibration(module_class_name: str):
70-
"""
71-
Decorator to register a MoE calibration module.
72-
73-
Usage:
74-
@register_moe_calibration("DeepseekV3MoE")
75-
class CalibrationDeepseekV3MoE(MoECalibrationModule):
76-
...
77-
78-
Args:
79-
module_class_name: The class name of the original module to replace
80-
"""
81-
82-
def decorator(cls: Type[MoECalibrationModule]) -> Type[MoECalibrationModule]:
83-
if not issubclass(cls, MoECalibrationModule):
84-
raise TypeError(f"{cls.__name__} must inherit from MoECalibrationModule")
85-
MOE_CALIBRATION_MODULES[module_class_name] = cls
86-
return cls
87-
88-
return decorator
89-
90-
9163
@contextlib.contextmanager
9264
def moe_calibration_context(
9365
model: PreTrainedModel,
@@ -127,9 +99,10 @@ def moe_calibration_context(
12799
# Step 1: Collect all MoE modules that need replacement
128100
logger.debug("Entering MoE calibration context")
129101
modules_to_replace = []
102+
moe_class_names = MoECalibrationModule.registered_names()
130103
for name, module in model.named_modules():
131104
class_name = module.__class__.__name__
132-
if class_name in MOE_CALIBRATION_MODULES:
105+
if class_name in moe_class_names:
133106
modules_to_replace.append((name, module, class_name))
134107

135108
# Step 2: Replace modules with progress bar
@@ -138,8 +111,8 @@ def moe_calibration_context(
138111
for name, module, class_name in tqdm(
139112
modules_to_replace, desc="Replacing MoE modules for calibration"
140113
):
141-
calibration_cls = MOE_CALIBRATION_MODULES[class_name]
142-
replacement = calibration_cls(
114+
replacement = MoECalibrationModule.load_from_registry(
115+
class_name,
143116
module,
144117
model.config,
145118
calibrate_all_experts=calibrate_all_experts,

src/llmcompressor/modeling/qwen3_moe.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,10 @@
2020
Qwen3MoeSparseMoeBlock as OriginalQwen3MoeSparseMoeBlock,
2121
)
2222

23-
from llmcompressor.modeling.moe_context import (
24-
MoECalibrationModule,
25-
register_moe_calibration,
26-
)
23+
from llmcompressor.modeling.moe_context import MoECalibrationModule
2724

2825

29-
@register_moe_calibration("Qwen3MoeSparseMoeBlock")
26+
@MoECalibrationModule.register("Qwen3MoeSparseMoeBlock")
3027
class CalibrationQwen3MoeSparseMoeBlock(MoECalibrationModule):
3128
"""
3229
Calibration version of Qwen3MoeSparseMoeBlock that sends all tokens to all experts.

src/llmcompressor/modeling/qwen3_vl_moe.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,11 @@
44
Qwen3VLMoeTextSparseMoeBlock as OriginalQwen3VLMoeTextSparseMoeBlock,
55
)
66

7-
from llmcompressor.modeling.moe_context import (
8-
MoECalibrationModule,
9-
register_moe_calibration,
10-
)
7+
from llmcompressor.modeling.moe_context import MoECalibrationModule
118
from llmcompressor.utils.dev import skip_weights_initialize
129

1310

14-
@register_moe_calibration("CalibrationQwen3VLMoeTextSparseMoeBlock")
11+
@MoECalibrationModule.register("CalibrationQwen3VLMoeTextSparseMoeBlock")
1512
class CalibrateQwen3VLMoeTextSparseMoeBlock(MoECalibrationModule):
1613
"""
1714
Calibration version of Qwen3VLMoeTextSparseMoeBlock that sends all tokens to all

0 commit comments

Comments
 (0)