Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 67 additions & 44 deletions test/quantization/quantize_/workflows/float8/test_float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torchao.testing.utils import TorchAOIntegrationTestCase
from torchao.utils import (
_is_fbgemm_gpu_genai_available,
auto_detect_device,
is_sm_at_least_89,
is_sm_at_least_90,
is_sm_at_least_100,
Expand All @@ -40,6 +41,8 @@
# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 128

_DEVICE = auto_detect_device()


class ToyLinearModel(torch.nn.Module):
def __init__(self, in_features, out_features, bias):
Expand Down Expand Up @@ -137,16 +140,19 @@ def check_weight_scaling(self, granularity: Granularity):

# TODO: move tests in test_affine_quantized_float.py here after we migrated all implementations
@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
@unittest.skipIf(
not torch.accelerator.is_available(), "skipping when gpu is not available"
)
@unittest.skipIf(torch.cuda.is_available() and not is_sm_at_least_89(), "Need sm89+")
class TestFloat8Tensor(TorchAOIntegrationTestCase):
def setUp(self):
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
self.GPU_DEVICES = [_DEVICE] if torch.accelerator.is_available() else []
torch.set_grad_enabled(False)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not torch.accelerator.is_available(), "Need accelerator available")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
torch.cuda.is_available() and not is_sm_at_least_89(),
"Requires GPU with compute capability >= 8.9",
)
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
Expand Down Expand Up @@ -191,9 +197,10 @@ def test_fp8_linear_variants(
ToyLinearModel(K, N, bias),
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not torch.accelerator.is_available(), "Need accelerator available")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
torch.cuda.is_available() and not is_sm_at_least_89(),
"Requires GPU with compute capability >= 8.9",
)
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
Expand Down Expand Up @@ -230,7 +237,7 @@ def test_fp8_matmul_lora_variants(
kernel_preference,
sizes,
bias=False,
model=model.to("cuda"),
model=model.to(_DEVICE),
)

def _test_fp8_matmul_model(
Expand All @@ -253,6 +260,8 @@ def _test_fp8_matmul_model(
return unittest.skip("unimplemented")

elif granularity == (PerBlock([1, 128]), PerBlock([128, 128])):
if _DEVICE.type == "xpu":
return unittest.skip("PerBlock granularity not supported on XPU")
if dtype is not torch.bfloat16:
return unittest.skip("unimplemented")
elif mode != "dynamic":
Expand Down Expand Up @@ -284,7 +293,8 @@ def _test_fp8_matmul_model(
)

if kernel_preference == KernelPreference.FBGEMM and (
(not _is_fbgemm_gpu_genai_available()) or (not is_sm_at_least_90())
(not _is_fbgemm_gpu_genai_available())
or (not torch.cuda.is_available() and not is_sm_at_least_90())
):
return unittest.skip(
"Requires fbgemm_gpu_genai to run fbgemm kernel preference test"
Expand All @@ -298,8 +308,8 @@ def _test_fp8_matmul_model(

with error_context:
M, N, K = sizes
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
model = model.eval().to(dtype).to("cuda")
input_tensor = torch.randn(*M, K, dtype=dtype, device=_DEVICE)
model = model.eval().to(dtype).to(_DEVICE)

quantized_model = copy.deepcopy(model)

Expand Down Expand Up @@ -328,9 +338,10 @@ def _test_fp8_matmul_model(
f"Quantization error is too high got a SQNR of {error}"
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not torch.accelerator.is_available(), "Need accelerator available")
@unittest.skipIf(
not is_sm_at_least_100(), "Requires GPU with compute capability >= 10.0"
torch.cuda.is_available() and not is_sm_at_least_100(),
"Requires GPU with compute capability >= 10.0",
)
@unittest.skipIf(
not _is_fbgemm_gpu_genai_available(),
Expand Down Expand Up @@ -369,7 +380,7 @@ def test_fp8_conv_variants(
kernel_size = 3

# Note: this is channel last memory format
input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device="cuda")
input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device=_DEVICE)
if dim == 3:
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
else:
Expand All @@ -384,7 +395,7 @@ def test_fp8_conv_variants(
bias=False,
padding=0,
dtype=dtype,
device="cuda",
device=_DEVICE,
).eval()

quantized_model = copy.deepcopy(model)
Expand All @@ -411,9 +422,10 @@ def test_fp8_conv_variants(
f"Quantization error is too high got a SQNR of {error}"
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not torch.accelerator.is_available(), "Need accelerator available")
@unittest.skipIf(
not is_sm_at_least_100(), "Requires GPU with compute capability >= 10.0"
torch.cuda.is_available() and not is_sm_at_least_100(),
"Requires GPU with compute capability >= 10.0",
)
@unittest.skipIf(
not _is_fbgemm_gpu_genai_available(),
Expand Down Expand Up @@ -453,7 +465,7 @@ def test_fp8_conv_skip_quant(
kernel_size = 3

# Note: this is channel last memory format
input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device="cuda")
input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device=_DEVICE)
if dim == 3:
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
else:
Expand All @@ -467,7 +479,7 @@ def test_fp8_conv_skip_quant(
bias=False,
padding=0,
dtype=dtype,
device="cuda",
device=_DEVICE,
).eval()

quantized_model = copy.deepcopy(model)
Expand All @@ -488,14 +500,14 @@ def test_fp8_conv_skip_quant(

@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
@unittest.skipIf(
not is_sm_at_least_90(),
torch.cuda.is_available() and not is_sm_at_least_90(),
"Failing in SM89 right now: "
"AssertionError: tensor(False, device='cuda:0') is not true : sqnr: -2.90625, will fix a bit later",
)
def test_slice(self, granularity):
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
dtype = torch.bfloat16
device = "cuda"
device = _DEVICE
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device)
dummy1.weight = torch.nn.Parameter(
Expand Down Expand Up @@ -567,9 +579,9 @@ def test_kernel_preference_numerical_equivalence(self, granularity, sizes):
"""
M, N, K = sizes
dtype = torch.bfloat16
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
input_tensor = torch.randn(*M, K, dtype=dtype, device=_DEVICE)
# Create a linear layer with bfloat16 dtype
model = ToyLinearModel(K, N, bias=False).eval().to(dtype).to("cuda")
model = ToyLinearModel(K, N, bias=False).eval().to(dtype).to(_DEVICE)

# reference kernel preference and results
# we are using KerenelPreference.TORCH as the reference
Expand All @@ -586,6 +598,7 @@ def test_kernel_preference_numerical_equivalence(self, granularity, sizes):
]
if (
_is_fbgemm_gpu_genai_available()
and torch.cuda.is_available()
and is_sm_at_least_90()
and not isinstance(granularity, PerTensor)
):
Expand Down Expand Up @@ -614,9 +627,9 @@ def test_kernel_preference_numerical_equivalence(self, granularity, sizes):
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
def test_slice_preserves_aliasing(self, granularity):
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
l = torch.nn.Linear(1024, 1024).to(_DEVICE).to(torch.bfloat16)
l.weight = torch.nn.Parameter(
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
torch.zeros(1024, 1024, dtype=torch.bfloat16, device=_DEVICE)
)
quantize_(l, config)
param = l.weight
Expand All @@ -631,7 +644,9 @@ def test_slice_and_copy_similar_to_vllm(self, granularity):
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
self._test_slice_and_copy_similar_to_vllm(config)

@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
@unittest.skipIf(
torch.cuda.is_available() and not is_sm_at_least_90(), "Need sm90+"
)
@unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai")
def test_bmm(self):
# only support per row quantization
Expand All @@ -646,7 +661,7 @@ def forward(self, x):
return torch.bmm(x, self.weight.transpose(-2, -1))

dtype = torch.bfloat16
device = "cuda"
device = _DEVICE

B, M, K, N = 10, 32, 128, 256

Expand All @@ -659,7 +674,9 @@ def forward(self, x):
sqnr = compute_error(original, quantized)
self.assertTrue(sqnr > 20)

@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
@unittest.skipIf(
torch.cuda.is_available() and not is_sm_at_least_90(), "Need sm90+"
)
@unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai")
def test_bmm_weight_in_bkn_layout(self):
# Tests rowwise quantization of a 3d weight stored with shape (B, K, N)
Expand All @@ -679,7 +696,7 @@ def forward(self, x):
return torch.bmm(x, self.weight)

dtype = torch.bfloat16
device = "cuda"
device = _DEVICE

B, M, K, N = 10, 32, 128, 256

Expand Down Expand Up @@ -739,7 +756,7 @@ def test_to_device(self, granularity, sizes):
def test_cat(self, granularity, sizes):
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
dtype = torch.bfloat16
device = "cuda"
device = _DEVICE
M, N, K = sizes
linear1 = torch.nn.Linear(K, N, dtype=dtype, device=device)
linear2 = torch.nn.Linear(K, N, dtype=dtype, device=device)
Expand Down Expand Up @@ -788,7 +805,9 @@ def test_cat(self, granularity, sizes):
@unittest.skip(
"This requires rowwise scaling for weight in layout BKN across axis 1 to work"
)
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
@unittest.skipIf(
torch.cuda.is_available() and not is_sm_at_least_90(), "Need sm90+"
)
@unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai")
def test_moe_weight_reshape_ops(self):
# only per row quantization is supported for bmm
Expand All @@ -799,7 +818,9 @@ def test_moe_weight_reshape_ops(self):
# TODO: we have some other tests living in https://github.com/pytorch/ao/blob/4ecc89edd7b5cfc12e6f80854c85d04c472a0eb0/test/dtypes/test_affine_quantized_float.py#L743
# that should be moved here after v1 config is deprecated:
# https://github.com/pytorch/ao/issues/2649
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
@unittest.skipIf(
torch.cuda.is_available() and not is_sm_at_least_90(), "Need sm90+"
)
@unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai")
def test_expected_gpu_kernel_fbgemm(self):
"""Making sure KernelPreference.FBGEMM calls correct quantize and gemm kernels
Expand All @@ -809,15 +830,15 @@ def test_expected_gpu_kernel_fbgemm(self):

M, K, N = 128, 256, 512
m = torch.nn.Sequential(
torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16)
torch.nn.Linear(K, N, device=_DEVICE, dtype=torch.bfloat16)
)
config = Float8DynamicActivationFloat8WeightConfig(
granularity=PerRow(),
kernel_preference=KernelPreference.FBGEMM,
)
quantize_(m, config)
m = torch.compile(m)
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
x = torch.randn(M, K, device=_DEVICE, dtype=torch.bfloat16)
out, code = run_and_get_code(m, x)

# 1. check at least one occurrence of the quantize op and rowwise gemm op
Expand All @@ -830,7 +851,9 @@ def test_expected_gpu_kernel_fbgemm(self):
".run("
).run(code[0])

@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
@unittest.skipIf(
torch.cuda.is_available() and not is_sm_at_least_90(), "Need sm90+"
)
def test_index_select(self):
"""
test that `x_0 = x[0]` works when `x` is a 3D `Float8Tensor`. This is
Expand All @@ -840,7 +863,7 @@ def test_index_select(self):
"""

E, K, N = 128, 256, 512
x = torch.randn(E, N, K, device="cuda", dtype=torch.bfloat16)
x = torch.randn(E, N, K, device=_DEVICE, dtype=torch.bfloat16)
x_fp8 = Float8Tensor.from_hp(x)
x_fp8_1 = x_fp8[1]
torch.testing.assert_close(
Expand All @@ -858,7 +881,7 @@ def test_index_select(self):
def test_unsqueeze_operation(self, granularity, sizes):
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
dtype = torch.bfloat16
device = "cuda"
device = _DEVICE
M, N, K = sizes

# Create a linear layer and quantize it
Expand Down Expand Up @@ -914,7 +937,7 @@ def test_unsqueeze_conv2d_weight(self):
granularity = PerTensor()
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
dtype = torch.bfloat16
device = "cuda"
device = _DEVICE
N, C_in, C_out, spatial_dims = 4, 16, 64, (32, 32)
dim = len(spatial_dims)
kernel_size = 3
Expand Down Expand Up @@ -1002,7 +1025,7 @@ def test_slice_3d_operation(self, granularity, slice_dim, tensor_shape):
"""Test slicing operations on 3D Float8Tensor across all dimensions"""
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
dtype = torch.bfloat16
device = "cuda"
device = _DEVICE

B, S, H = tensor_shape

Expand Down Expand Up @@ -1100,7 +1123,7 @@ def test_slice_3d_operation(self, granularity, slice_dim, tensor_shape):
self.assertEqual(sliced_dequantized, sliced_original)

def test_to_dtype_layout(self):
x = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16)
x = torch.randn(128, 512, device=_DEVICE, dtype=torch.bfloat16)
x_fp8 = Float8Tensor.from_hp(x)
y_fp8 = torch.ops.aten.to.dtype_layout(
x_fp8, dtype=x_fp8.dtype, layout=x_fp8.layout, device="cpu"
Expand All @@ -1110,9 +1133,9 @@ def test_to_dtype_layout(self):
self.assertEqual(y_fp8.device, torch.device("cpu"))

def test_has_compatible_shallow_copy_type(self):
x1 = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16)
x2 = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16)
x3 = torch.randn(128, 256, device="cuda", dtype=torch.bfloat16)
x1 = torch.randn(128, 512, device=_DEVICE, dtype=torch.bfloat16)
x2 = torch.randn(128, 512, device=_DEVICE, dtype=torch.bfloat16)
x3 = torch.randn(128, 256, device=_DEVICE, dtype=torch.bfloat16)
x1_fp8 = Float8Tensor.from_hp(x1)
x2_fp8 = Float8Tensor.from_hp(x2)
x3_fp8 = Float8Tensor.from_hp(x3)
Expand All @@ -1123,7 +1146,7 @@ def test_has_compatible_shallow_copy_type(self):
self.assertFalse(torch._has_compatible_shallow_copy_type(x1_fp8, x3_fp8))

def test_transpose(self):
x = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16)
x = torch.randn(128, 512, device=_DEVICE, dtype=torch.bfloat16)
x_fp8 = Float8Tensor.from_hp(x)
x_fp8_t = x_fp8.t()
torch.testing.assert_close(x_fp8_t.qdata, x_fp8.qdata.t(), atol=0, rtol=0)
Expand Down
4 changes: 3 additions & 1 deletion torchao/float8/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,9 @@ def _check_hardware_support(
is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularities)

if is_per_tensor or is_per_row:
assert is_sm_at_least_89() or is_MI300(), (
assert torch.xpu.is_available() or (
torch.cuda.is_available() and is_sm_at_least_89() or is_MI300()
), (
"Float8 dynamic quantization requires CUDA compute capability ≥8.9 or MI300+."
)
elif is_a_1_128_w_128_128:
Expand Down
Loading