Skip to content

Commit a257166

Browse files
authored
make float8 a1x128_w128x128 granularity serializeable (#3279)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent a9f2dc1 commit a257166

File tree

6 files changed

+13
-6
lines changed

6 files changed

+13
-6
lines changed

test/core/test_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
AWQConfig,
2424
AWQStep,
2525
)
26+
from torchao.quantization import PerBlock
2627
from torchao.quantization.quant_api import (
2728
Float8DynamicActivationFloat8WeightConfig,
2829
Float8DynamicActivationInt4WeightConfig,
@@ -45,6 +46,9 @@
4546
Float8DynamicActivationFloat8WeightConfig(),
4647
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
4748
Float8DynamicActivationFloat8WeightConfig(granularity=[PerRow(), PerRow()]),
49+
Float8DynamicActivationFloat8WeightConfig(
50+
granularity=[PerBlock([1, 128]), PerBlock([128, 128])]
51+
),
4852
Float8WeightOnlyConfig(
4953
weight_dtype=torch.float8_e4m3fn,
5054
),

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def setUp(self):
9191
@common_utils.parametrize("compile", [True, False])
9292
@common_utils.parametrize(
9393
"granularity",
94-
[PerTensor(), PerRow(), (PerBlock((1, 128)), PerBlock((128, 128)))],
94+
[PerTensor(), PerRow(), (PerBlock([1, 128]), PerBlock([128, 128]))],
9595
)
9696
@common_utils.parametrize(
9797
"kernel_preference",
@@ -125,7 +125,7 @@ def test_fp8_linear_variants(
125125
elif mode == "weight-only":
126126
return unittest.skip("unimplemented")
127127

128-
elif granularity == (PerBlock((1, 128)), PerBlock((128, 128))):
128+
elif granularity == (PerBlock([1, 128]), PerBlock([128, 128])):
129129
if dtype is not torch.bfloat16:
130130
return unittest.skip("unimplemented")
131131
elif mode != "dynamic":
@@ -199,7 +199,7 @@ def test_fp8_linear_variants(
199199
assert qs1.shape == (N, 1)
200200
assert qs2.shape == (K, 1)
201201
else:
202-
assert granularity == (PerBlock((1, 128)), PerBlock((128, 128)))
202+
assert granularity == (PerBlock([1, 128]), PerBlock([128, 128]))
203203
assert qs1.shape == (N // 128, K // 128)
204204
assert qs2.shape == (K // 128, N // 128)
205205

torchao/_models/llama/eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def run_evaluation(
173173
)
174174
if quantization == "float8_a1x128_w128x128":
175175
config = Float8DynamicActivationFloat8WeightConfig(
176-
granularity=(PerBlock((1, 128)), PerBlock((128, 128))),
176+
granularity=(PerBlock([1, 128]), PerBlock([128, 128])),
177177
activation_value_lb=1e-12,
178178
)
179179
# TODO(future): all workflows in this file should be skipping quantization

torchao/float8/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def _granularity_is_a_1_128_w_128_128(
225225
list[FP8Granularity],
226226
],
227227
) -> bool:
228-
return len(g) == 2 and g[0] == PerBlock((1, 128)) and g[1] == PerBlock((128, 128))
228+
return len(g) == 2 and g[0] == PerBlock([1, 128]) and g[1] == PerBlock([128, 128])
229229

230230

231231
def _normalize_granularity(

torchao/quantization/granularity.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,4 +126,8 @@ class PerBlock(Granularity):
126126
# 1. `block_size` in this class can support tensors of multiple ranks
127127
# 2. `block_size` in other places in the codebase has rank equal to the
128128
# corresponding tensor
129+
# TODO(future PR): change to list or support serialization with tuples,
130+
# currently serialization only works when `block_size` is specified as a
131+
# list. Example error:
132+
# https://gist.github.com/vkuzo/ab4d6aec83cb98ad9417898d2c024a2c
129133
block_size: tuple[int, ...]

torchao/quantization/quant_api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1782,7 +1782,6 @@ def __post_init__(self):
17821782
KernelPreference.AUTO,
17831783
KernelPreference.TORCH,
17841784
), "unimplemented"
1785-
assert self.mm_config is None, "unimplemented"
17861785
assert self.version >= 2, "unimplemented"
17871786
default_use_fast_accum = False
17881787

0 commit comments

Comments
 (0)