Skip to content

Commit f22ca14

Browse files
authored
[Tests] Update Schemes (#2013)
SUMMARY: - Latest CT introduces `scale_dtype` which if not provided, will use the weight dtype for the scales - As a result, we no longer hardcode FP8 as the dtype when generating scales for NVFp4 and instead rely on this field. Update the tests to reflect this Testing - Addresses 2/14 failures. All other failures are coming from one test case which will be resolved in a follow-up
1 parent cfe9169 commit f22ca14

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

tests/llmcompressor/modifiers/calibration/test_lifecycle.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
import torch
33
from compressed_tensors.quantization import (
4+
FP8_E4M3_DATA,
45
QuantizationScheme,
56
forward_quantize,
67
initialize_module_for_quantization,
@@ -83,6 +84,7 @@
8384
symmetric=True,
8485
strategy="tensor_group", # requires float4
8586
group_size=3,
87+
scale_dtype=FP8_E4M3_DATA.dtype,
8688
),
8789
torch.tensor([[0, 3], [6, 9], [12, 15], [18, 21]]),
8890
torch.tensor([[2, 5], [8, 11], [14, 17], [20, 23]]),
@@ -195,6 +197,7 @@ def test_static_weight_quantization(
195197
strategy="tensor_group",
196198
dynamic="local",
197199
group_size=3,
200+
scale_dtype=FP8_E4M3_DATA.dtype,
198201
),
199202
None,
200203
None,

0 commit comments

Comments
 (0)