Skip to content

Commit a9f2dc1

Browse files
authored
support eval of float8_a1x128_w128x128 (#3269)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 63fafb2 commit a9f2dc1

File tree

3 files changed

+25
-4
lines changed

3 files changed

+25
-4
lines changed

scripts/download.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -
3838
parser.add_argument(
3939
"--repo_id",
4040
type=str,
41-
default="checkpoints/meta-llama/llama-2-7b-chat-hf",
41+
default="meta-llama/llama-2-7b-chat-hf",
4242
help="Repository ID to download from.",
4343
)
4444
parser.add_argument(

torchao/_models/llama/eval.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
Int4WeightOnlyConfig,
2424
Int8DynamicActivationInt8WeightConfig,
2525
Int8WeightOnlyConfig,
26+
PerBlock,
2627
PerRow,
2728
PerTensor,
2829
UIntXWeightOnlyConfig,
@@ -44,6 +45,7 @@ def run_evaluation(
4445
calibration_limit: Optional[int] = None,
4546
calibration_seq_length: Optional[int] = None,
4647
pad_calibration_inputs: bool = False,
48+
print_model: bool = False,
4749
):
4850
"""Runs the evaluation of a model using LM Eval."""
4951
print(
@@ -169,6 +171,14 @@ def run_evaluation(
169171
model,
170172
Float8DynamicActivationFloat8WeightConfig(granularity=granularity),
171173
)
174+
if quantization == "float8_a1x128_w128x128":
175+
config = Float8DynamicActivationFloat8WeightConfig(
176+
granularity=(PerBlock((1, 128)), PerBlock((128, 128))),
177+
activation_value_lb=1e-12,
178+
)
179+
# TODO(future): all workflows in this file should be skipping quantization
180+
# of `lm_head`
181+
quantize_(model, config)
172182
if "autoround" in quantization:
173183
from transformers import AutoTokenizer
174184

@@ -273,7 +283,16 @@ def run_evaluation(
273283
)
274284

275285
if compile:
276-
model = torch.compile(model, mode="max-autotune", fullgraph=True)
286+
# TODO(future PR): clean this up
287+
if quantization == "float8_a1x128_w128x128":
288+
# we don't need max-autotune for float8 blockwise quant
289+
model = torch.compile(model)
290+
else:
291+
model = torch.compile(model, mode="max-autotune", fullgraph=True)
292+
293+
if print_model:
294+
print(model)
295+
277296
with torch.no_grad():
278297
print("Running evaluation ...")
279298
# avoid circular imports
@@ -371,6 +390,9 @@ def run_evaluation(
371390
default=False,
372391
help="pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower",
373392
)
393+
parser.add_argument(
394+
"--print_model", action="store_true", help="Whether to print the model."
395+
)
374396

375397
args = parser.parse_args()
376398
run_evaluation(
@@ -387,4 +409,5 @@ def run_evaluation(
387409
args.calibration_limit,
388410
args.calibration_seq_length,
389411
args.pad_calibration_inputs,
412+
args.print_model,
390413
)

torchao/quantization/quant_api.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1778,8 +1778,6 @@ def __post_init__(self):
17781778

17791779
default_use_fast_accum = True
17801780
if _granularity_is_a_1_128_w_128_128(self.granularity):
1781-
assert self.activation_value_lb is None, "unimplemented"
1782-
assert self.activation_value_ub is None, "unimplemented"
17831781
assert self.kernel_preference in (
17841782
KernelPreference.AUTO,
17851783
KernelPreference.TORCH,

0 commit comments

Comments
 (0)