Skip to content

Commit 3ba4f00

Browse files
committed
apply patch
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 1c85a66 commit 3ba4f00

File tree

4 files changed

+59
-50
lines changed

4 files changed

+59
-50
lines changed

src/llmcompressor/modeling/llama4.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,8 @@ def __init__(
3838
calibrate_all_experts: bool = True,
3939
):
4040
super().__init__()
41-
# Extract text config from multimodal config if needed
42-
text_config = (
43-
config.get_text_config() if hasattr(config, "get_text_config") else config
44-
)
41+
# Extract text config from multimodal config
42+
text_config: Llama4TextConfig = config.get_text_config()
4543
self.top_k = text_config.num_experts_per_tok
4644
self.hidden_dim = text_config.hidden_size
4745
self.num_experts = text_config.num_local_experts

src/llmcompressor/modeling/prepare.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,33 +10,20 @@
1010
from compressed_tensors.utils import deprecated, replace_module
1111
from transformers import PreTrainedModel
1212

13-
# Import MoE calibration modules to trigger registration
14-
from llmcompressor.modeling.deepseek_v3 import ( # noqa: F401
15-
CalibrationDeepseekV3MoE,
16-
)
17-
from llmcompressor.modeling.deepseek_v3 import (
18-
replace as replace_deepseekv3,
19-
)
20-
from llmcompressor.modeling.llama4 import ( # noqa: F401
21-
SequentialLlama4TextMoe,
22-
)
23-
from llmcompressor.modeling.llama4 import (
24-
replace as replace_llama4,
25-
)
26-
from llmcompressor.modeling.moe_context import ( # noqa: F401
27-
moe_calibration_context,
28-
)
29-
from llmcompressor.modeling.qwen3_moe import ( # noqa: F401
30-
CalibrationQwen3MoeSparseMoeBlock,
31-
)
32-
from llmcompressor.modeling.qwen3_next_moe import ( # noqa: F401
33-
CalibrationQwen3NextSparseMoeBlock,
34-
)
35-
from llmcompressor.modeling.qwen3_vl_moe import (
36-
replace as replace_Qwen3VLMoE,
37-
)
13+
# deprecated replacement functions
14+
from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3
15+
from llmcompressor.modeling.llama4 import replace as replace_llama4
16+
from llmcompressor.modeling.qwen3_vl_moe import replace as replace_Qwen3VLMoE
17+
18+
# trigger registration
19+
from .deepseek_v3 import CalibrationDeepseekV3MoE # noqa: F401
20+
from .llama4 import SequentialLlama4TextMoe # noqa: F401
21+
from .qwen3_moe import CalibrationQwen3MoeSparseMoeBlock # noqa: F401
22+
from .qwen3_vl_moe import CalibrateQwen3VLMoeTextSparseMoeBlock # noqa: F401
23+
24+
# TODO: add granite4, Qwen3Next
3825

39-
__all__ = ["moe_calibration_context", "replace_modules_for_calibration"]
26+
__all__ = ["replace_modules_for_calibration"]
4027

4128
# ---------------------- module replacements; permanent -------------------------
4229
replacements = {

src/llmcompressor/modeling/qwen3_vl_moe.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,40 @@
11
import torch
2-
2+
from transformers import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig
3+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
4+
Qwen3VLMoeTextSparseMoeBlock as OriginalQwen3VLMoeTextSparseMoeBlock,
5+
)
6+
7+
from llmcompressor.modeling.moe_context import (
8+
MoECalibrationModule,
9+
register_moe_calibration,
10+
)
311
from llmcompressor.utils.dev import skip_weights_initialize
412

513

6-
class LinearQwen3VLMoeTextSparseMoeBlock(torch.nn.Module):
7-
def __init__(self, config, original, calibrate_all_experts):
14+
@register_moe_calibration("CalibrationQwen3VLMoeTextSparseMoeBlock")
15+
class CalibrateQwen3VLMoeTextSparseMoeBlock(MoECalibrationModule):
16+
"""
17+
Calibration version of Qwen3VLMoeTextSparseMoeBlock that sends all tokens to all
18+
experts.
19+
"""
20+
21+
def __init__(
22+
self,
23+
original: OriginalQwen3VLMoeTextSparseMoeBlock,
24+
config: Qwen3VLMoeConfig,
25+
calibrate_all_experts: bool,
26+
):
827
super().__init__()
9-
self.hidden_size = config.hidden_size
10-
self.num_experts = config.num_experts
28+
text_config: Qwen3VLMoeTextConfig = config.get_text_config()
29+
30+
self.hidden_size = text_config.hidden_size
31+
self.num_experts = text_config.num_experts
1132
self.top_k = original.top_k
1233
# Note: gate was changed to be a Linear layer in transformers==4.57.0
1334
# https://github.com/JJJYmmm/transformers/commit/f5dea1c694af8c994c769170813a8702332119ee
1435
self.gate = original.gate
1536
self.calibrate_all_experts = calibrate_all_experts
16-
self.experts = SequentialQwen3VLMoeTextExperts(config, original.experts)
37+
self.experts = SequentialQwen3VLMoeTextExperts(text_config, original.experts)
1738

1839
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1940
batch_size, sequence_length, hidden_dim = hidden_states.shape
@@ -91,9 +112,13 @@ def __init__(self, config, original):
91112
self[i].down_proj.weight.data = down.t().clone().contiguous()
92113

93114

94-
def replace(config, module, calibrate_all_experts):
95-
return LinearQwen3VLMoeTextSparseMoeBlock(
115+
def replace(
116+
config: Qwen3VLMoeConfig,
117+
original: OriginalQwen3VLMoeTextSparseMoeBlock,
118+
calibrate_all_experts: bool,
119+
):
120+
return CalibrateQwen3VLMoeTextSparseMoeBlock(
96121
config=config.get_text_config(),
97-
original=module,
122+
original=original,
98123
calibrate_all_experts=calibrate_all_experts,
99124
)

tests/llmcompressor/modeling/test_calib_qwen3_vl_moe.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
import torch
2+
from transformers import Qwen3VLMoeTextConfig
3+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
4+
Qwen3VLMoeTextSparseMoeBlock,
5+
)
26

3-
from llmcompressor.modeling.qwen3_vl_moe import LinearQwen3VLMoeTextSparseMoeBlock
7+
from llmcompressor.modeling.qwen3_vl_moe import CalibrateQwen3VLMoeTextSparseMoeBlock
48
from llmcompressor.utils.helpers import calibration_forward_context
59
from tests.testing_utils import requires_gpu
610

711

812
@requires_gpu
9-
def test_calib_qwen3_vl_moe_module():
10-
from transformers import Qwen3VLMoeTextConfig
11-
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
12-
Qwen3VLMoeTextSparseMoeBlock,
13-
)
14-
13+
def test_calib_qwen3_moe_module():
1514
config = Qwen3VLMoeTextConfig()
1615
with torch.device("cuda"):
1716
original = Qwen3VLMoeTextSparseMoeBlock(config).eval()
@@ -29,16 +28,16 @@ def test_calib_qwen3_vl_moe_module():
2928
with calibration_forward_context(original):
3029
true_output = original(sample)
3130

32-
module = LinearQwen3VLMoeTextSparseMoeBlock(
33-
config, original, calibrate_all_experts=True
31+
module = CalibrateQwen3VLMoeTextSparseMoeBlock(
32+
original, config, calibrate_all_experts=True
3433
)
3534
with calibration_forward_context(module):
3635
output = module(sample)
3736
assert torch.nn.functional.mse_loss(true_output[0], output[0]) < 1e-10
3837
assert torch.nn.functional.mse_loss(true_output[1], output[1]) < 1e-10
3938

40-
module = LinearQwen3VLMoeTextSparseMoeBlock(
41-
config, original, calibrate_all_experts=False
39+
module = CalibrateQwen3VLMoeTextSparseMoeBlock(
40+
original, config, calibrate_all_experts=False
4241
)
4342
with calibration_forward_context(module):
4443
output = module(sample)

0 commit comments

Comments
 (0)