Skip to content

Commit 8091a43

Browse files
committed
fix issues, use registry
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent bb39dac commit 8091a43

File tree

8 files changed

+96
-66
lines changed

8 files changed

+96
-66
lines changed

src/llmcompressor/modeling/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,12 @@
99
needed for efficient compression.
1010
"""
1111

12+
# trigger registration
13+
from .deepseek_v3 import CalibrationDeepseekV3MoE # noqa: F401
14+
from .llama4 import SequentialLlama4TextMoe # noqa: F401
15+
from .qwen3_moe import CalibrationQwen3MoeSparseMoeBlock # noqa: F401
16+
from .qwen3_vl_moe import CalibrateQwen3VLMoeTextSparseMoeBlock # noqa: F401
17+
# TODO: add granite4, Qwen3Next
18+
1219
from .fuse import *
1320
from .prepare import *

src/llmcompressor/modeling/deepseek_v3.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,10 @@
44
DeepseekV3MoE as OriginalDeepseekV3MoE,
55
)
66

7-
from llmcompressor.modeling.moe_context import (
8-
MoECalibrationModule,
9-
register_moe_calibration,
10-
)
7+
from llmcompressor.modeling.moe_context import MoECalibrationModule
118

129

13-
@register_moe_calibration("DeepseekV3MoE")
10+
@MoECalibrationModule.register("DeepseekV3MoE")
1411
class CalibrationDeepseekV3MoE(MoECalibrationModule):
1512
"""
1613
Calibration version of DeepseekV3MoE that sends all tokens to all experts.

src/llmcompressor/modeling/llama4.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,11 @@
1111
Llama4TextMoe,
1212
)
1313

14-
from llmcompressor.modeling.moe_context import (
15-
MoECalibrationModule,
16-
register_moe_calibration,
17-
)
14+
from llmcompressor.modeling.moe_context import MoECalibrationModule
1815
from llmcompressor.utils.dev import skip_weights_initialize
1916

2017

21-
@register_moe_calibration("Llama4TextMoe")
18+
@MoECalibrationModule.register("Llama4TextMoe")
2219
class SequentialLlama4TextMoe(MoECalibrationModule):
2320
"""
2421
Calibration version of Llama4TextMoe that unpacks experts for sequential processing.

src/llmcompressor/modeling/moe_context.py

Lines changed: 9 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,25 @@
88
99
Key components:
1010
- MoECalibrationModule: Abstract base class for calibration modules
11-
- MOE_CALIBRATION_MODULES: Registry mapping module class names to calibration classes
1211
- moe_calibration_context: Context manager that applies calibration to a model
1312
"""
1413

1514
import contextlib
1615
from abc import ABC
17-
from typing import Dict, Type
1816

1917
import torch
18+
from compressed_tensors.registry import RegistryMixin, standardize_lookup_name
2019
from loguru import logger
2120
from tqdm import tqdm
2221
from transformers import PreTrainedModel
2322

2423
__all__ = [
2524
"MoECalibrationModule",
26-
"MOE_CALIBRATION_MODULES",
27-
"register_moe_calibration",
2825
"moe_calibration_context",
2926
]
3027

3128

32-
class MoECalibrationModule(ABC, torch.nn.Module):
29+
class MoECalibrationModule(ABC, torch.nn.Module, RegistryMixin):
3330
"""
3431
Abstract base class for MoE calibration modules.
3532
@@ -62,32 +59,6 @@ def restore(self, original: torch.nn.Module) -> torch.nn.Module:
6259
)
6360

6461

65-
# Registry: module class name -> calibration module class
66-
MOE_CALIBRATION_MODULES: Dict[str, Type[MoECalibrationModule]] = {}
67-
68-
69-
def register_moe_calibration(module_class_name: str):
70-
"""
71-
Decorator to register a MoE calibration module.
72-
73-
Usage:
74-
@register_moe_calibration("DeepseekV3MoE")
75-
class CalibrationDeepseekV3MoE(MoECalibrationModule):
76-
...
77-
78-
Args:
79-
module_class_name: The class name of the original module to replace
80-
"""
81-
82-
def decorator(cls: Type[MoECalibrationModule]) -> Type[MoECalibrationModule]:
83-
if not issubclass(cls, MoECalibrationModule):
84-
raise TypeError(f"{cls.__name__} must inherit from MoECalibrationModule")
85-
MOE_CALIBRATION_MODULES[module_class_name] = cls
86-
return cls
87-
88-
return decorator
89-
90-
9162
@contextlib.contextmanager
9263
def moe_calibration_context(
9364
model: PreTrainedModel,
@@ -115,12 +86,6 @@ def moe_calibration_context(
11586
model(**batch)
11687
# Model is now restored (unless permanent)
11788
"""
118-
# trigger registration
119-
from .deepseek_v3 import CalibrationDeepseekV3MoE # noqa: F401
120-
from .llama4 import SequentialLlama4TextMoe # noqa: F401
121-
from .qwen3_moe import CalibrationQwen3MoeSparseMoeBlock # noqa: F401
122-
from .qwen3_vl_moe import CalibrateQwen3VLMoeTextSparseMoeBlock # noqa: F401
123-
# TODO: add granite4, Qwen3Next
12489

12590
replaced = {}
12691

@@ -129,7 +94,7 @@ def moe_calibration_context(
12994
modules_to_replace = []
13095
for name, module in model.named_modules():
13196
class_name = module.__class__.__name__
132-
if class_name in MOE_CALIBRATION_MODULES:
97+
if _is_registered(class_name, MoECalibrationModule):
13398
modules_to_replace.append((name, module, class_name))
13499

135100
# Step 2: Replace modules with progress bar
@@ -138,8 +103,8 @@ def moe_calibration_context(
138103
for name, module, class_name in tqdm(
139104
modules_to_replace, desc="Replacing MoE modules for calibration"
140105
):
141-
calibration_cls = MOE_CALIBRATION_MODULES[class_name]
142-
replacement = calibration_cls(
106+
replacement = MoECalibrationModule.load_from_registry(
107+
class_name,
143108
module,
144109
model.config,
145110
calibrate_all_experts=calibrate_all_experts,
@@ -172,3 +137,7 @@ def moe_calibration_context(
172137
if not replacement.is_permanent:
173138
restored = replacement.restore(original)
174139
model.set_submodule(name, restored)
140+
141+
142+
def _is_registered(name: str, subclass: RegistryMixin):
143+
return standardize_lookup_name(name) in subclass.registered_names()

src/llmcompressor/modeling/qwen3_moe.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,10 @@
2020
Qwen3MoeSparseMoeBlock as OriginalQwen3MoeSparseMoeBlock,
2121
)
2222

23-
from llmcompressor.modeling.moe_context import (
24-
MoECalibrationModule,
25-
register_moe_calibration,
26-
)
23+
from llmcompressor.modeling.moe_context import MoECalibrationModule
2724

2825

29-
@register_moe_calibration("Qwen3MoeSparseMoeBlock")
26+
@MoECalibrationModule.register("Qwen3MoeSparseMoeBlock")
3027
class CalibrationQwen3MoeSparseMoeBlock(MoECalibrationModule):
3128
"""
3229
Calibration version of Qwen3MoeSparseMoeBlock that sends all tokens to all experts.

src/llmcompressor/modeling/qwen3_vl_moe.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,11 @@
44
Qwen3VLMoeTextSparseMoeBlock as OriginalQwen3VLMoeTextSparseMoeBlock,
55
)
66

7-
from llmcompressor.modeling.moe_context import (
8-
MoECalibrationModule,
9-
register_moe_calibration,
10-
)
7+
from llmcompressor.modeling.moe_context import MoECalibrationModule
118
from llmcompressor.utils.dev import skip_weights_initialize
129

1310

14-
@register_moe_calibration("CalibrationQwen3VLMoeTextSparseMoeBlock")
11+
@MoECalibrationModule.register("Qwen3VLMoeTextSparseMoeBlock")
1512
class CalibrateQwen3VLMoeTextSparseMoeBlock(MoECalibrationModule):
1613
"""
1714
Calibration version of Qwen3VLMoeTextSparseMoeBlock that sends all tokens to all
@@ -118,7 +115,7 @@ def replace(
118115
calibrate_all_experts: bool,
119116
):
120117
return CalibrateQwen3VLMoeTextSparseMoeBlock(
121-
config=config.get_text_config(),
122118
original=original,
119+
config=config,
123120
calibrate_all_experts=calibrate_all_experts,
124121
)

tests/llmcompressor/modeling/test_calib_qwen3_vl_moe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from transformers import Qwen3VLMoeTextConfig
2+
from transformers import Qwen3VLMoeConfig
33
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
44
Qwen3VLMoeTextSparseMoeBlock,
55
)
@@ -10,18 +10,18 @@
1010

1111

1212
@requires_gpu
13-
def test_calib_qwen3_moe_module():
14-
config = Qwen3VLMoeTextConfig()
13+
def test_calib_qwen3_vl_moe_module():
14+
config = Qwen3VLMoeConfig()
1515
with torch.device("cuda"):
16-
original = Qwen3VLMoeTextSparseMoeBlock(config).eval()
16+
original = Qwen3VLMoeTextSparseMoeBlock(config.get_text_config()).eval()
1717
# these are initialized as empty / all 0s which results in outputs
1818
# from the experts being all 0
1919
# update to use a small random value
2020
original.experts.gate_up_proj.data.normal_(mean=0.0, std=0.02)
2121
original.experts.down_proj.data.normal_(mean=0.0, std=0.02)
2222

2323
# Create dummy input tensor that simulates hidden_states
24-
hidden_dim = config.hidden_size
24+
hidden_dim = config.get_text_config().hidden_size
2525
batch, seq_len = 4, 32
2626
sample = torch.randn(batch, seq_len, hidden_dim, device="cuda")
2727

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import os
2+
3+
import pytest
4+
import torch
5+
from safetensors.torch import load_file
6+
7+
from llmcompressor import oneshot, ptq_weights
8+
from llmcompressor.modifiers.quantization import QuantizationModifier
9+
from tests.testing_utils import requires_gpu
10+
11+
12+
@requires_gpu
13+
@pytest.mark.parametrize("scheme", ["FP8_dynamic", "NVFP4A16"])
14+
def test_weights_ptq_e2e(scheme, tmp_path):
15+
model = "nm-testing/tinysmokellama-3.2"
16+
ptq_ignore = ["model.embed_tokens.weight", "lm_head.weight", "re:.*norm.weight$"]
17+
oneshot_ignore = ["lm_head"]
18+
device = "cuda:0"
19+
20+
ptq_outdir = tmp_path / "weights_out"
21+
oneshot_outdir = tmp_path / "oneshot_out"
22+
23+
ptq_weights(
24+
model,
25+
ptq_outdir,
26+
scheme=scheme,
27+
max_workers=2,
28+
device=device,
29+
ignore=ptq_ignore,
30+
)
31+
32+
oneshot(
33+
model=model,
34+
recipe=QuantizationModifier(
35+
targets="Linear", scheme=scheme, ignore=oneshot_ignore
36+
),
37+
output_dir=oneshot_outdir,
38+
)
39+
40+
ptq_st_files = _get_safetensors_files(ptq_outdir)
41+
oneshot_st_files = _get_safetensors_files(oneshot_outdir)
42+
assert set(ptq_st_files) == set(oneshot_st_files)
43+
44+
for file_name in ptq_st_files:
45+
_assert_safetensors_equal(ptq_outdir / file_name, oneshot_outdir / file_name)
46+
47+
48+
def _get_safetensors_files(dir_path: str) -> list[str]:
49+
return [
50+
file_name
51+
for file_name in os.listdir(dir_path)
52+
if file_name.endswith("safetensors")
53+
]
54+
55+
56+
def _assert_safetensors_equal(a_path: str, b_path: str) -> bool:
57+
a = load_file(a_path)
58+
b = load_file(b_path)
59+
60+
assert a.keys() == b.keys(), (a.keys() - b.keys(), b.keys() - a.keys())
61+
62+
for key in a.keys():
63+
value_equal = torch.equal(a[key].to(torch.bfloat16), b[key].to(torch.bfloat16))
64+
dtype_equal = a[key].dtype == b[key].dtype
65+
66+
assert value_equal and dtype_equal, (key, value_equal, dtype_equal)

0 commit comments

Comments
 (0)