Skip to content

Commit 8696bd9

Browse files
committed
Revert "use registry mixin"
This reverts commit 6dd0320.
1 parent 6dd0320 commit 8696bd9

File tree

5 files changed

+53
-14
lines changed

5 files changed

+53
-14
lines changed

src/llmcompressor/modeling/deepseek_v3.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
DeepseekV3MoE as OriginalDeepseekV3MoE,
55
)
66

7-
from llmcompressor.modeling.moe_context import MoECalibrationModule
7+
from llmcompressor.modeling.moe_context import (
8+
MoECalibrationModule,
9+
register_moe_calibration,
10+
)
811

912

10-
@MoECalibrationModule.register("DeepseekV3MoE")
13+
@register_moe_calibration("DeepseekV3MoE")
1114
class CalibrationDeepseekV3MoE(MoECalibrationModule):
1215
"""
1316
Calibration version of DeepseekV3MoE that sends all tokens to all experts.

src/llmcompressor/modeling/llama4.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,14 @@
1111
Llama4TextMoe,
1212
)
1313

14-
from llmcompressor.modeling.moe_context import MoECalibrationModule
14+
from llmcompressor.modeling.moe_context import (
15+
MoECalibrationModule,
16+
register_moe_calibration,
17+
)
1518
from llmcompressor.utils.dev import skip_weights_initialize
1619

1720

18-
@MoECalibrationModule.register("Llama4TextMoe")
21+
@register_moe_calibration("Llama4TextMoe")
1922
class SequentialLlama4TextMoe(MoECalibrationModule):
2023
"""
2124
Calibration version of Llama4TextMoe that unpacks experts for sequential processing.

src/llmcompressor/modeling/moe_context.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,22 @@
1414

1515
import contextlib
1616
from abc import ABC
17+
from typing import Dict, Type
1718

1819
import torch
19-
from compressed_tensors.registry import RegistryMixin
2020
from loguru import logger
2121
from tqdm import tqdm
2222
from transformers import PreTrainedModel
2323

2424
__all__ = [
2525
"MoECalibrationModule",
26+
"MOE_CALIBRATION_MODULES",
27+
"register_moe_calibration",
2628
"moe_calibration_context",
2729
]
2830

2931

30-
class MoECalibrationModule(ABC, torch.nn.Module, RegistryMixin):
32+
class MoECalibrationModule(ABC, torch.nn.Module):
3133
"""
3234
Abstract base class for MoE calibration modules.
3335
@@ -60,6 +62,32 @@ def restore(self) -> torch.nn.Module:
6062
)
6163

6264

65+
# Registry: module class name -> calibration module class
66+
MOE_CALIBRATION_MODULES: Dict[str, Type[MoECalibrationModule]] = {}
67+
68+
69+
def register_moe_calibration(module_class_name: str):
70+
"""
71+
Decorator to register a MoE calibration module.
72+
73+
Usage:
74+
@register_moe_calibration("DeepseekV3MoE")
75+
class CalibrationDeepseekV3MoE(MoECalibrationModule):
76+
...
77+
78+
Args:
79+
module_class_name: The class name of the original module to replace
80+
"""
81+
82+
def decorator(cls: Type[MoECalibrationModule]) -> Type[MoECalibrationModule]:
83+
if not issubclass(cls, MoECalibrationModule):
84+
raise TypeError(f"{cls.__name__} must inherit from MoECalibrationModule")
85+
MOE_CALIBRATION_MODULES[module_class_name] = cls
86+
return cls
87+
88+
return decorator
89+
90+
6391
@contextlib.contextmanager
6492
def moe_calibration_context(
6593
model: PreTrainedModel,
@@ -99,10 +127,9 @@ def moe_calibration_context(
99127
# Step 1: Collect all MoE modules that need replacement
100128
logger.debug("Entering MoE calibration context")
101129
modules_to_replace = []
102-
moe_class_names = MoECalibrationModule.registered_names()
103130
for name, module in model.named_modules():
104131
class_name = module.__class__.__name__
105-
if class_name in moe_class_names:
132+
if class_name in MOE_CALIBRATION_MODULES:
106133
modules_to_replace.append((name, module, class_name))
107134

108135
# Step 2: Replace modules with progress bar
@@ -111,8 +138,8 @@ def moe_calibration_context(
111138
for name, module, class_name in tqdm(
112139
modules_to_replace, desc="Replacing MoE modules for calibration"
113140
):
114-
replacement = MoECalibrationModule.load_from_registry(
115-
class_name,
141+
calibration_cls = MOE_CALIBRATION_MODULES[class_name]
142+
replacement = calibration_cls(
116143
module,
117144
model.config,
118145
calibrate_all_experts=calibrate_all_experts,

src/llmcompressor/modeling/qwen3_moe.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@
2020
Qwen3MoeSparseMoeBlock as OriginalQwen3MoeSparseMoeBlock,
2121
)
2222

23-
from llmcompressor.modeling.moe_context import MoECalibrationModule
23+
from llmcompressor.modeling.moe_context import (
24+
MoECalibrationModule,
25+
register_moe_calibration,
26+
)
2427

2528

26-
@MoECalibrationModule.register("Qwen3MoeSparseMoeBlock")
29+
@register_moe_calibration("Qwen3MoeSparseMoeBlock")
2730
class CalibrationQwen3MoeSparseMoeBlock(MoECalibrationModule):
2831
"""
2932
Calibration version of Qwen3MoeSparseMoeBlock that sends all tokens to all experts.

src/llmcompressor/modeling/qwen3_vl_moe.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44
Qwen3VLMoeTextSparseMoeBlock as OriginalQwen3VLMoeTextSparseMoeBlock,
55
)
66

7-
from llmcompressor.modeling.moe_context import MoECalibrationModule
7+
from llmcompressor.modeling.moe_context import (
8+
MoECalibrationModule,
9+
register_moe_calibration,
10+
)
811
from llmcompressor.utils.dev import skip_weights_initialize
912

1013

11-
@MoECalibrationModule.register("CalibrationQwen3VLMoeTextSparseMoeBlock")
14+
@register_moe_calibration("CalibrationQwen3VLMoeTextSparseMoeBlock")
1215
class CalibrateQwen3VLMoeTextSparseMoeBlock(MoECalibrationModule):
1316
"""
1417
Calibration version of Qwen3VLMoeTextSparseMoeBlock that sends all tokens to all

0 commit comments

Comments
 (0)