Skip to content

Commit 6f51b4c

Browse files
authored
Quant fallback to 8w per token + other quant improvements for multimodal (#154)
1 parent cebeb3d commit 6f51b4c

File tree

4 files changed

+164
-33
lines changed

4 files changed

+164
-33
lines changed

optimum/commands/export/executorch.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,14 @@ def parse_args_executorch(parser):
7676
required_group.add_argument(
7777
"--qlinear",
7878
type=str,
79-
choices=["8da4w", "4w", "8w"],
79+
choices=["8da4w", "4w", "8w", "8da8w", "8da4w,8da8w"],
8080
required=False,
8181
help=(
8282
"Quantization config for decoder linear layers.\n\n"
8383
"Options:\n"
8484
" 8da4w - 8-bit dynamic activation, 4-bit weight\n"
85+
" 8da8w - 8-bit dynamic activation, 8-bit weight\n"
86+
" 8da4w,8da8w - 8-bit dynamic activation, 4-bit weight and 8-bit weight\n"
8587
" 4w - 4-bit weight only\n"
8688
" 8w - 8-bit weight only"
8789
),
@@ -104,12 +106,14 @@ def parse_args_executorch(parser):
104106
required_group.add_argument(
105107
"--qlinear_encoder",
106108
type=str,
107-
choices=["8da4w", "4w", "8w"],
109+
choices=["8da4w", "4w", "8w", "8da8w", "8da4w,8da8w"],
108110
required=False,
109111
help=(
110112
"Quantization config for encoder linear layers.\n\n"
111113
"Options:\n"
112114
" 8da4w - 8-bit dynamic activation, 4-bit weight\n"
115+
" 8da8w - 8-bit dynamic activation, 8-bit weight\n"
116+
" 8da4w,8da8w - 8-bit dynamic activation, 4-bit weight; fallback on 8-bit dynamic activation, 8-bit weight per-channel where group size doesn't divide block size cleanly \n"
113117
" 4w - 4-bit weight only\n"
114118
" 8w - 8-bit weight only"
115119
),
@@ -144,6 +148,24 @@ def parse_args_executorch(parser):
144148
required_group.add_argument(
145149
"--qembedding_group_size", type=int, required=False, help="Group size for embedding quantization."
146150
)
151+
required_group.add_argument(
152+
"--qembedding_encoder",
153+
type=str,
154+
choices=["4w", "8w"],
155+
required=False,
156+
help=(
157+
"Quantization config for encoder embedding layer, for model arcitectures with an encoder.\n\n"
158+
"Options:\n"
159+
" 4w - 4-bit weight only\n"
160+
" 8w - 8-bit weight only"
161+
),
162+
)
163+
required_group.add_argument(
164+
"--qembedding_encoder_group_size",
165+
type=int,
166+
required=False,
167+
help="Group size for encoder embedding quantization, for model architectures with an encoder.",
168+
)
147169
required_group.add_argument(
148170
"--max_seq_len",
149171
type=int,
@@ -220,6 +242,10 @@ def run(self):
220242
kwargs["qembedding"] = self.args.qembedding
221243
if self.args.qembedding_group_size:
222244
kwargs["qembedding_group_size"] = self.args.qembedding_group_size
245+
if self.args.qembedding_encoder:
246+
kwargs["qembedding_encoder"] = self.args.qembedding_encoder
247+
if self.args.qembedding_encoder_group_size:
248+
kwargs["qembedding_encoder_group_size"] = self.args.qembedding_encoder_group_size
223249
if self.args.max_seq_len:
224250
kwargs["max_seq_len"] = self.args.max_seq_len
225251
if hasattr(self.args, "dtype") and self.args.dtype:

optimum/exporters/executorch/quantization.py

Lines changed: 82 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def quantize_model_(
4444
if qlinear_config == "8w":
4545
assert (
4646
qembedding_group_size == 0
47-
), "8-bit embedding quantization only supports per-channel at the moment, please use qembedding_group_size = 0."
47+
), "8-bit embedding quantization only supports per-token at the moment, please use qembedding_group_size = 0."
4848
if qembedding_group_size == 0:
4949
embedding_weight_granularity = PerAxis(0)
5050
else:
@@ -71,42 +71,99 @@ def quantize_model_(
7171
)
7272

7373
if qlinear_config:
74+
75+
def build_linear_config(quant_config_key: str, granularity: str, packing_format: Optional[str] = None):
76+
if quant_config_key == "8da4w":
77+
return Int8DynamicActivationIntxWeightConfig(
78+
weight_dtype=torch.int4,
79+
weight_granularity=granularity,
80+
)
81+
if quant_config_key == "4w":
82+
# Determine if we need to use Int4WeightOnlyConfig with int4_packing_format
83+
if packing_format:
84+
return Int4WeightOnlyConfig(
85+
group_size=qlinear_group_size,
86+
int4_packing_format=packing_format,
87+
int4_choose_qparams_algorithm="hqq",
88+
)
89+
else:
90+
return IntxWeightOnlyConfig(
91+
weight_dtype=torch.int4,
92+
granularity=granularity,
93+
)
94+
if quant_config_key == "8w":
95+
return IntxWeightOnlyConfig(
96+
weight_dtype=torch.int8,
97+
granularity=granularity,
98+
)
99+
if quant_config_key == "8da8w":
100+
return Int8DynamicActivationIntxWeightConfig(
101+
weight_dtype=torch.int8,
102+
weight_granularity=PerAxis(0),
103+
)
104+
raise ValueError(f"Unsupported linear quantization config '{quant_config_key}'.")
105+
106+
qlinear_configs = [cfg.strip() for cfg in qlinear_config.split(",")]
107+
if any(cfg == "" for cfg in qlinear_configs):
108+
raise ValueError("Linear quantization config entries must be non-empty.")
109+
if len(qlinear_configs) > 2:
110+
raise ValueError("Expected at most one fallback linear quantization config, got more than one comma.")
111+
112+
primary_linear_config_key = qlinear_configs[0]
113+
fallback_linear_config_key = qlinear_configs[1] if len(qlinear_configs) == 2 else None
114+
74115
if qlinear_group_size == 0:
75116
linear_weight_granularity = PerAxis(0)
117+
if fallback_linear_config_key is not None:
118+
logging.warning(
119+
"qlinear_group_size is 0, fallback linear config will not be used as all layers will be quantized with per-axis granularity."
120+
)
121+
fallback_linear_config_key = None
76122
else:
77-
assert qlinear_group_size % 2 == 0, "Linear quantization group size must be a multiple of 2."
123+
assert (
124+
qlinear_group_size % 2 == 0
125+
), f"Linear quantization group size must be a multiple of 2, got {qlinear_group_size}."
78126
linear_weight_granularity = PerGroup(qlinear_group_size)
79127

80128
logging.info("Quantizing linear layers.")
129+
primary_linear_config = build_linear_config(
130+
primary_linear_config_key, linear_weight_granularity, qlinear_packing_format
131+
)
81132

82-
# Determine if we need to use Int4WeightOnlyConfig with int4_packing_format
83-
if qlinear_config == "4w" and qlinear_packing_format:
84-
linear_config = Int4WeightOnlyConfig(
85-
group_size=qlinear_group_size,
86-
int4_packing_format=qlinear_packing_format,
87-
int4_choose_qparams_algorithm="hqq",
88-
)
89-
else:
90-
linear_config = {
91-
"8da4w": Int8DynamicActivationIntxWeightConfig(
92-
weight_dtype=torch.int4,
93-
weight_granularity=linear_weight_granularity,
94-
),
95-
"4w": IntxWeightOnlyConfig(
96-
weight_dtype=torch.int4,
97-
granularity=linear_weight_granularity,
98-
),
99-
"8w": IntxWeightOnlyConfig(
100-
weight_dtype=torch.int8,
101-
granularity=linear_weight_granularity,
102-
),
103-
}[qlinear_config]
133+
# First, quantize layers that are compatible with group quantization
134+
def per_group_filter(module, fqn):
135+
if isinstance(module, torch.nn.Linear):
136+
# Check if hidden dimension is divisible by group size
137+
# For Linear layers, weight shape is [out_features, in_features]
138+
# Group quantization typically applies to the in_features dimension (dim=1)
139+
return qlinear_group_size == 0 or (module.weight.shape[1] % qlinear_group_size == 0)
140+
return False
104141

105142
quantize_(
106143
eager_model,
107-
linear_config,
144+
primary_linear_config,
145+
filter_fn=per_group_filter,
108146
)
109147

148+
# Then, quantize incompatible layers using the fallback per-axis config
149+
if fallback_linear_config_key is not None:
150+
fallback_linear_config = build_linear_config(fallback_linear_config_key, PerAxis(0))
151+
152+
def per_token_filter(module, fqn):
153+
if isinstance(module, torch.nn.Linear):
154+
return module.weight.shape[1] % qlinear_group_size != 0
155+
return False
156+
157+
logging.info(
158+
f"Applying fallback linear config '{fallback_linear_config_key}' (per-axis)"
159+
f" to layers incompatible with group size {qlinear_group_size}."
160+
)
161+
quantize_(
162+
eager_model,
163+
fallback_linear_config,
164+
filter_fn=per_token_filter,
165+
)
166+
110167
# TODO: remove after ExecuTorch dep on Torch >= 2.10.0.
111168
if parse(torch_version) < parse("2.10.0.dev20251104"):
112169
unwrap_tensor_subclass(eager_model)

optimum/exporters/executorch/tasks/multimodal_text_to_text.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
import json
17+
import logging
1718
import os.path
1819

1920
import torchao
@@ -202,8 +203,12 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
202203
qlinear_encoder_packing_format = kwargs.get("qlinear_encoder_packing_format", None)
203204
qembedding_config = kwargs.get("qembedding", None)
204205
qembedding_group_size = kwargs.get("qembedding_group_size", None)
206+
qembedding_encoder_config = kwargs.get("qembedding_encoder", None)
207+
qembedding_encoder_group_size = kwargs.get("qembedding_encoder_group_size", None)
205208

206209
# Quantize decoder linear weights.
210+
if qlinear_config:
211+
logging.info("Quantizing decoder linears...")
207212
quantize_decoder_kwargs = {
208213
"eager_model": getattr(eager_model, decoder_name),
209214
"qlinear_config": qlinear_config,
@@ -214,7 +219,26 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
214219
quantize_decoder_kwargs["qlinear_packing_format"] = qlinear_packing_format
215220
quantize_model_(**quantize_decoder_kwargs)
216221

222+
# Quantize lm head, if it is separate from the decoder model.
223+
# e.g. Sometimes the top-level model will have:
224+
# def __init__(self, ...):
225+
# self.decoder = ...
226+
# self.lm_head = ... # lm_head is not part of the decoder instance
227+
# ...
228+
if not hasattr(getattr(eager_model, decoder_name), "lm_head"):
229+
if not hasattr(eager_model, "lm_head"):
230+
raise AttributeError(
231+
f"Could not find `lm_head` for {model_name_or_path} has no `lm_head`, please double check if this is expected."
232+
)
233+
quantize_lm_head_kwargs = {
234+
"eager_model": eager_model.lm_head,
235+
"qlinear_config": qlinear_config,
236+
}
237+
quantize_model_(**quantize_lm_head_kwargs)
238+
217239
# Quantize encoder linear weights.
240+
if qlinear_encoder_config:
241+
logging.info("Quantizing encoder linears...")
218242
quantize_encoder_kwargs = {
219243
"eager_model": getattr(eager_model, encoder_name),
220244
"qlinear_config": qlinear_encoder_config,
@@ -225,9 +249,9 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
225249
quantize_encoder_kwargs["qlinear_packing_format"] = qlinear_encoder_packing_format
226250
quantize_model_(**quantize_encoder_kwargs)
227251

228-
# TODO: quantize other parts of the model, e.g. MultimodalProjector?
229-
230252
# Quantize decoder embeddings.
253+
if qembedding_config:
254+
logging.info("Quantizing embeddings...")
231255
quantize_decoder_embedding_kwargs = {
232256
"eager_model": getattr(eager_model, decoder_name),
233257
"qembedding_config": qembedding_config,
@@ -236,7 +260,16 @@ def load_multimodal_text_to_text_model(model_name_or_path: str, **kwargs):
236260
quantize_decoder_embedding_kwargs["qembedding_group_size"] = qembedding_group_size
237261
quantize_model_(**quantize_decoder_embedding_kwargs)
238262

239-
# TODO: quantize encoder embeddings.
263+
# Quantize encoder embeddings.
264+
if qembedding_encoder_config:
265+
logging.info("Quantizing embeddings...")
266+
quantize_encoder_embedding_kwargs = {
267+
"eager_model": getattr(eager_model, encoder_name),
268+
"qembedding_config": qembedding_encoder_config,
269+
}
270+
if qembedding_encoder_group_size is not None:
271+
quantize_encoder_embedding_kwargs["qembedding_group_size"] = qembedding_encoder_group_size
272+
quantize_model_(**quantize_encoder_embedding_kwargs)
240273

241274
return MultiModalTextToTextExportableModule(
242275
model=eager_model,

tests/models/test_modeling_gemma3.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,9 +309,24 @@ def test_gemma3_image_vision_with_custom_sdpa_kv_cache_8da4w_8we(self):
309309
use_custom_kv_cache=True,
310310
qlinear="8da4w",
311311
qlinear_group_size=32,
312-
# Can't quantize the encoder a the moment, hidden dim of 4304 doesn't fit ExecuTorch's
313-
# XNNPack 32-group size quantized kernels. See https://github.com/pytorch/executorch/issues/14221.
314-
qembedding_config="8w",
312+
qlinear_encoder="8da4w,8da8w",
313+
qlinear_encoder_group_size=32,
314+
qembedding="8w",
315+
qembedding_encoder="8w",
316+
)
317+
318+
# Check file size is approximately 3GB (allow 1% tolerance)
319+
file_size_bytes = os.path.getsize(os.path.join(model._temp_dir.name, "model.pte"))
320+
file_size_gb = file_size_bytes / (1024**3)
321+
expected_size_gb = 2.96
322+
tolerance = 0.01 # 1% tolerance
323+
324+
logging.info(f"model.pte size: {file_size_gb:.2f} GB")
325+
self.assertAlmostEqual(
326+
file_size_gb,
327+
expected_size_gb,
328+
delta=expected_size_gb * tolerance,
329+
msg=f"Expected file size ~{expected_size_gb}GB, but got {file_size_gb:.2f}GB",
315330
)
316331

317332
# Generate

0 commit comments

Comments
 (0)