From 04bf850e8a687294ac73cf580d72bff3eecf419d Mon Sep 17 00:00:00 2001 From: "Sun, Jiayi" Date: Thu, 30 Oct 2025 09:55:34 +0000 Subject: [PATCH 1/3] [Inductor][float8] Register qconv weight prepack pass for float8 --- .../pt2e/test_x86inductor_fusion.py | 73 +++++++++++- .../quantization/pt2e/inductor_passes/x86.py | 107 ++++++++++-------- 2 files changed, 129 insertions(+), 51 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 520b5fbdfb..8b3923ee23 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -138,6 +138,40 @@ def forward(self, input): return out +class FP8QDQConv2d(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): + super().__init__() + self.qtype = torch.float8_e4m3fn + self.weight = torch.randn((out_channels, in_channels // groups, *kernel_size)).to(self.qtype) + self.weight_scale = 2.0 + self.scale = 2.0 + self.bias = None + if bias: + self.bias = torch.randn((out_channels,)) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + + def forward(self, input): + weight = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default( + tensor=self.weight.data, + scale=torch.tensor([self.weight_scale]), + output_dtype=torch.float, + ) + q_input = torch.ops.torchao.quantize_affine_float8_non_decomposed.default( + tensor=input, + scale=torch.tensor([self.scale]), + float8_dtype=self.qtype, + ) + dq_input = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default( + tensor=q_input, + scale=torch.tensor([self.scale]), + output_dtype=torch.float, + ) + + return torch.nn.functional.conv2d(dq_input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + def qdq(input, scale): dtype = input.dtype q_input = torch.ops.torchao.quantize_affine_float8_non_decomposed.default( @@ -172,7 +206,7 @@ def create_mod_info_recursion(parent): for name, mod in model.named_modules(): mod_type_str = mod.__class__.__name__ if mod_type_str not in [ - "Linear", + "Linear", "Conv2d" ]: continue param = mod.weight @@ -190,6 +224,11 @@ def create_mod_info_recursion(parent): patched_mod.bias = mod.bias patched_mod.weight_scale = weight_scale.item() patched_mod.weight.data = q_param + elif mod_type_str in ["Conv2d"]: + patched_mod = FP8QDQConv2d(mod.in_channels, mod.out_channels, mod.kernel_size, mod.stride, mod.padding, mod.dilation, mod.groups, False) + patched_mod.bias = mod.bias + patched_mod.weight_scale = weight_scale.item() + patched_mod.weight.data = q_param parent = parent_child_mod_dict[mod].parent name = parent_child_mod_dict[mod].name @@ -382,7 +421,7 @@ def _test_code_common( @unittest.skipIf(not torch_version_at_least("2.8.0"), "Requires torch 2.8+") class TestPatternMatcher(TestPatternMatcherBase): - def _qconv2d_test_helper(self, device="cpu", int8_mixed_bf16=False): + def _qconv2d_test_helper(self, device="cpu", mixed_bf16=False, is_fp8=False): class M(torch.nn.Module): def __init__( self, @@ -408,14 +447,14 @@ def forward(self, x): def matcher_check_fn(): # 1. Dequant-Conv2D pattern matched in QConv2D weight prepack * 1 # int8_mixed_fp32: [dequant_node, dequantize_per_channel, clone, convolution] - # int8_mixed_bf16: [dequant_node, optional(convert_element_type_4), + # mixed_bf16: [dequant_node, optional(convert_element_type_4), # dequantize_per_channel, optional(convert_element_type_3), clone, convolution] self.assertEqual( counters["inductor"]["qconv_weight_prepack_matcher_count"], 3 ) self.assertEqual( counters["inductor"]["qconv_weight_prepack_matcher_nodes"], - 18 if int8_mixed_bf16 else 12, + 18 if mixed_bf16 else 12, ) self.assertEqual( counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 3 @@ -426,7 +465,8 @@ def matcher_check_fn(): (v,), matcher_check_fn, check_quantization=True, - check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float32, + check_autocast=torch.bfloat16 if mixed_bf16 else torch.float32, + is_fp8=is_fp8, ) @skipIfNoDynamoSupport @@ -438,6 +478,16 @@ def test_qconv2d_cpu(self): """ self._qconv2d_test_helper("cpu") + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skip_if_rocm("Not applicable to ROCm") + @skipIfNoFloat8Support + def test_qconv2d_fp8_cpu(self): + r""" + This testcase will quantize a single Conv2d module. + """ + self._qconv2d_test_helper("cpu", is_fp8=True) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -446,7 +496,18 @@ def test_qconv2d_int8_mixed_bf16(self): r""" This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization. """ - self._qconv2d_test_helper(int8_mixed_bf16=True) + self._qconv2d_test_helper(mixed_bf16=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skip_if_rocm("Not applicable to ROCm") + @skipIfNoFloat8Support + def test_qconv2d_fp8_mixed_bf16(self): + r""" + This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization. + """ + self._qconv2d_test_helper(mixed_bf16=True, is_fp8=True) def _qconv2d_unary_test_helper( self, diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index a0aef11541..ddcba76f2d 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -167,24 +167,28 @@ def get_dequantize_per_tensor_activation_pattern( KeywordArg("w_dtype"), ) -dequantize_per_channel_to_bf16_weight_pattern = ( - _may_generate_pattern_with_dtype_convert( - dequantize_per_channel_weight_pattern, +dequantize_fp8_weight_pattern = CallFunction( + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, + KeywordArg("q_weight"), + KeywordArg("w_scale"), + output_dtype=KeywordArg("w_dtype"), +) + +def get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern): + return _may_generate_pattern_with_dtype_convert( + dequant_wgt_pattern, KeywordArg("autocast_wgt_dtype"), ) -) -dequantize_per_channel_clone_weight_pattern = CallFunction( - aten.clone.default, - dequantize_per_channel_weight_pattern, - memory_format=KeywordArg("memory_format"), -) +def get_dequantize_clone_weight_pattern(dequant_wgt_pattern): + return CallFunction( + aten.clone.default, + dequant_wgt_pattern, + memory_format=KeywordArg("memory_format"), + ) -dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction( - aten.clone.default, - dequantize_per_channel_to_bf16_weight_pattern, - memory_format=KeywordArg("memory_format"), -) +def get_dequantize_to_bf16_clone_weight_pattern(dequant_wgt_pattern): + return get_dequantize_clone_weight_pattern(get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern)) def get_qconv_pt2e_pattern(users=1): @@ -711,7 +715,7 @@ def _inner(match): return _inner -def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32): +def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32, is_fp8=False): @register_freezing_graph_pattern( pattern, extra_check=_is_valid_dequant_conv_pattern(dtype), @@ -724,7 +728,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): | dequant_per_tensor | - Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight + Conv2d <- optional(aten.clone.default) <- dequant <- int8_weight Insert weight prepack node and change the pattern to: int8 activation @@ -747,7 +751,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): ) if dtype == torch.float32: - dequant_per_channel = ( + dequant = ( clone_node.args[0] # type: ignore[union-attr] if has_clone_to_channel_last_node_in_pattern else conv_node.args[1] @@ -758,9 +762,9 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): if has_clone_to_channel_last_node_in_pattern else conv_node.args[1] ) - dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr] + dequant = weight_to_bf16_node.args[0] # type: ignore[union-attr] - assert dequant_per_channel.target in [ # type: ignore[union-attr] + assert dequant.target in [ # type: ignore[union-attr] quantized_decomposed.dequantize_per_channel.default, torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, ] @@ -768,7 +772,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): # Activation QParams qx, x_zp, x_scale = ( kwargs["x"], - kwargs["x_zp"], + kwargs["x_zp"] if "x_zp" in kwargs else None, kwargs["x_scale"], ) @@ -776,7 +780,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): qw, w_scale, w_zp = ( kwargs["q_weight"], kwargs["w_scale"], - kwargs["w_zp"], + kwargs["w_zp"] if "w_zp" in kwargs else None, ) # Conv Params @@ -792,14 +796,19 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): if has_free_symbols(x_shape): # For dynamic shape case, we can't get activation shape ahead of runtime. x_shape = None + if is_fp8 and w_scale.target is torch.ops.aten.full.default: + with torch.utils._python_dispatch._disable_current_modes(): + w_scale_tensor = torch.tensor([w_scale.args[1]]) + match.graph.owning_module.register_buffer("w_scale", w_scale_tensor) + w_scale = match.graph.create_node("get_attr", "w_scale") graph = match.graph with graph.inserting_before(conv_node): # Insert weight prepack node and the QConv node packed_weight_inputs = ( qw, w_scale, - x_scale, - x_zp, + x_scale.args[1] if is_fp8 and x_scale.target is torch.ops.aten.full.default else x_scale, + 0, stride, padding, dilation, @@ -830,9 +839,16 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): [], # scalars "", # algorithm ) - new_conv_node = graph.call_function( - torch.ops.onednn.qconv_pointwise.default, args=new_args - ) + Node = torch.fx.node.Node + # fp8 not need zp + if isinstance(x_scale, Node) and (isinstance(x_zp, Node) or is_fp8): + new_conv_node = graph.call_function( + torch.ops.onednn.qconv_pointwise.tensor, args=new_args + ) + else: + new_conv_node = graph.call_function( + torch.ops.onednn.qconv_pointwise.default, args=new_args + ) conv_node.replace_all_uses_with(new_conv_node) new_conv_node.meta.update(conv_node.meta) @@ -847,7 +863,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): graph.erase_node(clone_node) # type: ignore[arg-type] if dtype == torch.bfloat16: graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type] - graph.erase_node(dequant_per_channel) # type: ignore[arg-type] + graph.erase_node(dequant) # type: ignore[arg-type] counters["inductor"]["qconv_weight_prepack_matcher_count"] += 1 counters["inductor"]["qconv_weight_prepack_matcher_nodes"] += len( match.nodes @@ -855,17 +871,17 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): def _generate_dequant_convolution_node_pattern( - _dequant_per_channel_pattern, dtype=torch.float32 + _dequant_pattern, dtype=torch.float32, is_fp8=False ): assert dtype in [torch.float32, torch.bfloat16] dequant_convolution_node_pattern = CallFunction( aten.convolution.default, _may_generate_pattern_with_dtype_convert( - get_dequantize_per_tensor_activation_pattern(), + get_dequantize_per_tensor_activation_pattern(is_fp8=is_fp8), KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, ), - _dequant_per_channel_pattern, + _dequant_pattern, KeywordArg("b"), KeywordArg("stride"), KeywordArg("padding"), @@ -877,24 +893,30 @@ def _generate_dequant_convolution_node_pattern( return dequant_convolution_node_pattern -def _generate_qconv_weight_prepack_patterns(dtype=torch.float32): +def _generate_qconv_weight_prepack_patterns(dtype=torch.float32, is_fp8=False): assert dtype in [torch.float32, torch.bfloat16] + if is_fp8: + dequant_wgt_pattern = dequantize_fp8_weight_pattern + else: + dequant_wgt_pattern = dequantize_per_channel_weight_pattern return ( _generate_dequant_convolution_node_pattern( - dequantize_per_channel_weight_pattern + dequant_wgt_pattern if dtype == torch.float32 - else dequantize_per_channel_to_bf16_weight_pattern, + else get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern), dtype, + is_fp8=is_fp8, ), # There is another pattern due to the pass of convert_conv_weights_to_channels_last # https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362. # Depend on some heuristics, it may or may not insert to(channel_last) node - # between convolution and dequant_per_channel node + # between convolution and dequant node _generate_dequant_convolution_node_pattern( - dequantize_per_channel_clone_weight_pattern + get_dequantize_clone_weight_pattern(dequant_wgt_pattern) if dtype == torch.float32 - else dequantize_per_channel_to_bf16_clone_weight_pattern, + else get_dequantize_to_bf16_clone_weight_pattern(dequant_wgt_pattern), dtype, + is_fp8=is_fp8, ), ) @@ -1302,12 +1324,7 @@ def _generate_qlinear_weight_prepack_patterns( is_fp8=False, ): if is_fp8: - dequant_wgt_pattern = CallFunction( - torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, - KeywordArg("q_weight"), - KeywordArg("w_scale"), - output_dtype=KeywordArg("w_dtype"), - ) + dequant_wgt_pattern = dequantize_fp8_weight_pattern else: dequant_wgt_pattern = dequantize_per_channel_weight_pattern if input_dim_exceeds_two and not input_contiguous: @@ -1449,12 +1466,12 @@ def _register_dequant_promotion(): def _register_qconv_weight_prepack(): - for dtype in [torch.float32, torch.bfloat16]: - weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype) + for dtype, is_fp8 in itertools.product([torch.float32, torch.bfloat16], [True, False]): + weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype, is_fp8=is_fp8) for weight_prepack_pattern in weight_prepack_patterns: # Register to pass_number 1, so we can do dequant promotion in pass_number 0. _register_qconv_weight_prepack_pass( - weight_prepack_pattern, pass_number=1, dtype=dtype + weight_prepack_pattern, pass_number=1, dtype=dtype, is_fp8=is_fp8 ) From 0160379562fa4460e6e0e6510c44db4da9f2463b Mon Sep 17 00:00:00 2001 From: "Sun, Jiayi" Date: Tue, 11 Nov 2025 05:22:14 +0000 Subject: [PATCH 2/3] [Inductor][float8] Register qconv-unary fusion pass for float8 --- .../pt2e/test_x86inductor_fusion.py | 115 +++++++++++++++++- .../quantization/pt2e/inductor_passes/x86.py | 101 ++++++++------- 2 files changed, 157 insertions(+), 59 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 8b3923ee23..fe5aa36f53 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -512,9 +512,10 @@ def test_qconv2d_fp8_mixed_bf16(self): def _qconv2d_unary_test_helper( self, device="cpu", - int8_mixed_bf16=False, + mixed_bf16=False, unary_op=torch.nn.ReLU(), qconv_unary_matcher_nodes=None, + is_fp8=False, ): class M(torch.nn.Module): def __init__( @@ -563,8 +564,9 @@ def matcher_check_fn(): mod, (v,), check_quantization=True, - check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float32, + check_autocast=torch.bfloat16 if mixed_bf16 else torch.float32, matcher_check_fn=matcher_check_fn, + is_fp8=is_fp8, ) @skipIfNoDynamoSupport @@ -575,6 +577,15 @@ def test_qconv2d_relu_cpu(self): """ self._qconv2d_unary_test_helper(device="cpu") + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_qconv2d_relu_fp8_cpu(self): + r""" + This testcase will quantize Conv2d->ReLU pattern. + """ + self._qconv2d_unary_test_helper(device="cpu", is_fp8=True) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -582,7 +593,7 @@ def test_qconv2d_relu_int8_mixed_bf16_xpu(self): r""" This testcase will quantize Conv2d->ReLU pattern with int8_mixed_bf16 quantization. """ - self._qconv2d_unary_test_helper(int8_mixed_bf16=True) + self._qconv2d_unary_test_helper(mixed_bf16=True) @skipIfNoDynamoSupport @skipIfNoONEDNN @@ -592,6 +603,15 @@ def test_qconv2d_relu6_cpu(self): """ self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.ReLU6()) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_qconv2d_relu6_fp8_cpu(self): + r""" + This testcase will quantize Conv2d->ReLU6 pattern. + """ + self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.ReLU6(), is_fp8=True) + @skipIfNoDynamoSupport @skipIfNoONEDNN def test_qconv2d_hardtanh_cpu(self): @@ -600,6 +620,15 @@ def test_qconv2d_hardtanh_cpu(self): """ self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardtanh()) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_qconv2d_hardtanh_fp8_cpu(self): + r""" + This testcase will quantize Conv2d->Hardtanh pattern. + """ + self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardtanh(), is_fp8=True) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -612,8 +641,26 @@ def test_qconv2d_hardtanh_int8_mixed_bf16_cpu(self): """ self._qconv2d_unary_test_helper( unary_op=torch.nn.Hardtanh(), - int8_mixed_bf16=True, + mixed_bf16=True, + qconv_unary_matcher_nodes=11, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_qconv2d_hardtanh_fp8_mixed_bf16_cpu(self): + r""" + This testcase will quantize Conv2d->Hardtanh pattern. + Match.nodes: + [qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type, quantize_per_tensor] + [qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type] + """ + self._qconv2d_unary_test_helper( + unary_op=torch.nn.Hardtanh(), + mixed_bf16=True, qconv_unary_matcher_nodes=11, + is_fp8=True, ) @skipIfNoDynamoSupport @@ -624,6 +671,15 @@ def test_qconv2d_hardswish_cpu(self): """ self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardswish()) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_qconv2d_hardswish_fp8_cpu(self): + r""" + This testcase will quantize Conv2d->Hardswish pattern. + """ + self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardswish(), is_fp8=True) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -637,8 +693,27 @@ def test_qconv2d_hardswish_int8_mixed_bf16_cpu(self): """ self._qconv2d_unary_test_helper( unary_op=torch.nn.Hardswish(), - int8_mixed_bf16=True, + mixed_bf16=True, + qconv_unary_matcher_nodes=17, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_qconv2d_hardswish_fp8_mixed_bf16_cpu(self): + r""" + This testcase will quantize Conv2d->Hardswish pattern. + Match.nodes: + [qconv2d_pointwise_default, convert_element_type, add, clamp_min, + clamp_max, mul, div, convert_element_type, quantize_per_tensor] + [qconv2d_pointwise_default, convert_element_type, add, clamp_min, clamp_max, mul, div, convert_element_type] + """ + self._qconv2d_unary_test_helper( + unary_op=torch.nn.Hardswish(), + mixed_bf16=True, qconv_unary_matcher_nodes=17, + is_fp8=True, ) @skipIfNoDynamoSupport @@ -649,6 +724,15 @@ def test_qconv2d_silu_cpu(self): """ self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.SiLU()) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_qconv2d_silu_fp8_cpu(self): + r""" + This testcase will quantize Conv2d->SiLU pattern. + """ + self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.SiLU(), is_fp8=True) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -662,10 +746,29 @@ def test_qconv2d_silu_int8_mixed_bf16_cpu(self): """ self._qconv2d_unary_test_helper( unary_op=torch.nn.SiLU(), - int8_mixed_bf16=True, + mixed_bf16=True, qconv_unary_matcher_nodes=11, ) + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_qconv2d_silu_fp8_mixed_bf16_cpu(self): + r""" + This testcase will quantize Conv2d->SiLU pattern. + Match.nodes: + [qconv2d_pointwise_default, convert_element_type, sigmoid, mul, + convert_element_type, quantize_per_tensor] + [qconv2d_pointwise_default, convert_element_type, sigmoid, mul, convert_element_type] + """ + self._qconv2d_unary_test_helper( + unary_op=torch.nn.SiLU(), + mixed_bf16=True, + qconv_unary_matcher_nodes=11, + is_fp8=True, + ) + def _qconv2d_add_test_helper( self, device="cpu", use_relu=False, int8_mixed_bf16=False ): diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index ddcba76f2d..09e35105e0 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -191,9 +191,14 @@ def get_dequantize_to_bf16_clone_weight_pattern(dequant_wgt_pattern): return get_dequantize_clone_weight_pattern(get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern)) -def get_qconv_pt2e_pattern(users=1): +def get_qconv_pt2e_pattern(x_scale_zp_are_tensors=False, users=1): + qconv_op = ( + torch.ops.onednn.qconv_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qconv_pointwise.default + ) return CallFunction( - torch.ops.onednn.qconv_pointwise.default, + qconv_op, KeywordArg("x"), KeywordArg("x_scale"), KeywordArg("x_zp"), @@ -215,35 +220,6 @@ def get_qconv_pt2e_pattern(users=1): ) -def get_qconv2d_binary_pt2e_pattern(users=1): - return CallFunction( - torch.ops.onednn.qconv2d_pointwise.binary, - KeywordArg("x"), - KeywordArg("x_scale"), - KeywordArg("x_zp"), - KeywordArg("packed_weight"), - KeywordArg("w_scale"), - KeywordArg("w_zp"), - KeywordArg("accum"), - KeywordArg("b"), - KeywordArg("stride"), - KeywordArg("padding"), - KeywordArg("dilation"), - KeywordArg("groups"), - KeywordArg("output_scale"), - KeywordArg("output_zero_point"), - KeywordArg("output_dtype"), - KeywordArg("accum_scale"), - KeywordArg("accum_zero_point"), - KeywordArg("binary_op_name"), - KeywordArg("alpha"), - KeywordArg("unary_op_name"), - KeywordArg("unary_op_args"), - KeywordArg("unary_op_algorithm"), - _users=users, - ) - - def get_qlinear_pt2e_pattern(x_scale_zp_are_tensors, users=1): qlinear_op = ( torch.ops.onednn.qlinear_pointwise.tensor @@ -2070,13 +2046,19 @@ def qconv(match: Match, *args, **kwargs): kwargs["groups"], ) output_dtype = _get_pattern_output_dtype(match) - assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16] + assert output_dtype in [torch.int8, torch.uint8, torch.float8_e4m3fn, torch.float32, torch.bfloat16] # Output QParams - o_inv_scale = ( - kwargs["o_inv_scale"] - if (output_dtype == torch.uint8 or output_dtype == torch.int8) - else 1.0 - ) + if output_dtype == torch.float8_e4m3fn: + # For float8, torchao.quantize_affine_float8 requires tensor as scale + # Support scale node is full firstly + assert kwargs["o_inv_scale"].target is torch.ops.aten.full.default + o_inv_scale = kwargs["o_inv_scale"].args[1] + else: + o_inv_scale = ( + kwargs["o_inv_scale"] + if (output_dtype == torch.uint8 or output_dtype == torch.int8) + else 1.0 + ) o_zero_point = ( kwargs["o_zp"] if (output_dtype == torch.uint8 or output_dtype == torch.int8) @@ -2182,56 +2164,69 @@ def _register_qconv_unary_fusion(): _silu_fusion, ) - for original_pattern_output_dtype in [torch.float32, torch.bfloat16]: + combinations = itertools.product( + [torch.float32, torch.bfloat16], [False, True], [False, True] + ) + for original_pattern_output_dtype, x_scale_zp_are_tensors, is_fp8 in combinations: # Priority 1 to match: QConv2d Unary pattern with int8 output # If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly. # For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant is_bf16 = original_pattern_output_dtype == torch.bfloat16 + computation_op = ( + torch.ops.onednn.qconv_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qconv_pointwise.default + ) conv_unary_replace_patterns = { PostOpAttr( "none", None, "none", [], "" ): generate_pattern_with_output_quant( - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), + is_fp8=is_fp8, ), PostOpAttr( "none", None, "relu", [], "" ): generate_pattern_with_output_quant( generate_pattern_with_unary( - get_qconv_pt2e_pattern(1), aten.relu.default + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), aten.relu.default ), + is_fp8=is_fp8, ), PostOpAttr( "none", None, "hardtanh", [], "" ): generate_pattern_with_output_quant( _unary_fusion_pattern( _hardtanh_fusion, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), 1, is_bf16, ), with_dtype_convert=is_bf16, + is_fp8=is_fp8, ), PostOpAttr( "none", None, "hardswish", [], "" ): generate_pattern_with_output_quant( _unary_fusion_pattern( _hardswish_fusion, - get_qconv_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1 if is_bf16 else 2), 2, is_bf16, ), with_dtype_convert=is_bf16, + is_fp8=is_fp8, ), PostOpAttr( "none", None, "swish", [], "" ): generate_pattern_with_output_quant( _unary_fusion_pattern( _silu_fusion, - get_qconv_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1 if is_bf16 else 2), 2, is_bf16, ), with_dtype_convert=is_bf16, + is_fp8=is_fp8, ), } @@ -2240,21 +2235,21 @@ def _register_qconv_unary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 3, # pass_number - torch.ops.onednn.qconv_pointwise.default, # computation_op + computation_op, # computation_op unary_attr, # unary_attr ) # Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output conv_unary_replace_float_out_patterns = { PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary( - get_qconv_pt2e_pattern(1), aten.relu.default + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), aten.relu.default ), PostOpAttr( "none", None, "hardtanh", [], "" ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _hardtanh_fusion, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), 1, is_bf16, ), @@ -2266,7 +2261,7 @@ def _register_qconv_unary_fusion(): ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _hardswish_fusion, - get_qconv_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1 if is_bf16 else 2), 2, is_bf16, ), @@ -2278,7 +2273,7 @@ def _register_qconv_unary_fusion(): ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _silu_fusion, - get_qconv_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1 if is_bf16 else 2), 2, is_bf16, ), @@ -2292,7 +2287,7 @@ def _register_qconv_unary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 4, # pass_number - torch.ops.onednn.qconv_pointwise.default, # computation_op + computation_op, # computation_op unary_attr, # unary_attr ) @@ -2310,7 +2305,7 @@ def _register_qconv_binary_fusion(): ): generate_pattern_with_output_quant( generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(users=1), dequantize_accum_pattern, int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, @@ -2322,7 +2317,7 @@ def _register_qconv_binary_fusion(): generate_pattern_with_unary( generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(users=1), dequantize_accum_pattern, int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, @@ -2349,7 +2344,7 @@ def _register_qconv_binary_fusion(): PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary( generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(users=1), KeywordArg("accum_after_dequant"), int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, @@ -2387,7 +2382,7 @@ def _register_qconv_binary_fusion(): "sum", 1.0, "none", [], "" ): generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(users=1), KeywordArg("accum_after_dequant"), int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, From 23bd2b82edafe308d8a38bfd5cb63359b613d651 Mon Sep 17 00:00:00 2001 From: "Sun, Jiayi" Date: Wed, 12 Nov 2025 09:16:26 +0000 Subject: [PATCH 3/3] [Inductor][float8] Register qconv-binary fusion pass for float8 --- .../quantization/pt2e/inductor_passes/x86.py | 52 ++++++++++++------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index 09e35105e0..791e9a42e4 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -376,9 +376,9 @@ def fn(match): return fn -def _is_valid_qconv_post_op_fusion_pattern(has_binary_post_op=False): +def _is_valid_qconv_post_op_fusion_pattern(has_binary_post_op=False, is_fp8=False): return ( - _is_valid_qconv_binary_optimization_pattern() + _is_valid_qconv_binary_optimization_pattern(is_fp8=is_fp8) if has_binary_post_op else _is_valid_quantized_conv_optimization_pattern() ) @@ -408,9 +408,11 @@ def _is_valid_qlinear_post_op_fusion_pattern(has_binary_post_op=False): ) -def _is_valid_qconv_binary_optimization_pattern(): +def _is_valid_qconv_binary_optimization_pattern(is_fp8=False): return _is_valid_quantized_op_binary_optimization_pattern( - torch.ops.onednn.qconv_pointwise + torch.ops.onednn.qconv_pointwise, + # we don't insert q-dq for extra input in fp8 recipe + extra_input_from_dequant= not is_fp8, ) @@ -2016,12 +2018,13 @@ def _register_qconv_post_op_fusion_pass( pass_number, computation_op, post_op_attr, + is_fp8=False, ): has_binary_post_op = post_op_attr.binary_op_name != "none" @register_freezing_graph_pattern( pattern, - extra_check=_is_valid_qconv_post_op_fusion_pattern(has_binary_post_op), + extra_check=_is_valid_qconv_post_op_fusion_pattern(has_binary_post_op, is_fp8=is_fp8), pass_number=pass_number, ) def qconv(match: Match, *args, **kwargs): @@ -2097,7 +2100,7 @@ def qconv(match: Match, *args, **kwargs): else: accum = ( kwargs["accum"] - if output_dtype in [torch.uint8, torch.int8] + if output_dtype in [torch.uint8, torch.int8] or is_fp8 else kwargs["accum_after_dequant"] ) accum_scale = ( @@ -2237,6 +2240,7 @@ def _register_qconv_unary_fusion(): 3, # pass_number computation_op, # computation_op unary_attr, # unary_attr + is_fp8=is_fp8, ) # Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output @@ -2289,15 +2293,21 @@ def _register_qconv_unary_fusion(): 4, # pass_number computation_op, # computation_op unary_attr, # unary_attr + is_fp8=is_fp8, ) def _register_qconv_binary_fusion(): - for int8_mixed_bf16_with_inplace_add in [False, True]: + for int8_mixed_bf16_with_inplace_add, x_scale_zp_are_tensors in itertools.product([False, True], [False, True]): + qconv_binary_op = ( + torch.ops.onednn.qconv2d_pointwise.binary_tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qconv2d_pointwise.binary + ) # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output swap_binary_inputs_list = [False, True] binary_replace_patterns = {} - for swap_inputs in swap_binary_inputs_list: + for swap_inputs, is_fp8 in itertools.product(swap_binary_inputs_list, [False, True]): binary_replace_patterns.update( { PostOpAttr( @@ -2305,11 +2315,12 @@ def _register_qconv_binary_fusion(): ): generate_pattern_with_output_quant( generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(users=1), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), dequantize_accum_pattern, int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, ), + is_fp8=is_fp8, ), PostOpAttr( "sum", 1.0, "relu", [], "" @@ -2317,13 +2328,14 @@ def _register_qconv_binary_fusion(): generate_pattern_with_unary( generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(users=1), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), dequantize_accum_pattern, int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, ), aten.relu.default, ), + is_fp8=is_fp8, ), } ) @@ -2332,8 +2344,9 @@ def _register_qconv_binary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 3, # pass_number - torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + qconv_binary_op, # computation_op binary_unary_attr, # binary_unary_attr + is_fp8=is_fp8, ) # Priority 2 to match: QConv2d Binary-Unary pattern with fp32/bfloat16 output @@ -2344,8 +2357,8 @@ def _register_qconv_binary_fusion(): PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary( generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(users=1), - KeywordArg("accum_after_dequant"), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), + KeywordArg("accum") if is_fp8 else KeywordArg("accum_after_dequant"), int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, ), @@ -2362,15 +2375,17 @@ def _register_qconv_binary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 3, # pass_number - torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + qconv_binary_op, # computation_op binary_unary_attr, # binary_unary_attr + is_fp8=is_fp8, ) else: _register_qconv_post_op_fusion_pass( patterns, 4, # pass_number - torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + qconv_binary_op, # computation_op binary_unary_attr, # binary_unary_attr + is_fp8=is_fp8, ) # Priority 3: QConv2d Binary pattern with fp32/bfloat16 output @@ -2382,8 +2397,8 @@ def _register_qconv_binary_fusion(): "sum", 1.0, "none", [], "" ): generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(users=1), - KeywordArg("accum_after_dequant"), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), + KeywordArg("accum") if is_fp8 else KeywordArg("accum_after_dequant"), int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, ), @@ -2397,8 +2412,9 @@ def _register_qconv_binary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 4 if int8_mixed_bf16_with_inplace_add else 5, # pass_number - torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + qconv_binary_op, # computation_op binary_unary_attr, # binary_unary_attr + is_fp8=is_fp8, )