Skip to content

Commit 5470020

Browse files
committed
add example for torchao with deepseek scaling
Summary: Test Plan: ``` with-proxy python quantize_hf_model_with_torchao.py --quant_type fp8 --granularity "a1x128_w128x128" --save_model_to_disk True ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 1d2aed5 commit 5470020

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

hf_torchao_vllm/quantize_hf_model_with_torchao.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@
2424
NVFP4InferenceConfig,
2525
NVFP4MMConfig,
2626
)
27-
from torchao.quantization import ModuleFqnToConfig
27+
from torchao.quantization import (
28+
ModuleFqnToConfig,
29+
PerBlock,
30+
PerRow,
31+
PerTensor,
32+
)
2833
from torchao.quantization.quant_api import (
2934
CutlassInt4PackedLayout,
3035
Float8DynamicActivationFloat8WeightConfig,
@@ -34,8 +39,6 @@
3439
Int8DynamicActivationInt4WeightConfig,
3540
Int8DynamicActivationInt8WeightConfig,
3641
Int8WeightOnlyConfig,
37-
PerRow,
38-
PerTensor,
3942
)
4043
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
4144

@@ -62,6 +65,7 @@ def get_quantization_config(args):
6265
granularity_mapping = {
6366
"per_row": PerRow(),
6467
"per_tensor": PerTensor(),
68+
"a1x128_w128x128": [PerBlock([1, 128]), PerBlock([128, 128])],
6569
}
6670

6771
gran = granularity_mapping[args.granularity]
@@ -71,7 +75,13 @@ def get_quantization_config(args):
7175
return TorchAoConfig("autoquant", min_sqnr=args.min_sqnr)
7276
case "fp8":
7377
single_config = Float8DynamicActivationFloat8WeightConfig(
74-
granularity=gran
78+
granularity=gran,
79+
# the 125m model has a lot of activation zeroes for some
80+
# prompts, need to set a lower bound to prevent scales from
81+
# being 0.
82+
# TODO seems like torchao should do this for me.
83+
# TODO tool to find this (I used bisect on this tiny model).
84+
activation_value_lb=1.0e-12,
7585
)
7686
if args.experts_only_qwen_1_5_moe_a_2_7b:
7787
expert_fqn_to_config = {}
@@ -325,7 +335,9 @@ def main(
325335
"mxfp4",
326336
"nvfp4",
327337
] = "fp8",
328-
granularity: Literal["per_row", "per_tensor"] = "per_row",
338+
granularity: Literal[
339+
"per_row", "per_tensor", "a1x128_w128x128"
340+
] = "per_row",
329341
min_sqnr: float | None = None,
330342
max_new_tokens: int = 64,
331343
benchmark: bool = False,

0 commit comments

Comments
 (0)