Skip to content

Commit 07eb188

Browse files
committed
[wip] float8 rowwise quant along row 1 of tensor rank 2
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: c40d96b ghstack-comment-id: 3497584430 Pull-Request: #3303
1 parent 58b07f0 commit 07eb188

File tree

3 files changed

+48
-3
lines changed

3 files changed

+48
-3
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torchao.quantization import (
1919
Float8DynamicActivationFloat8WeightConfig,
2020
Float8WeightOnlyConfig,
21+
PerAxis,
2122
PerBlock,
2223
PerRow,
2324
PerTensor,
@@ -466,6 +467,39 @@ def forward(self, x):
466467
sqnr = compute_error(original, quantized)
467468
self.assertTrue(sqnr > 20)
468469

470+
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
471+
@unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai")
472+
def test_bmm_weight_in_bkn_layout(self):
473+
# Tests rowwise quantization of a 3d weight stored with shape (B, K, N)
474+
# and contigous with that shape. Since the `K` dimension is not last, we
475+
# need to specify granularity with `PerAxis(1)`.
476+
477+
# only support per row quantization
478+
granularity = [PerRow(), PerAxis(1)]
479+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
480+
481+
class Model(torch.nn.Module):
482+
def __init__(self, weight):
483+
super().__init__()
484+
self.weight = weight
485+
486+
def forward(self, x):
487+
return torch.bmm(x, self.weight)
488+
489+
dtype = torch.bfloat16
490+
device = "cuda"
491+
492+
B, M, K, N = 10, 32, 128, 256
493+
494+
input = torch.randn(B, M, K, dtype=dtype, device=device)
495+
weight = torch.randn(B, K, N, dtype=dtype, device=device)
496+
m = Model(weight).eval()
497+
original = m(input)
498+
quantize_(m, config, filter_fn=lambda x, fqn: True)
499+
quantized = m(input)
500+
sqnr = compute_error(original, quantized)
501+
self.assertTrue(sqnr > 20)
502+
469503
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
470504
@common_utils.parametrize(
471505
"sizes",

torchao/float8/inference.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul
1616
from torchao.float8.types import FP8Granularity
1717
from torchao.quantization.granularity import (
18+
PerAxis,
1819
PerBlock,
1920
PerRow,
2021
PerTensor,
@@ -247,13 +248,21 @@ def _normalize_granularity(
247248
granularity[1], PerTensor
248249
)
249250
is_per_row = isinstance(granularity[0], PerRow) and isinstance(
250-
granularity[1], PerRow
251+
granularity[1], (PerRow, PerAxis)
251252
)
252253
is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularity)
253254

254255
if not (is_per_tensor or is_per_row or is_a_1_128_w_128_128):
255256
raise ValueError(f"Unsupported granularity types: {granularity}.")
256-
if not isinstance(granularity[0], type(granularity[1])):
257+
258+
a_w_granularities_match = (
259+
# direct match
260+
isinstance(granularity[0], type(granularity[1]))
261+
# PerAxis is a more general version of PerRow
262+
or (isinstance(granularity[0], PerRow) and isinstance(granularity[1], PerAxis))
263+
)
264+
265+
if not a_w_granularities_match:
257266
raise ValueError(
258267
f"Different granularities for activation and weight are not supported: {granularity}."
259268
)
@@ -280,7 +289,7 @@ def _check_hardware_support(
280289
granularities[1], PerTensor
281290
)
282291
is_per_row = isinstance(granularities[0], PerRow) and isinstance(
283-
granularities[1], PerRow
292+
granularities[1], (PerRow, PerAxis)
284293
)
285294
is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularities)
286295

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,8 @@ def _(func, types, args, kwargs):
423423

424424
b_data = weight_tensor.qdata
425425
b_scale = weight_tensor.scale
426+
print('a', a_data.shape, a_scale.shape, input_tensor.block_size)
427+
print('b', b_data.shape, b_scale.shape, weight_tensor.block_size)
426428

427429
assert (
428430
weight_tensor.block_size[0] == 1

0 commit comments

Comments
 (0)