diff --git a/test/quantization/pt2e/test_x86inductor_fusion.py b/test/quantization/pt2e/test_x86inductor_fusion.py index 520b5fbdfb..fe5aa36f53 100644 --- a/test/quantization/pt2e/test_x86inductor_fusion.py +++ b/test/quantization/pt2e/test_x86inductor_fusion.py @@ -138,6 +138,40 @@ def forward(self, input): return out +class FP8QDQConv2d(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): + super().__init__() + self.qtype = torch.float8_e4m3fn + self.weight = torch.randn((out_channels, in_channels // groups, *kernel_size)).to(self.qtype) + self.weight_scale = 2.0 + self.scale = 2.0 + self.bias = None + if bias: + self.bias = torch.randn((out_channels,)) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + + def forward(self, input): + weight = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default( + tensor=self.weight.data, + scale=torch.tensor([self.weight_scale]), + output_dtype=torch.float, + ) + q_input = torch.ops.torchao.quantize_affine_float8_non_decomposed.default( + tensor=input, + scale=torch.tensor([self.scale]), + float8_dtype=self.qtype, + ) + dq_input = torch.ops.torchao.dequantize_affine_float8_non_decomposed.default( + tensor=q_input, + scale=torch.tensor([self.scale]), + output_dtype=torch.float, + ) + + return torch.nn.functional.conv2d(dq_input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + def qdq(input, scale): dtype = input.dtype q_input = torch.ops.torchao.quantize_affine_float8_non_decomposed.default( @@ -172,7 +206,7 @@ def create_mod_info_recursion(parent): for name, mod in model.named_modules(): mod_type_str = mod.__class__.__name__ if mod_type_str not in [ - "Linear", + "Linear", "Conv2d" ]: continue param = mod.weight @@ -190,6 +224,11 @@ def create_mod_info_recursion(parent): patched_mod.bias = mod.bias patched_mod.weight_scale = weight_scale.item() patched_mod.weight.data = q_param + elif mod_type_str in ["Conv2d"]: + patched_mod = FP8QDQConv2d(mod.in_channels, mod.out_channels, mod.kernel_size, mod.stride, mod.padding, mod.dilation, mod.groups, False) + patched_mod.bias = mod.bias + patched_mod.weight_scale = weight_scale.item() + patched_mod.weight.data = q_param parent = parent_child_mod_dict[mod].parent name = parent_child_mod_dict[mod].name @@ -382,7 +421,7 @@ def _test_code_common( @unittest.skipIf(not torch_version_at_least("2.8.0"), "Requires torch 2.8+") class TestPatternMatcher(TestPatternMatcherBase): - def _qconv2d_test_helper(self, device="cpu", int8_mixed_bf16=False): + def _qconv2d_test_helper(self, device="cpu", mixed_bf16=False, is_fp8=False): class M(torch.nn.Module): def __init__( self, @@ -408,14 +447,14 @@ def forward(self, x): def matcher_check_fn(): # 1. Dequant-Conv2D pattern matched in QConv2D weight prepack * 1 # int8_mixed_fp32: [dequant_node, dequantize_per_channel, clone, convolution] - # int8_mixed_bf16: [dequant_node, optional(convert_element_type_4), + # mixed_bf16: [dequant_node, optional(convert_element_type_4), # dequantize_per_channel, optional(convert_element_type_3), clone, convolution] self.assertEqual( counters["inductor"]["qconv_weight_prepack_matcher_count"], 3 ) self.assertEqual( counters["inductor"]["qconv_weight_prepack_matcher_nodes"], - 18 if int8_mixed_bf16 else 12, + 18 if mixed_bf16 else 12, ) self.assertEqual( counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 3 @@ -426,7 +465,8 @@ def matcher_check_fn(): (v,), matcher_check_fn, check_quantization=True, - check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float32, + check_autocast=torch.bfloat16 if mixed_bf16 else torch.float32, + is_fp8=is_fp8, ) @skipIfNoDynamoSupport @@ -438,6 +478,16 @@ def test_qconv2d_cpu(self): """ self._qconv2d_test_helper("cpu") + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skip_if_rocm("Not applicable to ROCm") + @skipIfNoFloat8Support + def test_qconv2d_fp8_cpu(self): + r""" + This testcase will quantize a single Conv2d module. + """ + self._qconv2d_test_helper("cpu", is_fp8=True) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -446,14 +496,26 @@ def test_qconv2d_int8_mixed_bf16(self): r""" This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization. """ - self._qconv2d_test_helper(int8_mixed_bf16=True) + self._qconv2d_test_helper(mixed_bf16=True) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skip_if_rocm("Not applicable to ROCm") + @skipIfNoFloat8Support + def test_qconv2d_fp8_mixed_bf16(self): + r""" + This testcase will quantize a single Conv2d module with int8_mixed_bf16 quantization. + """ + self._qconv2d_test_helper(mixed_bf16=True, is_fp8=True) def _qconv2d_unary_test_helper( self, device="cpu", - int8_mixed_bf16=False, + mixed_bf16=False, unary_op=torch.nn.ReLU(), qconv_unary_matcher_nodes=None, + is_fp8=False, ): class M(torch.nn.Module): def __init__( @@ -502,8 +564,9 @@ def matcher_check_fn(): mod, (v,), check_quantization=True, - check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float32, + check_autocast=torch.bfloat16 if mixed_bf16 else torch.float32, matcher_check_fn=matcher_check_fn, + is_fp8=is_fp8, ) @skipIfNoDynamoSupport @@ -514,6 +577,15 @@ def test_qconv2d_relu_cpu(self): """ self._qconv2d_unary_test_helper(device="cpu") + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_qconv2d_relu_fp8_cpu(self): + r""" + This testcase will quantize Conv2d->ReLU pattern. + """ + self._qconv2d_unary_test_helper(device="cpu", is_fp8=True) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -521,7 +593,7 @@ def test_qconv2d_relu_int8_mixed_bf16_xpu(self): r""" This testcase will quantize Conv2d->ReLU pattern with int8_mixed_bf16 quantization. """ - self._qconv2d_unary_test_helper(int8_mixed_bf16=True) + self._qconv2d_unary_test_helper(mixed_bf16=True) @skipIfNoDynamoSupport @skipIfNoONEDNN @@ -531,6 +603,15 @@ def test_qconv2d_relu6_cpu(self): """ self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.ReLU6()) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_qconv2d_relu6_fp8_cpu(self): + r""" + This testcase will quantize Conv2d->ReLU6 pattern. + """ + self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.ReLU6(), is_fp8=True) + @skipIfNoDynamoSupport @skipIfNoONEDNN def test_qconv2d_hardtanh_cpu(self): @@ -539,6 +620,15 @@ def test_qconv2d_hardtanh_cpu(self): """ self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardtanh()) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_qconv2d_hardtanh_fp8_cpu(self): + r""" + This testcase will quantize Conv2d->Hardtanh pattern. + """ + self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardtanh(), is_fp8=True) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -551,10 +641,28 @@ def test_qconv2d_hardtanh_int8_mixed_bf16_cpu(self): """ self._qconv2d_unary_test_helper( unary_op=torch.nn.Hardtanh(), - int8_mixed_bf16=True, + mixed_bf16=True, qconv_unary_matcher_nodes=11, ) + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_qconv2d_hardtanh_fp8_mixed_bf16_cpu(self): + r""" + This testcase will quantize Conv2d->Hardtanh pattern. + Match.nodes: + [qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type, quantize_per_tensor] + [qconv2d_pointwise_default, convert_element_type, clamp_min, clamp_max, convert_element_type] + """ + self._qconv2d_unary_test_helper( + unary_op=torch.nn.Hardtanh(), + mixed_bf16=True, + qconv_unary_matcher_nodes=11, + is_fp8=True, + ) + @skipIfNoDynamoSupport @skipIfNoONEDNN def test_qconv2d_hardswish_cpu(self): @@ -563,6 +671,15 @@ def test_qconv2d_hardswish_cpu(self): """ self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardswish()) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_qconv2d_hardswish_fp8_cpu(self): + r""" + This testcase will quantize Conv2d->Hardswish pattern. + """ + self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardswish(), is_fp8=True) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -576,10 +693,29 @@ def test_qconv2d_hardswish_int8_mixed_bf16_cpu(self): """ self._qconv2d_unary_test_helper( unary_op=torch.nn.Hardswish(), - int8_mixed_bf16=True, + mixed_bf16=True, qconv_unary_matcher_nodes=17, ) + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_qconv2d_hardswish_fp8_mixed_bf16_cpu(self): + r""" + This testcase will quantize Conv2d->Hardswish pattern. + Match.nodes: + [qconv2d_pointwise_default, convert_element_type, add, clamp_min, + clamp_max, mul, div, convert_element_type, quantize_per_tensor] + [qconv2d_pointwise_default, convert_element_type, add, clamp_min, clamp_max, mul, div, convert_element_type] + """ + self._qconv2d_unary_test_helper( + unary_op=torch.nn.Hardswish(), + mixed_bf16=True, + qconv_unary_matcher_nodes=17, + is_fp8=True, + ) + @skipIfNoDynamoSupport @skipIfNoONEDNN def test_qconv2d_silu_cpu(self): @@ -588,6 +724,15 @@ def test_qconv2d_silu_cpu(self): """ self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.SiLU()) + @skipIfNoDynamoSupport + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_qconv2d_silu_fp8_cpu(self): + r""" + This testcase will quantize Conv2d->SiLU pattern. + """ + self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.SiLU(), is_fp8=True) + @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN @@ -601,8 +746,27 @@ def test_qconv2d_silu_int8_mixed_bf16_cpu(self): """ self._qconv2d_unary_test_helper( unary_op=torch.nn.SiLU(), - int8_mixed_bf16=True, + mixed_bf16=True, + qconv_unary_matcher_nodes=11, + ) + + @skipIfNoDynamoSupport + @skipIfNoONEDNNBF16 + @skipIfNoONEDNN + @skipIfNoFloat8Support + def test_qconv2d_silu_fp8_mixed_bf16_cpu(self): + r""" + This testcase will quantize Conv2d->SiLU pattern. + Match.nodes: + [qconv2d_pointwise_default, convert_element_type, sigmoid, mul, + convert_element_type, quantize_per_tensor] + [qconv2d_pointwise_default, convert_element_type, sigmoid, mul, convert_element_type] + """ + self._qconv2d_unary_test_helper( + unary_op=torch.nn.SiLU(), + mixed_bf16=True, qconv_unary_matcher_nodes=11, + is_fp8=True, ) def _qconv2d_add_test_helper( diff --git a/torchao/quantization/pt2e/inductor_passes/x86.py b/torchao/quantization/pt2e/inductor_passes/x86.py index a0aef11541..791e9a42e4 100644 --- a/torchao/quantization/pt2e/inductor_passes/x86.py +++ b/torchao/quantization/pt2e/inductor_passes/x86.py @@ -167,29 +167,38 @@ def get_dequantize_per_tensor_activation_pattern( KeywordArg("w_dtype"), ) -dequantize_per_channel_to_bf16_weight_pattern = ( - _may_generate_pattern_with_dtype_convert( - dequantize_per_channel_weight_pattern, +dequantize_fp8_weight_pattern = CallFunction( + torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, + KeywordArg("q_weight"), + KeywordArg("w_scale"), + output_dtype=KeywordArg("w_dtype"), +) + +def get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern): + return _may_generate_pattern_with_dtype_convert( + dequant_wgt_pattern, KeywordArg("autocast_wgt_dtype"), ) -) -dequantize_per_channel_clone_weight_pattern = CallFunction( - aten.clone.default, - dequantize_per_channel_weight_pattern, - memory_format=KeywordArg("memory_format"), -) +def get_dequantize_clone_weight_pattern(dequant_wgt_pattern): + return CallFunction( + aten.clone.default, + dequant_wgt_pattern, + memory_format=KeywordArg("memory_format"), + ) -dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction( - aten.clone.default, - dequantize_per_channel_to_bf16_weight_pattern, - memory_format=KeywordArg("memory_format"), -) +def get_dequantize_to_bf16_clone_weight_pattern(dequant_wgt_pattern): + return get_dequantize_clone_weight_pattern(get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern)) -def get_qconv_pt2e_pattern(users=1): +def get_qconv_pt2e_pattern(x_scale_zp_are_tensors=False, users=1): + qconv_op = ( + torch.ops.onednn.qconv_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qconv_pointwise.default + ) return CallFunction( - torch.ops.onednn.qconv_pointwise.default, + qconv_op, KeywordArg("x"), KeywordArg("x_scale"), KeywordArg("x_zp"), @@ -211,35 +220,6 @@ def get_qconv_pt2e_pattern(users=1): ) -def get_qconv2d_binary_pt2e_pattern(users=1): - return CallFunction( - torch.ops.onednn.qconv2d_pointwise.binary, - KeywordArg("x"), - KeywordArg("x_scale"), - KeywordArg("x_zp"), - KeywordArg("packed_weight"), - KeywordArg("w_scale"), - KeywordArg("w_zp"), - KeywordArg("accum"), - KeywordArg("b"), - KeywordArg("stride"), - KeywordArg("padding"), - KeywordArg("dilation"), - KeywordArg("groups"), - KeywordArg("output_scale"), - KeywordArg("output_zero_point"), - KeywordArg("output_dtype"), - KeywordArg("accum_scale"), - KeywordArg("accum_zero_point"), - KeywordArg("binary_op_name"), - KeywordArg("alpha"), - KeywordArg("unary_op_name"), - KeywordArg("unary_op_args"), - KeywordArg("unary_op_algorithm"), - _users=users, - ) - - def get_qlinear_pt2e_pattern(x_scale_zp_are_tensors, users=1): qlinear_op = ( torch.ops.onednn.qlinear_pointwise.tensor @@ -396,9 +376,9 @@ def fn(match): return fn -def _is_valid_qconv_post_op_fusion_pattern(has_binary_post_op=False): +def _is_valid_qconv_post_op_fusion_pattern(has_binary_post_op=False, is_fp8=False): return ( - _is_valid_qconv_binary_optimization_pattern() + _is_valid_qconv_binary_optimization_pattern(is_fp8=is_fp8) if has_binary_post_op else _is_valid_quantized_conv_optimization_pattern() ) @@ -428,9 +408,11 @@ def _is_valid_qlinear_post_op_fusion_pattern(has_binary_post_op=False): ) -def _is_valid_qconv_binary_optimization_pattern(): +def _is_valid_qconv_binary_optimization_pattern(is_fp8=False): return _is_valid_quantized_op_binary_optimization_pattern( - torch.ops.onednn.qconv_pointwise + torch.ops.onednn.qconv_pointwise, + # we don't insert q-dq for extra input in fp8 recipe + extra_input_from_dequant= not is_fp8, ) @@ -711,7 +693,7 @@ def _inner(match): return _inner -def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32): +def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32, is_fp8=False): @register_freezing_graph_pattern( pattern, extra_check=_is_valid_dequant_conv_pattern(dtype), @@ -724,7 +706,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): | dequant_per_tensor | - Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight + Conv2d <- optional(aten.clone.default) <- dequant <- int8_weight Insert weight prepack node and change the pattern to: int8 activation @@ -747,7 +729,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): ) if dtype == torch.float32: - dequant_per_channel = ( + dequant = ( clone_node.args[0] # type: ignore[union-attr] if has_clone_to_channel_last_node_in_pattern else conv_node.args[1] @@ -758,9 +740,9 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): if has_clone_to_channel_last_node_in_pattern else conv_node.args[1] ) - dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr] + dequant = weight_to_bf16_node.args[0] # type: ignore[union-attr] - assert dequant_per_channel.target in [ # type: ignore[union-attr] + assert dequant.target in [ # type: ignore[union-attr] quantized_decomposed.dequantize_per_channel.default, torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, ] @@ -768,7 +750,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): # Activation QParams qx, x_zp, x_scale = ( kwargs["x"], - kwargs["x_zp"], + kwargs["x_zp"] if "x_zp" in kwargs else None, kwargs["x_scale"], ) @@ -776,7 +758,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): qw, w_scale, w_zp = ( kwargs["q_weight"], kwargs["w_scale"], - kwargs["w_zp"], + kwargs["w_zp"] if "w_zp" in kwargs else None, ) # Conv Params @@ -792,14 +774,19 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): if has_free_symbols(x_shape): # For dynamic shape case, we can't get activation shape ahead of runtime. x_shape = None + if is_fp8 and w_scale.target is torch.ops.aten.full.default: + with torch.utils._python_dispatch._disable_current_modes(): + w_scale_tensor = torch.tensor([w_scale.args[1]]) + match.graph.owning_module.register_buffer("w_scale", w_scale_tensor) + w_scale = match.graph.create_node("get_attr", "w_scale") graph = match.graph with graph.inserting_before(conv_node): # Insert weight prepack node and the QConv node packed_weight_inputs = ( qw, w_scale, - x_scale, - x_zp, + x_scale.args[1] if is_fp8 and x_scale.target is torch.ops.aten.full.default else x_scale, + 0, stride, padding, dilation, @@ -830,9 +817,16 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): [], # scalars "", # algorithm ) - new_conv_node = graph.call_function( - torch.ops.onednn.qconv_pointwise.default, args=new_args - ) + Node = torch.fx.node.Node + # fp8 not need zp + if isinstance(x_scale, Node) and (isinstance(x_zp, Node) or is_fp8): + new_conv_node = graph.call_function( + torch.ops.onednn.qconv_pointwise.tensor, args=new_args + ) + else: + new_conv_node = graph.call_function( + torch.ops.onednn.qconv_pointwise.default, args=new_args + ) conv_node.replace_all_uses_with(new_conv_node) new_conv_node.meta.update(conv_node.meta) @@ -847,7 +841,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): graph.erase_node(clone_node) # type: ignore[arg-type] if dtype == torch.bfloat16: graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type] - graph.erase_node(dequant_per_channel) # type: ignore[arg-type] + graph.erase_node(dequant) # type: ignore[arg-type] counters["inductor"]["qconv_weight_prepack_matcher_count"] += 1 counters["inductor"]["qconv_weight_prepack_matcher_nodes"] += len( match.nodes @@ -855,17 +849,17 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): def _generate_dequant_convolution_node_pattern( - _dequant_per_channel_pattern, dtype=torch.float32 + _dequant_pattern, dtype=torch.float32, is_fp8=False ): assert dtype in [torch.float32, torch.bfloat16] dequant_convolution_node_pattern = CallFunction( aten.convolution.default, _may_generate_pattern_with_dtype_convert( - get_dequantize_per_tensor_activation_pattern(), + get_dequantize_per_tensor_activation_pattern(is_fp8=is_fp8), KeywordArg("autocast_act_dtype"), dtype == torch.bfloat16, ), - _dequant_per_channel_pattern, + _dequant_pattern, KeywordArg("b"), KeywordArg("stride"), KeywordArg("padding"), @@ -877,24 +871,30 @@ def _generate_dequant_convolution_node_pattern( return dequant_convolution_node_pattern -def _generate_qconv_weight_prepack_patterns(dtype=torch.float32): +def _generate_qconv_weight_prepack_patterns(dtype=torch.float32, is_fp8=False): assert dtype in [torch.float32, torch.bfloat16] + if is_fp8: + dequant_wgt_pattern = dequantize_fp8_weight_pattern + else: + dequant_wgt_pattern = dequantize_per_channel_weight_pattern return ( _generate_dequant_convolution_node_pattern( - dequantize_per_channel_weight_pattern + dequant_wgt_pattern if dtype == torch.float32 - else dequantize_per_channel_to_bf16_weight_pattern, + else get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern), dtype, + is_fp8=is_fp8, ), # There is another pattern due to the pass of convert_conv_weights_to_channels_last # https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362. # Depend on some heuristics, it may or may not insert to(channel_last) node - # between convolution and dequant_per_channel node + # between convolution and dequant node _generate_dequant_convolution_node_pattern( - dequantize_per_channel_clone_weight_pattern + get_dequantize_clone_weight_pattern(dequant_wgt_pattern) if dtype == torch.float32 - else dequantize_per_channel_to_bf16_clone_weight_pattern, + else get_dequantize_to_bf16_clone_weight_pattern(dequant_wgt_pattern), dtype, + is_fp8=is_fp8, ), ) @@ -1302,12 +1302,7 @@ def _generate_qlinear_weight_prepack_patterns( is_fp8=False, ): if is_fp8: - dequant_wgt_pattern = CallFunction( - torch.ops.torchao.dequantize_affine_float8_non_decomposed.default, - KeywordArg("q_weight"), - KeywordArg("w_scale"), - output_dtype=KeywordArg("w_dtype"), - ) + dequant_wgt_pattern = dequantize_fp8_weight_pattern else: dequant_wgt_pattern = dequantize_per_channel_weight_pattern if input_dim_exceeds_two and not input_contiguous: @@ -1449,12 +1444,12 @@ def _register_dequant_promotion(): def _register_qconv_weight_prepack(): - for dtype in [torch.float32, torch.bfloat16]: - weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype) + for dtype, is_fp8 in itertools.product([torch.float32, torch.bfloat16], [True, False]): + weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype, is_fp8=is_fp8) for weight_prepack_pattern in weight_prepack_patterns: # Register to pass_number 1, so we can do dequant promotion in pass_number 0. _register_qconv_weight_prepack_pass( - weight_prepack_pattern, pass_number=1, dtype=dtype + weight_prepack_pattern, pass_number=1, dtype=dtype, is_fp8=is_fp8 ) @@ -2023,12 +2018,13 @@ def _register_qconv_post_op_fusion_pass( pass_number, computation_op, post_op_attr, + is_fp8=False, ): has_binary_post_op = post_op_attr.binary_op_name != "none" @register_freezing_graph_pattern( pattern, - extra_check=_is_valid_qconv_post_op_fusion_pattern(has_binary_post_op), + extra_check=_is_valid_qconv_post_op_fusion_pattern(has_binary_post_op, is_fp8=is_fp8), pass_number=pass_number, ) def qconv(match: Match, *args, **kwargs): @@ -2053,13 +2049,19 @@ def qconv(match: Match, *args, **kwargs): kwargs["groups"], ) output_dtype = _get_pattern_output_dtype(match) - assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16] + assert output_dtype in [torch.int8, torch.uint8, torch.float8_e4m3fn, torch.float32, torch.bfloat16] # Output QParams - o_inv_scale = ( - kwargs["o_inv_scale"] - if (output_dtype == torch.uint8 or output_dtype == torch.int8) - else 1.0 - ) + if output_dtype == torch.float8_e4m3fn: + # For float8, torchao.quantize_affine_float8 requires tensor as scale + # Support scale node is full firstly + assert kwargs["o_inv_scale"].target is torch.ops.aten.full.default + o_inv_scale = kwargs["o_inv_scale"].args[1] + else: + o_inv_scale = ( + kwargs["o_inv_scale"] + if (output_dtype == torch.uint8 or output_dtype == torch.int8) + else 1.0 + ) o_zero_point = ( kwargs["o_zp"] if (output_dtype == torch.uint8 or output_dtype == torch.int8) @@ -2098,7 +2100,7 @@ def qconv(match: Match, *args, **kwargs): else: accum = ( kwargs["accum"] - if output_dtype in [torch.uint8, torch.int8] + if output_dtype in [torch.uint8, torch.int8] or is_fp8 else kwargs["accum_after_dequant"] ) accum_scale = ( @@ -2165,56 +2167,69 @@ def _register_qconv_unary_fusion(): _silu_fusion, ) - for original_pattern_output_dtype in [torch.float32, torch.bfloat16]: + combinations = itertools.product( + [torch.float32, torch.bfloat16], [False, True], [False, True] + ) + for original_pattern_output_dtype, x_scale_zp_are_tensors, is_fp8 in combinations: # Priority 1 to match: QConv2d Unary pattern with int8 output # If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly. # For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant is_bf16 = original_pattern_output_dtype == torch.bfloat16 + computation_op = ( + torch.ops.onednn.qconv_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qconv_pointwise.default + ) conv_unary_replace_patterns = { PostOpAttr( "none", None, "none", [], "" ): generate_pattern_with_output_quant( - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), + is_fp8=is_fp8, ), PostOpAttr( "none", None, "relu", [], "" ): generate_pattern_with_output_quant( generate_pattern_with_unary( - get_qconv_pt2e_pattern(1), aten.relu.default + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), aten.relu.default ), + is_fp8=is_fp8, ), PostOpAttr( "none", None, "hardtanh", [], "" ): generate_pattern_with_output_quant( _unary_fusion_pattern( _hardtanh_fusion, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), 1, is_bf16, ), with_dtype_convert=is_bf16, + is_fp8=is_fp8, ), PostOpAttr( "none", None, "hardswish", [], "" ): generate_pattern_with_output_quant( _unary_fusion_pattern( _hardswish_fusion, - get_qconv_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1 if is_bf16 else 2), 2, is_bf16, ), with_dtype_convert=is_bf16, + is_fp8=is_fp8, ), PostOpAttr( "none", None, "swish", [], "" ): generate_pattern_with_output_quant( _unary_fusion_pattern( _silu_fusion, - get_qconv_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1 if is_bf16 else 2), 2, is_bf16, ), with_dtype_convert=is_bf16, + is_fp8=is_fp8, ), } @@ -2223,21 +2238,22 @@ def _register_qconv_unary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 3, # pass_number - torch.ops.onednn.qconv_pointwise.default, # computation_op + computation_op, # computation_op unary_attr, # unary_attr + is_fp8=is_fp8, ) # Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output conv_unary_replace_float_out_patterns = { PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary( - get_qconv_pt2e_pattern(1), aten.relu.default + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), aten.relu.default ), PostOpAttr( "none", None, "hardtanh", [], "" ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _hardtanh_fusion, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), 1, is_bf16, ), @@ -2249,7 +2265,7 @@ def _register_qconv_unary_fusion(): ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _hardswish_fusion, - get_qconv_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1 if is_bf16 else 2), 2, is_bf16, ), @@ -2261,7 +2277,7 @@ def _register_qconv_unary_fusion(): ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _silu_fusion, - get_qconv_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1 if is_bf16 else 2), 2, is_bf16, ), @@ -2275,17 +2291,23 @@ def _register_qconv_unary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 4, # pass_number - torch.ops.onednn.qconv_pointwise.default, # computation_op + computation_op, # computation_op unary_attr, # unary_attr + is_fp8=is_fp8, ) def _register_qconv_binary_fusion(): - for int8_mixed_bf16_with_inplace_add in [False, True]: + for int8_mixed_bf16_with_inplace_add, x_scale_zp_are_tensors in itertools.product([False, True], [False, True]): + qconv_binary_op = ( + torch.ops.onednn.qconv2d_pointwise.binary_tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qconv2d_pointwise.binary + ) # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output swap_binary_inputs_list = [False, True] binary_replace_patterns = {} - for swap_inputs in swap_binary_inputs_list: + for swap_inputs, is_fp8 in itertools.product(swap_binary_inputs_list, [False, True]): binary_replace_patterns.update( { PostOpAttr( @@ -2293,11 +2315,12 @@ def _register_qconv_binary_fusion(): ): generate_pattern_with_output_quant( generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), dequantize_accum_pattern, int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, ), + is_fp8=is_fp8, ), PostOpAttr( "sum", 1.0, "relu", [], "" @@ -2305,13 +2328,14 @@ def _register_qconv_binary_fusion(): generate_pattern_with_unary( generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(1), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), dequantize_accum_pattern, int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, ), aten.relu.default, ), + is_fp8=is_fp8, ), } ) @@ -2320,8 +2344,9 @@ def _register_qconv_binary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 3, # pass_number - torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + qconv_binary_op, # computation_op binary_unary_attr, # binary_unary_attr + is_fp8=is_fp8, ) # Priority 2 to match: QConv2d Binary-Unary pattern with fp32/bfloat16 output @@ -2332,8 +2357,8 @@ def _register_qconv_binary_fusion(): PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary( generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(1), - KeywordArg("accum_after_dequant"), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), + KeywordArg("accum") if is_fp8 else KeywordArg("accum_after_dequant"), int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, ), @@ -2350,15 +2375,17 @@ def _register_qconv_binary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 3, # pass_number - torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + qconv_binary_op, # computation_op binary_unary_attr, # binary_unary_attr + is_fp8=is_fp8, ) else: _register_qconv_post_op_fusion_pass( patterns, 4, # pass_number - torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + qconv_binary_op, # computation_op binary_unary_attr, # binary_unary_attr + is_fp8=is_fp8, ) # Priority 3: QConv2d Binary pattern with fp32/bfloat16 output @@ -2370,8 +2397,8 @@ def _register_qconv_binary_fusion(): "sum", 1.0, "none", [], "" ): generate_pattern_with_binary( aten.add.Tensor, - get_qconv_pt2e_pattern(1), - KeywordArg("accum_after_dequant"), + get_qconv_pt2e_pattern(x_scale_zp_are_tensors, 1), + KeywordArg("accum") if is_fp8 else KeywordArg("accum_after_dequant"), int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, ), @@ -2385,8 +2412,9 @@ def _register_qconv_binary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 4 if int8_mixed_bf16_with_inplace_add else 5, # pass_number - torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + qconv_binary_op, # computation_op binary_unary_attr, # binary_unary_attr + is_fp8=is_fp8, )