Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 176 additions & 12 deletions test/quantization/pt2e/test_x86inductor_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,40 @@ def forward(self, input):
return out


class FP8QDQConv2d(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
super().__init__()
self.qtype = torch.float8_e4m3fn
self.weight = torch.randn((out_channels, in_channels // groups, *kernel_size)).to(self.qtype)
self.weight_scale = 2.0
self.scale = 2.0
self.bias = None
if bias:
self.bias = torch.randn((out_channels,))
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups

def forward(self, input):
weight = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default(
tensor=self.weight.data,
scale=torch.tensor([self.weight_scale]),
output_dtype=torch.float,
)
q_input = torch.ops.torchao.quantize_affine_float8_non_decomposed.default(
tensor=input,
scale=torch.tensor([self.scale]),
float8_dtype=self.qtype,
)
dq_input = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default(
tensor=q_input,
scale=torch.tensor([self.scale]),
output_dtype=torch.float,
)

return torch.nn.functional.conv2d(dq_input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

def qdq(input, scale):
dtype = input.dtype
q_input = torch.ops.torchao.quantize_affine_float8_non_decomposed.default(
Expand Down Expand Up @@ -172,7 +206,7 @@ def create_mod_info_recursion(parent):
for name, mod in model.named_modules():
mod_type_str = mod.__class__.__name__
if mod_type_str not in [
"Linear",
"Linear", "Conv2d"
]:
continue
param = mod.weight
Expand All @@ -190,6 +224,11 @@ def create_mod_info_recursion(parent):
patched_mod.bias = mod.bias
patched_mod.weight_scale = weight_scale.item()
patched_mod.weight.data = q_param
elif mod_type_str in ["Conv2d"]:
patched_mod = FP8QDQConv2d(mod.in_channels, mod.out_channels, mod.kernel_size, mod.stride, mod.padding, mod.dilation, mod.groups, False)
patched_mod.bias = mod.bias
patched_mod.weight_scale = weight_scale.item()
patched_mod.weight.data = q_param

parent = parent_child_mod_dict[mod].parent
name = parent_child_mod_dict[mod].name
Expand Down Expand Up @@ -382,7 +421,7 @@ def _test_code_common(

@unittest.skipIf(not torch_version_at_least("2.8.0"), "Requires torch 2.8+")
class TestPatternMatcher(TestPatternMatcherBase):
def _qconv2d_test_helper(self, device="cpu", int8_mixed_bf16=False):
def _qconv2d_test_helper(self, device="cpu", mixed_bf16=False, is_fp8=False):
class M(torch.nn.Module):
def __init__(
self,
Expand All @@ -408,14 +447,14 @@ def forward(self, x):
def matcher_check_fn():
# 1. Dequant-Conv2D pattern matched in QConv2D weight prepack * 1
# int8_mixed_fp32: [dequant_node, dequantize_per_channel, clone, convolution]
# int8_mixed_bf16: [dequant_node, optional(convert_element_type_4),
# mixed_bf16: [dequant_node, optional(convert_element_type_4),
# dequantize_per_channel, optional(convert_element_type_3), clone, convolution]
self.assertEqual(
counters["inductor"]["qconv_weight_prepack_matcher_count"], 3
)
self.assertEqual(
counters["inductor"]["qconv_weight_prepack_matcher_nodes"],
18 if int8_mixed_bf16 else 12,
18 if mixed_bf16 else 12,
)
self.assertEqual(
counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 3
Expand All @@ -426,7 +465,8 @@ def matcher_check_fn():
(v,),
matcher_check_fn,
check_quantization=True,
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float32,
check_autocast=torch.bfloat16 if mixed_bf16 else torch.float32,
is_fp8=is_fp8,
)

@skipIfNoDynamoSupport
Expand All @@ -438,6 +478,16 @@ def test_qconv2d_cpu(self):
"""
self._qconv2d_test_helper("cpu")

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skip_if_rocm("Not applicable to ROCm")
@skipIfNoFloat8Support
def test_qconv2d_fp8_cpu(self):
r"""
This testcase will quantize a single Conv2d module.
"""
self._qconv2d_test_helper("cpu", is_fp8=True)

@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN
Expand All @@ -446,14 +496,26 @@ def test_qconv2d_int8_mixed_bf16(self):
r"""
This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization.
"""
self._qconv2d_test_helper(int8_mixed_bf16=True)
self._qconv2d_test_helper(mixed_bf16=True)

@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN
@skip_if_rocm("Not applicable to ROCm")
@skipIfNoFloat8Support
def test_qconv2d_fp8_mixed_bf16(self):
r"""
This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization.
"""
self._qconv2d_test_helper(mixed_bf16=True, is_fp8=True)

def _qconv2d_unary_test_helper(
self,
device="cpu",
int8_mixed_bf16=False,
mixed_bf16=False,
unary_op=torch.nn.ReLU(),
qconv_unary_matcher_nodes=None,
is_fp8=False,
):
class M(torch.nn.Module):
def __init__(
Expand Down Expand Up @@ -502,8 +564,9 @@ def matcher_check_fn():
mod,
(v,),
check_quantization=True,
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float32,
check_autocast=torch.bfloat16 if mixed_bf16 else torch.float32,
matcher_check_fn=matcher_check_fn,
is_fp8=is_fp8,
)

@skipIfNoDynamoSupport
Expand All @@ -514,14 +577,23 @@ def test_qconv2d_relu_cpu(self):
"""
self._qconv2d_unary_test_helper(device="cpu")

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfNoFloat8Support
def test_qconv2d_relu_fp8_cpu(self):
r"""
This testcase will quantize Conv2d->ReLU pattern.
"""
self._qconv2d_unary_test_helper(device="cpu", is_fp8=True)

@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN
def test_qconv2d_relu_int8_mixed_bf16_xpu(self):
r"""
This testcase will quantize Conv2d->ReLU pattern with int8_mixed_bf16 quantization.
"""
self._qconv2d_unary_test_helper(int8_mixed_bf16=True)
self._qconv2d_unary_test_helper(mixed_bf16=True)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
Expand All @@ -531,6 +603,15 @@ def test_qconv2d_relu6_cpu(self):
"""
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.ReLU6())

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfNoFloat8Support
def test_qconv2d_relu6_fp8_cpu(self):
r"""
This testcase will quantize Conv2d->ReLU6 pattern.
"""
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.ReLU6(), is_fp8=True)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
def test_qconv2d_hardtanh_cpu(self):
Expand All @@ -539,6 +620,15 @@ def test_qconv2d_hardtanh_cpu(self):
"""
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardtanh())

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfNoFloat8Support
def test_qconv2d_hardtanh_fp8_cpu(self):
r"""
This testcase will quantize Conv2d->Hardtanh pattern.
"""
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardtanh(), is_fp8=True)

@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN
Expand All @@ -551,10 +641,28 @@ def test_qconv2d_hardtanh_int8_mixed_bf16_cpu(self):
"""
self._qconv2d_unary_test_helper(
unary_op=torch.nn.Hardtanh(),
int8_mixed_bf16=True,
mixed_bf16=True,
qconv_unary_matcher_nodes=11,
)

@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN
@skipIfNoFloat8Support
def test_qconv2d_hardtanh_fp8_mixed_bf16_cpu(self):
r"""
This testcase will quantize Conv2d->Hardtanh pattern.
Match.nodes:
[qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type, quantize_per_tensor]
[qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type]
"""
self._qconv2d_unary_test_helper(
unary_op=torch.nn.Hardtanh(),
mixed_bf16=True,
qconv_unary_matcher_nodes=11,
is_fp8=True,
)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
def test_qconv2d_hardswish_cpu(self):
Expand All @@ -563,6 +671,15 @@ def test_qconv2d_hardswish_cpu(self):
"""
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardswish())

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfNoFloat8Support
def test_qconv2d_hardswish_fp8_cpu(self):
r"""
This testcase will quantize Conv2d->Hardswish pattern.
"""
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardswish(), is_fp8=True)

@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN
Expand All @@ -576,10 +693,29 @@ def test_qconv2d_hardswish_int8_mixed_bf16_cpu(self):
"""
self._qconv2d_unary_test_helper(
unary_op=torch.nn.Hardswish(),
int8_mixed_bf16=True,
mixed_bf16=True,
qconv_unary_matcher_nodes=17,
)

@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN
@skipIfNoFloat8Support
def test_qconv2d_hardswish_fp8_mixed_bf16_cpu(self):
r"""
This testcase will quantize Conv2d->Hardswish pattern.
Match.nodes:
[qconv2d_pointwise_default, convert_element_type, add, clamp_min,
clamp_max, mul, div, convert_element_type, quantize_per_tensor]
[qconv2d_pointwise_default, convert_element_type, add, clamp_min, clamp_max, mul, div, convert_element_type]
"""
self._qconv2d_unary_test_helper(
unary_op=torch.nn.Hardswish(),
mixed_bf16=True,
qconv_unary_matcher_nodes=17,
is_fp8=True,
)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
def test_qconv2d_silu_cpu(self):
Expand All @@ -588,6 +724,15 @@ def test_qconv2d_silu_cpu(self):
"""
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.SiLU())

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfNoFloat8Support
def test_qconv2d_silu_fp8_cpu(self):
r"""
This testcase will quantize Conv2d->SiLU pattern.
"""
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.SiLU(), is_fp8=True)

@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN
Expand All @@ -601,8 +746,27 @@ def test_qconv2d_silu_int8_mixed_bf16_cpu(self):
"""
self._qconv2d_unary_test_helper(
unary_op=torch.nn.SiLU(),
int8_mixed_bf16=True,
mixed_bf16=True,
qconv_unary_matcher_nodes=11,
)

@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN
@skipIfNoFloat8Support
def test_qconv2d_silu_fp8_mixed_bf16_cpu(self):
r"""
This testcase will quantize Conv2d->SiLU pattern.
Match.nodes:
[qconv2d_pointwise_default, convert_element_type, sigmoid, mul,
convert_element_type, quantize_per_tensor]
[qconv2d_pointwise_default, convert_element_type, sigmoid, mul, convert_element_type]
"""
self._qconv2d_unary_test_helper(
unary_op=torch.nn.SiLU(),
mixed_bf16=True,
qconv_unary_matcher_nodes=11,
is_fp8=True,
)

def _qconv2d_add_test_helper(
Expand Down
Loading