Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 119 additions & 34 deletions test/quantization/quantize_/workflows/float8/test_float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def __init__(
)
if dim == 3:
self.conv = self.conv.to(memory_format=torch.channels_last_3d)
elif dim == 2:
self.conv = self.conv.to(memory_format=torch.channels_last)

def forward(self, x):
return self.conv(x)
Expand Down Expand Up @@ -337,33 +339,43 @@ def _test_fp8_matmul_model(
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize("inference_mode", [True, False])
# only test for 3D conv for now
# Inputs are (N, C_in, C_out, D, H, W)
# test for 2D/3D conv
# Inputs are (N, C_in, C_out, (D, H, W) or
# (N, C_in, C_out, (H, W)
@common_utils.parametrize(
"sizes",
[
(4, 16, 64, 32, 32, 32),
(4, 16, 64, (32, 32, 32)),
(4, 16, 64, (32, 32)),
],
)
def test_fp8_conv_variants(
self,
dtype: torch.dtype,
compile: bool,
inference_mode: bool,
kernel_preference: KernelPreference,
sizes: Tuple,
):
torch.compiler.reset()
granularity = PerTensor()
kernel_preference = KernelPreference.AUTO
N, C_in, C_out, D, H, W = sizes
dim = 3

N, C_in, C_out, spatial_dims = sizes
dim = len(spatial_dims)
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
assert dim in convs, f"Unsupported dim: {dim}"
conv_class = convs[dim]

kernel_size = 3

# Note: this is channel last memory format
input_tensor = torch.randn(N, C_in, D, H, W, dtype=dtype, device="cuda")
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device="cuda")
if dim == 3:
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
else:
assert dim == 2
input_tensor = input_tensor.to(memory_format=torch.channels_last)

# Create a linear layer with bfloat16 dtype
model = ToyConvModel(
dim,
C_in,
Expand All @@ -382,9 +394,9 @@ def test_fp8_conv_variants(
kernel_preference=kernel_preference,
)

_is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d)
_is_conv = lambda m, fqn: isinstance(m, conv_class)

quantize_(quantized_model, config, filter_fn=_is_conv3d)
quantize_(quantized_model, config, filter_fn=_is_conv)

if compile:
quantized_model = torch.compile(quantized_model, fullgraph=True)
Expand All @@ -408,13 +420,16 @@ def test_fp8_conv_variants(
"Requires fbgemm_gpu_genai to be installed",
)
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
# only test for 3D conv for now
# Inputs are (N, C_in, C_out, D, H, W)
# test for 2D/3D conv
# Inputs are (N, C_in, C_out, (D, H, W) or
# (N, C_in, C_out, (H, W)
@common_utils.parametrize(
"sizes",
[
(4, 12, 64, 32, 32, 32),
(4, 16, 12, 32, 32, 32),
(4, 12, 64, (32, 32, 32)),
(4, 16, 12, (32, 32, 32)),
(4, 12, 64, (32, 32)),
(4, 16, 12, (32, 32)),
],
)
def test_fp8_conv_skip_quant(
Expand All @@ -427,14 +442,23 @@ def test_fp8_conv_skip_quant(
"""
granularity = PerTensor()
kernel_preference = KernelPreference.AUTO
N, C_in, C_out, D, H, W = sizes
dim = 3

N, C_in, C_out, spatial_dims = sizes

dim = len(spatial_dims)
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
assert dim in convs, f"Unsupported dim: {dim}"
conv_class = convs[dim]

kernel_size = 3

# Note: this is channel last memory format
input_tensor = torch.randn(N, C_in, D, H, W, dtype=dtype, device="cuda")
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
# Create a linear layer with bfloat16 dtype
input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device="cuda")
if dim == 3:
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
else:
input_tensor = input_tensor.to(memory_format=torch.channels_last)

model = ToyConvModel(
dim,
C_in,
Expand All @@ -453,9 +477,9 @@ def test_fp8_conv_skip_quant(
kernel_preference=kernel_preference,
)

_is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d)
_is_conv = lambda m, fqn: isinstance(m, conv_class)

quantize_(quantized_model, config, filter_fn=_is_conv3d)
quantize_(quantized_model, config, filter_fn=_is_conv)
assert not isinstance(quantized_model.conv.weight, Float8Tensor)

output_original = model(input_tensor)
Expand Down Expand Up @@ -832,7 +856,6 @@ def test_index_select(self):
],
)
def test_unsqueeze_operation(self, granularity, sizes):
"""Test aten.unsqueeze.default operation on Float8Tensor"""
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
dtype = torch.bfloat16
device = "cuda"
Expand All @@ -845,7 +868,7 @@ def test_unsqueeze_operation(self, granularity, sizes):
original_weight = linear.weight
original_shape = original_weight.shape

# Test unsqueeze operation at dim=0 (only supported dimension)
# Test unsqueeze operation at dim=0
unsqueezed_weight = original_weight.unsqueeze(0)

# Verify the unsqueezed tensor has correct shape
Expand Down Expand Up @@ -887,22 +910,84 @@ def test_unsqueeze_operation(self, granularity, sizes):

self.assertEqual(unsqueezed_dequant, expected_dequant)

@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
def test_unsqueeze_error_cases(self, granularity):
"""Test error cases for aten.unsqueeze.default operation"""
def test_unsqueeze_conv2d_weight(self):
granularity = PerTensor()
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
dtype = torch.bfloat16
device = "cuda"
N, C_in, C_out, spatial_dims = 4, 16, 64, (32, 32)
dim = len(spatial_dims)
kernel_size = 3

# Create a linear layer and quantize it
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device)
quantize_(linear, config)
input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device=device)
input_tensor = input_tensor.to(memory_format=torch.channels_last)
model = ToyConvModel(
dim,
C_in,
C_out,
kernel_size,
bias=False,
padding=0,
dtype=dtype,
device=device,
).eval()

quantized_model = copy.deepcopy(model)

config = Float8DynamicActivationFloat8WeightConfig(
granularity=granularity,
)

_is_conv = lambda m, fqn: isinstance(m, torch.nn.Conv2d)

weight = linear.weight
quantize_(quantized_model, config, filter_fn=_is_conv)

# Test that unsqueezing on unsupported dimensions raises an error
with self.assertRaisesRegex(AssertionError, "Only dim == 0 is supported"):
weight.unsqueeze(1) # dim=1 should not be supported
original_weight = quantized_model.conv.weight
original_shape = original_weight.shape

# Test unsqueeze operation at dim=2
unsqueezed_weight = original_weight.unsqueeze(2)

# Verify the unsqueezed tensor has correct shape
original_shape_list = list(original_shape)
expected_shape = original_shape_list[:2] + [1] + original_shape_list[2:]
scale_shape_list = list(original_weight.scale.shape)
expected_scale_shape = scale_shape_list[:2] + [1] + scale_shape_list[2:]

self.assertEqual(unsqueezed_weight.shape, torch.Size(expected_shape))
# Verify qdata and scale shapes
expected_qdata_shape = expected_shape

self.assertEqual(
unsqueezed_weight.qdata.shape, torch.Size(expected_qdata_shape)
)
self.assertEqual(
unsqueezed_weight.scale.shape, torch.Size(expected_scale_shape)
)

# Verify block_size is correctly updated
expected_block_size = []
for i in range(len(expected_shape)):
expected_block_size.append(expected_shape[i] // expected_scale_shape[i])

self.assertEqual(unsqueezed_weight.block_size, expected_block_size)

# Test that metadata is preserved
self.assertEqual(unsqueezed_weight.mm_config, original_weight.mm_config)
self.assertEqual(
unsqueezed_weight.act_quant_kwargs, original_weight.act_quant_kwargs
)
self.assertEqual(
unsqueezed_weight.kernel_preference, original_weight.kernel_preference
)
self.assertEqual(unsqueezed_weight.dtype, original_weight.dtype)

# Test numerical correctness
original_dequant = original_weight.dequantize()
unsqueezed_dequant = unsqueezed_weight.dequantize()
expected_dequant = original_dequant.unsqueeze(2)

self.assertEqual(unsqueezed_dequant, expected_dequant)

@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
@common_utils.parametrize("slice_dim", [0, 1, 2])
Expand Down
14 changes: 10 additions & 4 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1816,13 +1816,19 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
_check_hardware_support(granularity)
activation_granularity, weight_granularity = granularity

if weight.dim() == 5:
# weights for conv3d
# Note: right now we assume it's weights of conv2d and conv3d purely based
# on the dimension of weight, currently there is no conflict with linear 2d
# and moe weights 3d
# if we need to support conv1d, which also has 3d weight, we may have to
# pass around the module as well to distinguish between conv1d and 3d moe weight
if weight.dim() in [4, 5]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there anything more robust we can check here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can't really distinguish from here whether it's linear and conv weight I think (although right now seems linear is 2d/3d and conv is 4d/5d, maybe conv1d could have 3d weight, which is an overlap with linear)

but we could potentially separate the handling of conv and linear by passing around the module as well, if this is needed in the future

I can add a comment for now

# weights for conv2d or 3d
assert isinstance(activation_granularity, PerTensor) and isinstance(
weight_granularity, PerTensor
), "5D tensor only supports per tensor activation and weight quantization"
), "4D/5D tensor only supports per tensor activation and weight quantization"

# weight dim: (C_out, C_in, K1, K2, K3)
# conv3d weight dim: (C_out, C_in, K1, K2, K3)
# conv2d weight dim: (C_out, C_in, K1, K2)
# skip quantization when either C_out or C_in
# is not a multiple of 16
if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0:
Expand Down
Loading
Loading