From 9e3eef7b3a2863df716dbe84b28b2c06a0d22362 Mon Sep 17 00:00:00 2001 From: libohao Date: Fri, 31 Oct 2025 16:15:10 +0800 Subject: [PATCH 01/12] Add ut test_decomp.py on XPU --- test/xpu/run_test_with_only.py | 54 +- test/xpu/test_decomp.py | 1364 ++++++++++++++++++++++++++++++++ 2 files changed, 1417 insertions(+), 1 deletion(-) create mode 100644 test/xpu/test_decomp.py diff --git a/test/xpu/run_test_with_only.py b/test/xpu/run_test_with_only.py index 0c3d11b504..134be9bd19 100644 --- a/test/xpu/run_test_with_only.py +++ b/test/xpu/run_test_with_only.py @@ -43,7 +43,59 @@ def launch_test(test_case, skip_list=None, exe_list=None): "test_comprehensive_nn_functional_nll_loss_xpu_float64", "bincount", ) -res += launch_test("test_decomp_xpu.py", exe_list=execute_list) +skip_list = ( + "test_comprehensive_baddbmm_xpu_float64", + "test_comprehensive_logspace_tensor_overload_xpu_int16", + "test_comprehensive_logspace_tensor_overload_xpu_int32", + "test_comprehensive_logspace_tensor_overload_xpu_int64", + "test_comprehensive_logspace_xpu_int16", + "test_comprehensive_logspace_xpu_int32", + "test_comprehensive_logspace_xpu_int64", + "test_comprehensive_nn_functional_conv_transpose2d_xpu_bfloat16", + "test_comprehensive_nn_functional_conv_transpose2d_xpu_complex128", + "test_comprehensive_nn_functional_conv_transpose2d_xpu_complex32", + "test_comprehensive_nn_functional_conv_transpose2d_xpu_complex64", + "test_comprehensive_nn_functional_conv_transpose2d_xpu_float16", + "test_comprehensive_nn_functional_conv_transpose2d_xpu_float32", + "test_comprehensive_nn_functional_conv_transpose2d_xpu_float64", + "test_comprehensive_nn_functional_conv_transpose3d_xpu_bfloat16", + "test_comprehensive_nn_functional_conv_transpose3d_xpu_complex128", + "test_comprehensive_nn_functional_conv_transpose3d_xpu_complex32", + "test_comprehensive_nn_functional_conv_transpose3d_xpu_complex64", + "test_comprehensive_nn_functional_conv_transpose3d_xpu_float16", + "test_comprehensive_nn_functional_conv_transpose3d_xpu_float32", + "test_comprehensive_nn_functional_conv_transpose3d_xpu_float64", + "test_comprehensive_nn_functional_instance_norm_xpu_float64", + "test_comprehensive_nn_functional_nll_loss_xpu_float16", + "test_comprehensive_nn_functional_pad_reflect_xpu_bfloat16", + "test_comprehensive_torch_ops_aten__flash_attention_forward_xpu_float16", + "test_comprehensive_vdot_xpu_complex128", + "test_comprehensive_vdot_xpu_complex64", + "test_quick_addmm_xpu_float64", + "test_quick_baddbmm_xpu_float64", + "test_quick_core_backward_baddbmm_xpu_float64", + "test_quick_core_backward_mv_xpu_float64", + "test_quick_logspace_tensor_overload_xpu_int16", + "test_quick_logspace_tensor_overload_xpu_int32", + "test_quick_logspace_tensor_overload_xpu_int64", + "test_quick_logspace_xpu_int16", + "test_quick_logspace_xpu_int32", + "test_quick_logspace_xpu_int64", + "test_quick_vdot_xpu_complex128", + "test_quick_vdot_xpu_complex64", + "test_exponential_non_inf_xpu", + "test_aten_core_operators", + "test_has_decomposition", + "test_comprehensive_diff_xpu_complex128", + "test_comprehensive_ormqr_xpu_complex128", + "test_quick_var_mean_xpu_float64", + "test_comprehensive_diff_xpu_complex64", + "test_comprehensive_ormqr_xpu_complex64", + "test_quick_mean_xpu_complex128", + "test_comprehensive_grid_sampler_2d_xpu_bfloat16", +) +# res += launch_test("test_decomp_xpu.py", exe_list=execute_list) +res += launch_test("test_decomp.py", skip_list=skip_list) if os.name == "nt": sys.exit(res) diff --git a/test/xpu/test_decomp.py b/test/xpu/test_decomp.py new file mode 100644 index 0000000000..2b1d637e6b --- /dev/null +++ b/test/xpu/test_decomp.py @@ -0,0 +1,1364 @@ +# Owner(s): ["module: decompositions"] + +import functools +import itertools +import re +import unittest +from collections import defaultdict +from functools import partial + +import torch._inductor.decomposition +import torch.autograd +from torch import Tensor +from torch._decomp import core_aten_decompositions, decomposition_table +from torch._dispatch.python import enable_python_dispatcher +from torch._export.utils import _is_cia_op +from torch._ops import DispatchKey +from torch.testing import make_tensor +from torch.testing._internal.common_cuda import SM70OrLater, tf32_off +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, + onlyCPU, + onlyNativeDeviceTypes, + ops, +) +from torch.testing._internal.common_methods_invocations import ( + op_db, + skip, + skipOps, + xfail, +) +from torch.testing._internal.common_modules import module_db, modules +from torch.testing._internal.common_utils import ( + is_iterable_of_tensors, + run_tests, + skipIfCrossRef, + skipIfTorchDynamo, + suppress_warnings, + TEST_WITH_ASAN, + TEST_WITH_SLOW, + TestCase, + unMarkDynamoStrictTest, +) +from torch.utils import _pytree as pytree +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten + +device_type = ( + acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu" +) + +aten = torch.ops.aten + + +# TODO: this isn't going to work with non-aten namespaces +def overload_to_aten_name(op): + return op._schema.name.split("::")[1] + + +# All operators that can have decomp tests +decomposition_names = { + overload_to_aten_name(k) + for k in decomposition_table + if isinstance(k, torch._ops.OpOverload) +} +core_decomposition_names = { + overload_to_aten_name(k) + for k in core_aten_decompositions() + if isinstance(k, torch._ops.OpOverload) and not _is_cia_op(k) +} +_decomp_test_ops = [ + op + for op in op_db + if op.aten_name in decomposition_names + or op.aten_backward_name in decomposition_names +] +_decomp_test_ops_core_autograd = [ + op + for op in op_db + if op.aten_name in core_decomposition_names and op.supports_autograd +] +_sdpa_op_info = [op for op in op_db if "scaled_dot_product_attention" in op.aten_name] + + +def diff_arg(arg, requires_grad=True): + def is_differentiable_arg(arg): + if requires_grad: + return arg.requires_grad + else: + return arg.is_floating_point() or arg.is_complex() + + if is_iterable_of_tensors(arg): + if all(is_differentiable_arg(a) for a in arg): + return True + if all(not is_differentiable_arg(a) for a in arg): + return False + raise RuntimeError("NYI: The test runner can't handle this") + return isinstance(arg, Tensor) and is_differentiable_arg(arg) + + +# Version of autograd.grad with some differences: +# - pytree inputs is allowed (but leaves of the pytree have to all +# be tensors) +# - if an input is not used as part of derivatives, we will return a +# zero-filled tensor for the result +def _autograd_grad( + outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True +): + inputs, inputs_spec = tree_flatten(inputs) + diff_inputs = tuple(inp for inp in inputs if inp.requires_grad) + if grad_outputs is None: + diff_outputs = tuple(out for out in outputs if out.requires_grad) + else: + diff_grad_outputs = [ + (out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad + ] + if len(diff_grad_outputs) == 0: + diff_outputs, grad_outputs = (), () + else: + diff_outputs, grad_outputs = zip(*diff_grad_outputs) + grad_inputs = torch.autograd.grad( + diff_outputs, + diff_inputs, + grad_outputs, + retain_graph=retain_graph, + create_graph=create_graph, + allow_unused=True, + ) + result = [] + grad_inputs_iter = iter(grad_inputs) + for inp in inputs: + if inp.requires_grad: + grad_input = next(grad_inputs_iter) + if grad_input is None: + result.append(torch.zeros_like(inp)) + else: + result.append(grad_input) + else: + result.append(torch.zeros_like(inp)) + return tree_unflatten(result, inputs_spec) + + +def _as_tuple(val): + if isinstance(val, tuple): + return val + return (val,) + + +def ref_vjp_no_create(f, *primals): + result = f(*primals) + + def wrapped(cotangents): + return _autograd_grad( + _as_tuple(result), + primals, + _as_tuple(cotangents), + create_graph=False, + retain_graph=True, + ) + + return result, wrapped + + +dtype_precisions = { + torch.float16: (0.001, 1e-5), + torch.bfloat16: (0.016, 1e-4), + torch.float32: (1.3e-6, 1e-5), + torch.float64: (1e-7, 1e-7), + torch.complex32: (0.001, 1e-5), + torch.complex64: (1.3e-6, 1e-5), + torch.complex128: (1e-7, 1e-7), +} +# Returns the "default" rtol and atol for comparing scalars or +# tensors of the given dtypes. + + +def _getDefaultRtolAndAtol(dtype0, dtype1): + rtol = max( + dtype_precisions.get(dtype0, (0, 0))[0], dtype_precisions.get(dtype1, (0, 0))[0] + ) + atol = max( + dtype_precisions.get(dtype0, (0, 0))[1], dtype_precisions.get(dtype1, (0, 0))[1] + ) + return rtol, atol + + + +def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs): + assert orig.dtype == decomp.dtype, f"{i} Operation: {op}" + if orig.numel() == 0 or decomp.numel() == 0: + assert orig.numel() == decomp.numel() + return + assert orig.shape == decomp.shape, f"{i} Operation: {op}" + tol_table = { + (torch.bfloat16, torch.ops.aten.native_layer_norm.default): 1e-5, + (torch.float16, torch.ops.aten.native_layer_norm.default): 1e-5, + (torch.float16, torch.ops.aten.native_layer_norm_backward.default): 1e-3, + (torch.bfloat16, torch.ops.aten.native_layer_norm_backward.default): 2e-2, + (torch.bfloat16, torch.ops.aten.native_batch_norm.default): 1e-5, + (torch.float16, torch.ops.aten.native_batch_norm.default): 1e-5, + (torch.bfloat16, torch.ops.aten._native_batch_norm_legit.default): 1e-5, + (torch.bfloat16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5, + (torch.float16, torch.ops.aten._native_batch_norm_legit.default): 1e-5, + (torch.float16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5, + (torch.bfloat16, torch.ops.aten.linalg_vector_norm.default): 1e-4, + (torch.float16, torch.ops.aten.linalg_vector_norm.default): 1e-4, + (torch.bfloat16, torch.ops.aten.var_mean.correction): 5e-7, + (torch.float16, torch.ops.aten.var_mean.correction): 5e-7, + (torch.bfloat16, torch.ops.aten.var_mean.dim): 5e-7, + (torch.float16, torch.ops.aten.var_mean.dim): 5e-7, + (torch.float16, torch.ops.aten.nll_loss_forward.default): 1e-2, + (torch.bfloat16, torch.ops.aten.nll_loss_forward.default): 1e-1, + (torch.float16, torch.ops.aten.nll_loss2d_forward.default): 1e-2, + (torch.float16, torch.ops.aten.nll_loss2d_backward.default): 1e-4, + (torch.bfloat16, torch.ops.aten.nll_loss2d_forward.default): 2e-1, + (torch.float16, torch.ops.aten.hardswish.default): 2e-7, + (torch.bfloat16, torch.ops.aten.hardswish.default): 2e-7, + (torch.float16, torch.ops.aten.multi_margin_loss.default): 3e-2, + (torch.bfloat16, torch.ops.aten.multi_margin_loss.default): 5e-2, + (torch.float16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2, + (torch.bfloat16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2, + (torch.float16, torch.ops.aten.reflection_pad1d_backward.default): 5e-3, + (torch.bfloat16, torch.ops.aten.reflection_pad1d_backward.default): 5e-3, + (torch.float16, torch.ops.aten.reflection_pad2d_backward.default): 5e-3, + (torch.bfloat16, torch.ops.aten.reflection_pad2d_backward.default): 5e-3, + (torch.float16, torch.ops.aten.reflection_pad3d_backward.default): 5e-3, + (torch.bfloat16, torch.ops.aten.reflection_pad3d_backward.default): 5e-2, + # see https://github.com/pytorch/pytorch/pull/96264 + (torch.float16, torch.ops.aten.mv.default): 1e-5, + (torch.bfloat16, torch.ops.aten.mv.default): 1e-5, + (torch.float16, torch.ops.aten.log_sigmoid_backward.default): 2e-5, + (torch.float16, torch.ops.aten._softmax_backward_data.default): 3e-7, + # XPU specific + ( + torch.float16, + torch.ops.aten._batch_norm_with_update.default, + ): 2e-7, # adjust tolerance for xpu + ( + torch.bfloat16, + torch.ops.aten._batch_norm_with_update.default, + ): 2e-7, # adjust tolerance for xpu + } + if ref.is_floating_point(): + orig_diff = (orig - ref).abs().max() + decomp_diff = (decomp - ref).abs().max() + atol = tol_table.get((test_dtype, op), 1e-7) + if decomp_diff > orig_diff + atol: + raise RuntimeError( + f"Difference from float64 is larger with decomposition {op.__name__}" + f" than original on output {i}. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n" + f"atol = {atol}\n" + f"args = {args}\n" + f"kwargs = {kwargs}" + ) + else: + test_case.assertEqual( + orig, decomp, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}" + ) + + +def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs): + test_case.assertEqual( + orig.dtype, + decomp.dtype, + f"Operation: {op}, orig.dtype: {orig.dtype}, decomp.dtype: {decomp.dtype}, {args}, {kwargs}", + ) + # Before adding an entry to this table, make sure your decomposition is right :) + tol_table = { + # Due to strange epsilon behaviors, see https://github.com/pytorch/pytorch/issues/73161 + (torch.float32, torch.ops.aten.native_layer_norm.default): (1e-3, 1e-3), + (torch.float32, torch.ops.aten.native_layer_norm_backward.default): ( + 1e-3, + 1e-3, + ), + (torch.float64, torch.ops.aten.native_layer_norm.default): (1e-6, 1e-6), + # This exceeds default tolerances only on CPU, on CUDA it's fine + (torch.float32, torch.ops.aten.grid_sampler_2d.default): (7e-6, 3e-5), + # Exceeds tolerances on CUDA, likely due to fma + (torch.float32, torch.ops.aten.mv.default): (1e-5, 3e-5), + (torch.complex64, torch.ops.aten.mv.default): (5e-5, 5e-5), + (torch.float64, torch.ops.aten.upsample_bicubic2d.vec): (1e-5, 5e-4), + (torch.float64, torch.ops.aten.upsample_bicubic2d.default): (1e-5, 5e-4), + # The decomposition is TOO correct. It computes everything in int64, so sometimes + # there's an off-by-one error. See + # https://github.com/pytorch/pytorch/issues/81996 + # https://github.com/pytorch/pytorch/issues/82230 + (torch.int8, torch.ops.aten.linspace.default): (0, 1), + (torch.uint8, torch.ops.aten.linspace.default): (0, 1), + (torch.int16, torch.ops.aten.linspace.default): (0, 1), + (torch.int32, torch.ops.aten.linspace.default): (0, 1), + (torch.int64, torch.ops.aten.linspace.default): (0, 1), + (torch.int8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), + (torch.uint8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), + (torch.int16, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), + (torch.int32, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), + (torch.int64, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), + (torch.int8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), + (torch.uint8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), + (torch.int16, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), + (torch.int32, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), + (torch.int64, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), + (torch.int8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), + (torch.uint8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), + (torch.int16, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), + (torch.int32, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), + (torch.int64, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), + } + if (decomp.dtype, op) in tol_table: + rtol, atol = tol_table[(decomp.dtype, op)] + else: + rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype) + test_case.assertEqual( + orig, + decomp, + rtol=rtol, + atol=atol, + msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}", + ) + + +# Given f, returns an f' such that: +# - f' takes only positional arguments +# - All arguments to f' are floating-point Tensors +# - All outputs of f' are floating-point Tensors +def normalize_op_input_output2( + f, args, kwargs, output_process_fn_grad=None, requires_grad=True +): + flat_args, args_spec = tree_flatten(args) + diff_argnums = tuple( + i + for i, arg in enumerate(flat_args) + if diff_arg(arg, requires_grad=requires_grad) + ) + assert len(diff_argnums) > 0 + primals = tuple(flat_args[i] for i in diff_argnums) + + @functools.wraps(f) + def wrapped(*primals): + _args = list(flat_args) + for num, arg in zip(diff_argnums, primals): + _args[num] = arg + _args = tree_unflatten(_args, args_spec) + result = f(*_args, **kwargs) + if output_process_fn_grad is not None: + result = output_process_fn_grad(result) + if isinstance(result, tuple): + # TODO We should check that the integer outputs also agree + result = tuple( + r + for r in result + if isinstance(r, Tensor) and (r.is_floating_point() or r.is_complex()) + ) + assert len(result) > 0 + return result + + return wrapped, primals + + +# NB: This also upcasts dtype arguments +# TODO: handle complex correctly +def upcast_tensor(x, dtype=torch.float32): + if isinstance(x, Tensor) and x.dtype.is_floating_point: + return x.to(dtype=dtype) + elif isinstance(x, torch.dtype) and x in [ + torch.float16, + torch.bfloat16, + torch.float, + ]: + return dtype + else: + return x + + +def normalize_op_input_output(f, sample, requires_grad=True): + args = tuple([sample.input] + list(sample.args)) + return normalize_op_input_output2( + f, + args, + sample.kwargs, + sample.output_process_fn_grad, + requires_grad=requires_grad, + ) + + +CROSS_REF_EXCLUDE_SET = { + # CUBLAS_STATUS_NOT_SUPPORTED when calling + # `cublasGemmStridedBatchedExFix(handle, opa, opb, (int)m, (int)n, (int)k, + # (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, b, CUDA_R_16BF, + # (int)ldb, strideb, (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec, + # (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)` + ("cuda", torch.bfloat16, "nn.functional.bilinear"), + # randomness + (None, None, "special.ndtr"), # aten.special_ndtr was not decomposed + (None, None, "new_empty"), + (None, None, "empty_like"), + (None, None, "empty"), + # AssertionError: False is not true : aten.item was not decomposed, saw calls for: aten._local_scalar_dense.default. + (None, None, "item"), + # It's the only in-place op without an out-of-place equivalent in the Python API + # Its OpInfo wrongly registers it as `torch.zero_(x.clone())`. + (None, None, "zero_"), + # No idea what's going on here + # In the recursive test logsumexp.default fails with args = (torch.tensor(-math.inf), []) + # in the test, but it seems to pass when tested locally and in the logsumexp test + (None, torch.float32, "masked.logsumexp"), + (None, torch.float64, "masked.logsumexp"), + # exp_vml_cpu not implemented for Half + (torch.cpu, torch.float16, "signal.windows.exponential"), + (torch.cpu, torch.float16, "signal.windows.gaussian"), + # sin_vml_cpu not implemented for Half + (torch.cpu, torch.float16, "signal.windows.cosine"), + # CompositeAutogradImplicit + # See https://github.com/pytorch/pytorch/issues/81669 + (None, None, "nn.functional.relu6"), + # This decomp runs before autograd. + (None, None, "nn.functional.rrelu"), + (None, None, "meshgrid"), + # Decomposition registered as Autograd + (None, None, "nn.functional.hardshrink"), + (None, None, "nn.functional.softshrink"), + # diag was not decomposed (it just registers a decomp for diag_out, torch.diag is CompImplicit) + (None, None, "diag"), + # _softmax_backward_data's CPU kernel for bfloat16 always return the grad_input as float32 + ("cpu", torch.bfloat16, "_softmax_backward_data"), + (None, None, "norm"), + # native_batch_norm is only implicit when python dispatcher is on (and noncomposite otherwise) + (None, None, "native_batch_norm"), + (None, None, "_upsample_bilinear2d_aa"), + (None, None, "empty_strided"), # aten.empty_strided was not decomposed + ( + None, + None, + "bernoulli", + ), # bernoulli is a function of randomness, so couldn't do cross-reference. + # XPU specific exclude cases + # ("xpu", None, "some_xpu_specific_op"), # 根据需要添加 XPU 特定的排除项 +} + +CROSS_REF_BACKWARD_EXCLUDE_SET = { + # Decomposed backward formula is not as precise + ("cpu", torch.bfloat16, "nn.functional.hardswish"), + ("cuda", torch.float16, "nn.functional.cross_entropy"), + ( + None, + None, + "bernoulli", + ), # bernoulli is a function of randomness, so couldn't do cross-reference. + # XPU specific backward exclude cases + # ("xpu", torch.float16, "nn.functional.some_op"), # 根据需要添加 +} + +all_decomposed = set() +all_called = defaultdict(int) + +# Helpful snippet for testing coverage +""" +import atexit +def check_coverage(): + print("missing coverage:") + print("\n".join(map(str, decomposition_table.keys() - all_decomposed))) +atexit.register(check_coverage) +""" + +# Helpful snippet for Horace to create his google sheet :) +""" +import atexit +def dump_ops(): + with open('run_ops.txt', 'w') as f, open('count_ops.txt', 'w') as g: + for op, count in sorted(all_called.items(), key=lambda x: x[0].__name__): + f.write(f'{op.__name__}\n') + g.write(f'{count}\n') + with open('run_decompositions.txt', 'w') as f: + for op in sorted([i.__name__ for i in all_decomposed]): + f.write(f'{op}\n') + +atexit.register(dump_ops) +""" + + +def any_unsupported(args, kwargs): + def test_unsupported(t): + if type(t) is torch.Tensor or type(t) is torch.nn.Parameter: + # These are all things that we haven't coded decompositions + # to handle correctly. Maybe they should. + return any( + [ + t.is_sparse_csr, + t.is_sparse, + t.is_mkldnn, + t.is_quantized, + t.is_nested, + torch._is_functional_tensor(t), + ] + ) + elif torch.overrides.is_tensor_like(t): + # Decompositions will generally change the behavior of Tensor-like + # subclasses, so bypass tests in this case too + return True + else: + return False + + flat_args = pytree.arg_tree_leaves(*args, **kwargs) + return any(test_unsupported(x) for x in flat_args) + + +core_backward_failures = { + skip("_softmax_backward_data"), # slow: fails with --timeout=360 secs + xfail("addcdiv"), + skip("addcmul"), # slow: fails with --timeout=360 secs + skip("deg2rad"), # slow: fails with --timeout=360 secs + skip("diag_embed"), # slow: fails with --timeout=360 secs + skip("frac"), # slow: fails with --timeout=360 secs + skip("grid_sampler_2d"), # slow: fails with --timeout=360 secs + xfail("lerp"), + skip("logaddexp"), # slow: fails with --timeout=360 secs + skip("native_dropout_backward"), # slow: fails with --timeout=360 secs + xfail("nn.functional.binary_cross_entropy_with_logits"), + skip("nn.functional.glu"), # slow: fails with --timeout=360 secs + xfail("nn.functional.hardshrink"), + xfail("nn.functional.softshrink"), + skip("nn.functional.unfold"), # slow: fails with --timeout=360 secs + xfail("norm"), + xfail("norm", "fro"), + xfail("norm", "inf"), + xfail("norm", "nuc"), + skip("rad2deg"), # slow: fails with --timeout=360 secs + skip("renorm"), # slow: fails with --timeout=360 secs + skip("rot90"), # slow: fails with --timeout=360 secs + skip("rsub"), # slow: fails with --timeout=360 secs + skip("sgn"), # slow: fails with --timeout=360 secs + skip("special.xlog1py"), # slow: fails with --timeout=360 secs + xfail("stack"), + skip("tril"), # slow: fails with --timeout=360 secs + skip("triu"), # slow: fails with --timeout=360 secs + skip("unfold_copy"), # slow: fails with --timeout=360 secs + skip("xlogy"), # slow: fails with --timeout=360 secs + xfail("zero_"), +} +if not TEST_WITH_SLOW: + core_backward_failures.update( + { + skip("addr"), # slow: takes 46 sec on A100 + skip("baddbmm"), # slow: takes 800+ sec on A100 + skip("clamp_min"), # slow: takes 800 sec on A100 + skip("clamp_max"), # slow: takes 800 sec on A100 + skip("logit"), # slow: takes 44 sec on A100 + skip("nn.functional.hardswish"), # slow: takes 60 sec on A100 + skip("std_mean"), # slow: takes 170 sec on A100 + skip("split", variant_name="list_args"), # slow: takes 118 sec on A100 + skip("transpose"), # slow: takes 50 sec on A100 + skip("unbind"), # slow: takes 70 sec on A100 + skip("unsafe_split"), # slow: takes 49 sec on A100 + } + ) + +comprehensive_failures = { + xfail( + "nn.functional.interpolate", "bilinear", dtypes=(torch.uint8,) + ), # off by one error + xfail( + "nn.functional.interpolate", "bicubic", dtypes=(torch.uint8,) + ), # off by one error + xfail( + "nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,) + ), # off by one error +} + + +@unMarkDynamoStrictTest +class TestDecomp(TestCase): + longMessage = True + + # NB: This actually overlaps with test_comprehensive, but it only + # runs on things that are definitely decomposed so it's a lot faster + # to run + @onlyNativeDeviceTypes + @skipIfCrossRef + @suppress_warnings + @ops(_decomp_test_ops) + def test_quick(self, device, dtype, op): + self.do_cross_ref(device, dtype, op, run_all=False) + + @skipOps("TestDecomp", "test_quick_core_backward", core_backward_failures) + @onlyNativeDeviceTypes + @skipIfCrossRef + @suppress_warnings + @ops(_decomp_test_ops_core_autograd, allowed_dtypes=(torch.float64,)) + def test_quick_core_backward(self, device, dtype, op): + test_keys = [ + (torch.device(device).type, dtype, op.name), + (None, dtype, op.name), + (None, None, op.name), + ] + if any(key in CROSS_REF_BACKWARD_EXCLUDE_SET for key in test_keys): + self.skipTest(f"{op.name} in {dtype} not supported") + for sample_input in op.sample_inputs(device, dtype, requires_grad=True): + aten_name = op.decomp_aten_name or op.aten_name + args = [sample_input.input] + list(sample_input.args) + kwargs = sample_input.kwargs + func = partial(op.get_op(), **kwargs) + with ( + self.DecompCrossRefMode( + self, self.precision, self.rel_tol, dtype, run_all=False + ) as mode, + enable_python_dispatcher(), + ): + torch.autograd.gradcheck(func, args) + self.check_decomposed(aten_name, mode) + + @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") + @onlyNativeDeviceTypes + @skipIfCrossRef + @skipOps("TestDecomp", "test_comprehensive", comprehensive_failures) + @suppress_warnings + @ops(op_db) + def test_comprehensive(self, device, dtype, op): + self.do_cross_ref(device, dtype, op, run_all=True) + + def test_uniform(self, device): + size = (2, 3, 4, 5) + dtype = torch.float32 + x = make_tensor(size, dtype=dtype, device=device) + low = 0.3 + high = 0.9 + + torch.manual_seed(123) + ref = torch.ops.aten.uniform(x, low, high) + torch.manual_seed(123) + res = torch._decomp.decompositions.uniform(x, low=low, high=high) + self.assertEqual(ref, res) + + def test_bernoulli_default(self, device): + p = 0.3 + p_t = p * torch.ones(5, 5) + torch.manual_seed(123) + ref = torch.ops.aten.bernoulli.default(p_t) + torch.manual_seed(123) + res = torch._decomp.decompositions.bernoulli(p_t) + ref_p = ref.sum() / torch.prod(torch.tensor(ref.size())) + res_p = res.sum() / torch.prod(torch.tensor(res.size())) + self.assertEqual(ref_p, res_p, atol=0.06 * p, rtol=0.06) + + def test_broadcasting_index_copy(self, device): + x = torch.zeros([1, 10], device=device) + xs = torch.ones([2, 10], device=device) + + def index_copy(xs, x): + torch._decomp.decompositions.index_copy_( + xs, 0, torch.tensor(0).to(device), x + ) + + index_copy(xs, x) + + xs_two = torch.ones([2, 10], device=device) + xs_two[0] = x + + self.assertEqual(xs, xs_two) + + def test_cat_single_input(self, device): + decomp_table = torch._inductor.decomposition.select_decomp_table() + cat_inductor = decomp_table[torch.ops.aten.cat.default] + + inp = torch.rand([2048, 2048], device=device) + inps = [inp for _ in range(10)] + + for dim in (-1, 0, 1): + self.assertEqual(torch.cat(inps, dim), cat_inductor(inps, dim)) + + @suppress_warnings + @tf32_off() + # only tests RNNs since we have py dispsatcher decomps for them + @modules( + filter( + lambda m: m.module_cls in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU), + module_db, + ) + ) + def test_rnn_decomp_module(self, device, dtype, module_info, training): + module_cls = module_info.module_cls + module_inputs = module_info.module_inputs_func( + module_info, + device=device, + dtype=dtype, + requires_grad=True, + training=training, + ) + for module_input in module_inputs: + if module_input.forward_input is None: + continue + args, kwargs = ( + module_input.constructor_input.args, + module_input.constructor_input.kwargs, + ) + m = module_cls(*args, **kwargs) + m.to(device).to(dtype) + + args, kwargs = ( + module_input.forward_input.args, + module_input.forward_input.kwargs, + ) + with ( + self.DecompCrossRefMode( + self, self.precision, self.rel_tol, dtype, run_all=True + ), + enable_python_dispatcher(), + ): + decomp_out = m(*args, **kwargs) + + non_decomp_out = m(*args, **kwargs) + # without this check, incorrect decomps at the python dispatcher level can still pass because + # they're checking aten decomps at the torch_dispatch level + self.assertEqual(decomp_out, non_decomp_out) + + def test_batch_norm_unflatten_weight_bias(self, device): + # https://github.com/pytorch/pytorch/issues/100970 + shape = (1, 3, 2, 2) + input = torch.randn(shape, device=device) + weight = torch.randn((3, 1, 1, 1), device=device) + bias = torch.randn(3, device=device) + mean = torch.randn(3, device=device) + var = torch.randn(3, device=device) + res = torch._decomp.decompositions.native_batch_norm( + input, weight, bias, mean, var, False, 1, 1e-05 + ) + self.assertEqual(shape, res[0].shape) + + def test_arange_graph(self, device): + from torch.fx.experimental.proxy_tensor import make_fx + + def func(x, start): + le = x.shape[-1] + if start is None: + a = torch.arange(le, dtype=torch.float32, device=x.device) + else: + a = torch.arange(start, le, dtype=torch.float32, device=x.device) + return a + + pattern = r", device = device\(.+\), requires_grad = False" + + cfunc = make_fx(func, decomposition_table=decomposition_table) + fx_g = cfunc(torch.rand(10, device=device), None) + fx_g_code = fx_g.code.strip() + # Remove device and requires_grad + fx_g_code = re.sub(pattern, "", fx_g_code) + self.assertExpectedInline( + fx_g_code, + """\ +def forward(self, x_1, start_1): + iota = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64) + mul = torch.ops.prims.mul.default(iota, 1); iota = None + add = torch.ops.prims.add.default(mul, 0); mul = None + convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32); add = None + return convert_element_type""", + ) + + fx_g = cfunc(torch.rand(10, device=device), 1) + fx_g_code = fx_g.code.strip() + # Remove device and requires_grad + fx_g_code = re.sub(pattern, "", fx_g_code) + self.assertExpectedInline( + fx_g_code, + """\ +def forward(self, x_1, start_1): + iota = torch.ops.prims.iota.default(9, start = 0, step = 1, dtype = torch.int64) + mul = torch.ops.prims.mul.default(iota, 1); iota = None + add = torch.ops.prims.add.default(mul, 1); mul = None + convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32); add = None + return convert_element_type""", + ) + + def test_masked_fill(self, device): + from torch.fx.experimental.proxy_tensor import make_fx + + if torch.device(device).type not in [ + "xpu", + "cuda", + torch._C._get_privateuse1_backend_name(), + ]: + self.skipTest("only runs on XPU, CUDA and PrivateUse1.") + + def func(scores, mask, value): + return scores.masked_fill(mask, value) + + scores_t = torch.tensor([1, 2, 3, 4], device=device) + mask_t = torch.tensor([True, True, True, True], device=device) + value_t = torch.tensor(0, dtype=scores_t.dtype) + cfunc = make_fx(func, decomposition_table=decomposition_table) + fx_g = cfunc(scores_t, mask_t, value_t) + self.assertExpectedInline( + fx_g.code.strip(), + """\ +def forward(self, scores_1, mask_1, value_1): + where = torch.ops.prims.where.default(mask_1, value_1, scores_1); mask_1 = value_1 = scores_1 = None + return where""", + ) + + class DecompCrossRefMode(TorchDispatchMode): + def __init__(self, test_case, saved_precision, saved_rel_tol, dtype, run_all): + self.test_case = test_case + self.saved_precision = saved_precision + self.saved_rel_tol = saved_rel_tol + self.test_dtype = dtype + self.run_all = run_all + + # We check the correctness of each decomposition right after running it. + # So, when we encounter a decomposition, we run the function normally, and + # then run the decomposition, and ensure they're identical. + self.called = set() + self.decomposed = set() + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + self.test_case.precision = self.saved_precision + self.test_case.rel_tol = self.saved_rel_tol + + self.called.add(func) + all_called[func] += 1 + + # Stuff we shouldn't bother testing + # (TODO: remove detach from the decomp table?) + # N.b. Testing in-place ops would need dedicated logic + in_place = func.name()[-1] == "_" + ignored_ops = [ + torch.ops.aten.detach.default, + # non-deterministic ops + torch.ops.aten.empty.memory_format, + torch.ops.aten.empty_like.default, + torch.ops.aten.new_empty.default, + torch.ops.aten.empty_strided.default, + torch.ops.aten.new_empty_strided.default, + torch.ops.aten.randn.default, + torch.ops.aten.native_dropout.default, + ] + if ( + func not in decomposition_table + or func in ignored_ops + or torch.Tag.nondeterministic_seeded in func.tags + or any_unsupported(args, kwargs) + or in_place + ): + return func(*args, **kwargs) + + self.decomposed.add(func) + all_decomposed.add(func) + + # We take 2 main strategies for verifying correctness/numerical stability of decompositions + # The first one is simply tolerance checking between decomp_out and pytorch_out + # However, for fp16/bf16 and reductions, this becomes very + # finicky, as there are not many guarantees we can make. + # So, for fp16/bf16, we instead compare the difference of + # {decomp_out, pytorch_out_64} and {pytorch_out, + # pytorch_out_64}. In other words, we compare how far the + # decomposition and pytorch are from the "ground truth" (i.e. + # fp64). If the decomposition results in more error, we error + + # We also decompose the decomposition recursively for + # further coverage, as some paths not be exercised directly by + # OpInfos (sadly) but just by other ops + + decomposition = decomposition_table[func] + + do_relative_check = self.test_dtype in [torch.float16, torch.bfloat16] + if self.run_all: + # Execute recursively via DFS, to find the root of a possible error first + with self: + decomp_out = pytree.tree_leaves(decomposition(*args, **kwargs)) + else: + decomp_out = pytree.tree_leaves(decomposition(*args, **kwargs)) + + # At this stage we should not be decomposing an in-place op + # We'd like to have decompositions that decompose out-of-place ops into out-of-place ops + # because decompositions are run after functionalisation and we would not like them to + # de-functionalise the graph, as that would break AoTAutograd + # We run the real function *after* the decomposition to make sure that the + # decomposition does not modify any of the inputs in-place. If it does + # real_out should be different than decom_out so we should catch this + real_out_unflat = func(*args, **kwargs) + real_out = pytree.tree_leaves(real_out_unflat) + + assert len(real_out) == len(decomp_out) + + if do_relative_check: + device_arg = kwargs.get("device", None) + + def upcast(x): + if (isinstance(x, Tensor) and x.device.type == "mps") or ( + device_arg and torch.device(device_arg).type == "mps" + ): + return upcast_tensor(x, dtype=torch.float32) + else: + return upcast_tensor(x, dtype=torch.float64) + + real_out_double, _ = tree_flatten( + func(*tree_map(upcast, args), **tree_map(upcast, kwargs)) + ) + for i, (orig, decomp, ref) in enumerate( + zip(real_out, decomp_out, real_out_double) + ): + if not isinstance(orig, torch.Tensor): + assert type(orig) == type(decomp) + assert orig == decomp + continue + op_assert_ref( + self.test_case, + func, + self.test_dtype, + i, + orig, + decomp, + ref, + args, + kwargs, + ) + else: + for orig, decomp in zip(real_out, decomp_out): + if not isinstance(orig, torch.Tensor): + assert type(orig) == type(decomp) + assert orig == decomp + continue + op_assert_equal( + self.test_case, + func, + self.test_dtype, + orig, + decomp, + args, + kwargs, + ) + + return real_out_unflat + + def check_decomposed(self, aten_name, mode): + self.assertTrue( + any(overload_to_aten_name(c) == aten_name for c in mode.decomposed), + msg=( + f"aten.{aten_name} was not decomposed, saw calls for: " + f"{', '.join(map(str, list(mode.called)))}. If your op is " + f"CompositeImplicitAutograd you should skip this test " + f"by updating CROSS_REF_EXCLUDE_SET." + ), + ) + + @skipIfTorchDynamo("Test does not work with TorchDynamo") + def do_cross_ref(self, device, dtype, op, *, run_all): + test_keys = [ + (torch.device(device).type, dtype, op.name), + (None, dtype, op.name), + (None, None, op.name), + ] + if any(key in CROSS_REF_EXCLUDE_SET for key in test_keys): + self.skipTest(f"{op.name} in {dtype} not supported") + + skip_decomp_vjp = any( + key in CROSS_REF_BACKWARD_EXCLUDE_SET for key in test_keys + ) + + requires_grad = ( + op.supports_autograd + and dtype in op.supported_backward_dtypes(torch.device(device).type) + # TODO: OpInfo really ought to error out for this case, but it's + # not exercised in test_ops_gradients atm. The problem is not + # complex32 per-se (which is supported by data movement only ops) + # but that when we do backwards we expect other ops like add to work + and not dtype == torch.complex32 + ) + samples = op.sample_inputs(device, dtype, requires_grad=requires_grad) + + aten_name = op.decomp_aten_name or op.aten_name + + func = op.get_op() + + def run_without_python_dispatcher(mode): + return any( + isinstance(op, torch._ops.OpOverload) + and op.has_kernel_for_dispatch_key( + DispatchKey.CompositeImplicitAutograd + ) + for op in mode.decomposed.union([func]) + ) + + for sample_input in samples: + if requires_grad: + fn, primals = normalize_op_input_output(func, sample_input) + primals = tree_map( + lambda x: x if isinstance(x, torch.Tensor) else x, primals + ) + + # Once https://github.com/pytorch/pytorch/pull/75965/ I can + # store the called list on the mode object instance and no + # explicit clearing is necessary as I will create a fresh mode + # for each region + with ( + self.DecompCrossRefMode( + self, self.precision, self.rel_tol, dtype, run_all + ) as mode, + enable_python_dispatcher(), + ): + decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals) + if run_without_python_dispatcher(mode): + # without this check, incorrect decomps at the python dispatcher level can still pass because + # they're checking aten decomps at the torch_dispatch level. + with self.DecompCrossRefMode( + self, self.precision, self.rel_tol, dtype, run_all + ) as mode: + decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals) + if aten_name in decomposition_names: + self.check_decomposed(aten_name, mode) + + if not skip_decomp_vjp and ( + op.aten_backward_name in decomposition_names or run_all + ): + cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out) + + with ( + self.DecompCrossRefMode( + self, self.precision, self.rel_tol, dtype, run_all + ) as mode, + enable_python_dispatcher(), + ): + decomp_vjp_fn(cotangents) + if run_without_python_dispatcher(mode): + # without this check, incorrect decomps at the python dispatcher level can still pass because + # they're checking aten decomps at the torch_dispatch level. + with self.DecompCrossRefMode( + self, self.precision, self.rel_tol, dtype, run_all + ) as mode: + decomp_vjp_fn(cotangents) + if not run_all: + self.check_decomposed(op.aten_backward_name, mode) + + elif aten_name in decomposition_names or run_all: + args = [sample_input.input] + list(sample_input.args) + kwargs = sample_input.kwargs + # A failure here might be because the decomposition for the op is wrong or because a + # decomposition used by the particular op is wrong. + with ( + self.DecompCrossRefMode( + self, self.precision, self.rel_tol, dtype, run_all + ) as mode, + enable_python_dispatcher(), + ): + func(*args, **kwargs) + + if run_without_python_dispatcher(mode): + # without this check, incorrect decomps at the python dispatcher level can still pass because + # they're checking aten decomps at the torch_dispatch level. + with self.DecompCrossRefMode( + self, self.precision, self.rel_tol, dtype, run_all + ) as mode: + func(*args, **kwargs) + + if not run_all: + self.check_decomposed(aten_name, mode) + else: + assert op.supports_autograd + self.skipTest( + "only backwards is decomposed, but dtype doesn't support AD" + ) + + +instantiate_device_type_tests(TestDecomp, globals(), only_for="xpu", allow_xpu=True) + + +class DecompOneOffTests(TestCase): + @onlyNativeDeviceTypes + @skipIfCrossRef + def test_contiguous_softmax(self, device): + size = (2, 4, 3, 3) + stride = (9, 18, 3, 1) + dtype = torch.float32 + + x = torch.randn(size, dtype=dtype, device=device) + x = torch.as_strided(x, size, stride) + + ref = torch.ops.aten._softmax(x, -1, False) + res = torch._decomp.decompositions._softmax(x, -1, False) + self.assertEqual(ref.stride(), res.stride()) + + @onlyNativeDeviceTypes + @skipIfCrossRef + def test_contiguous_log_softmax(self, device): + size = (2, 4, 3, 3) + stride = (9, 18, 3, 1) + + dtype = torch.float32 + x = torch.randn(size, dtype=dtype, device=device) + x = torch.as_strided(x, size, stride) + + ref = torch.ops.aten._log_softmax(x, -1, False) + res = torch._decomp.decompositions._log_softmax(x, -1, False) + self.assertEqual(ref.stride(), res.stride()) + + def test_exponential_non_inf(self, device): + inp = torch.empty((4, 400, 256), device=device) + + with torch._dynamo.utils.preserve_rng_state(): + exp_ref = inp.exponential_() + exp = torch._refs.exponential(inp) + + self.assertEqual(exp, exp_ref) + self.assertFalse(exp.isinf().any()) + + @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") + @skipIfCrossRef + def test_amp_batch_norm_backward(self): + device = device_type + grad_out = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device) + x = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device) + weight = torch.randn((2,), dtype=torch.float32, device=device) + rmean = torch.randn((2,), dtype=torch.float32, device=device) + rvar = torch.randn((2,), dtype=torch.float32, device=device) + mean = torch.randn((0,), dtype=torch.float32, device=device) + + ref = torch.ops.aten.native_batch_norm_backward( + grad_out, + x, + weight, + rmean, + rvar, + mean, + mean, + False, + 1e-05, + [True, True, True], + ) + res = torch._decomp.decompositions.native_batch_norm_backward( + grad_out, + x, + weight, + rmean, + rvar, + mean, + mean, + False, + 1e-05, + [True, True, True], + ) + for a, b in zip(ref, res): + self.assertEqual(a.stride(), b.stride()) + self.assertEqual(a.dtype, b.dtype) + + @onlyNativeDeviceTypes + @skipIfCrossRef + def test_elu_backward(self, device): + size = (2, 4, 3, 3) + dtype = torch.float32 + grad_out = torch.randn(size, dtype=dtype, device=device) + out = torch.randn(size, dtype=dtype, device=device) + + ref = torch.ops.aten.elu_backward(grad_out, 1.0, 1, 1, True, out) + res = torch._decomp.decompositions.elu_backward(grad_out, 1.0, 1, 1, True, out) + self.assertEqual(ref, res) + + @onlyNativeDeviceTypes + @skipIfCrossRef + def test_threshold_backward_dtype(self, device): + grad = torch.randint(10, (4,), device=device) + input_tensor = torch.randint(10, (4,), device=device) + + ref = torch.ops.aten.threshold_backward(grad, input_tensor, 1) + res = torch._decomp.decompositions.threshold_backward(grad, input_tensor, 1) + self.assertEqual(ref.dtype, res.dtype) + + @onlyNativeDeviceTypes + @skipIfCrossRef + def test_weight_norm_interface(self, device): + g = torch.randn((3, 10, 10), device=device) + v = torch.randn((1, 1, 10), device=device) + + ref = torch.ops.aten._weight_norm_interface(g, v, 2) + res = torch._decomp.decompositions._weight_norm_interface(g, v, 2) + self.assertTrue(torch.allclose(ref[0], res[0])) + self.assertTrue(torch.allclose(ref[1], res[1])) + + inp = torch.rand([30, 10], device=device) + inp2 = torch.rand([30, 1], device=device) + + self.assertEqual( + torch.ops.aten._weight_norm_interface(inp, inp2), + torch._decomp.decompositions._weight_norm_interface(inp, inp2), + ) + + @onlyCPU + @skipIfCrossRef + @skipOps( + "DecompOneOffTests", + "test_sdpa", + [ + xfail( + "nn.functional.scaled_dot_product_attention", + dtypes=[torch.half], + ), + ], + ) + @ops(_sdpa_op_info) + def test_sdpa(self, device, dtype, op): + # SDPA doesn't support float16, this is aligned with aten/src/ATen/native/transformers/attention.cpp. If we + # add support for float16 over there we should update this test as well. + query_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype) + key_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype) + value_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype) + masks = [None, torch.ones((1, 1, 100, 100), device=device, dtype=torch.bool)] + + atol, rtol = dtype_precisions[dtype] + + for mask in masks: + is_causal = mask is None + decomposed_res = ( + torch._decomp.decompositions.scaled_dot_product_flash_attention_for_cpu( + query_layer, key_layer, value_layer, 0.0, is_causal, attn_mask=mask + ) + ) + actual_res = decomposed_res[0] + # Output has form (N, H, L, E), but should be continuous on (L, N, H, E) + # in order for subsequent view(L * N, H * E) to be valid. + # So permute(2, 0, 1, 3) before checking that tensor is contiguous + self.assertTrue(actual_res.permute(2, 0, 1, 3).is_contiguous()) + + eager_res = op( + query_layer, + key_layer, + value_layer, + attn_mask=mask, + dropout_p=0.0, + is_causal=is_causal, + ) + + self.assertTrue(torch.allclose(actual_res, eager_res, atol=atol, rtol=rtol)) + + @onlyCPU + def test_native_layer_norm_cpu_decomp(self, device): + def f(x, w, b): + return torch.ops.aten.native_layer_norm.default(x, [1, 2, 3], w, b, eps=0.5) + + x = torch.randn(1, 2, 3, dtype=torch.bfloat16, device="cpu") + w = torch.randn(1, 2, 3, dtype=torch.bfloat16, requires_grad=True, device="cpu") + b = torch.randn(1, 2, 3, dtype=torch.bfloat16, requires_grad=True, device="cpu") + out_ref = f(x, w, b) + + from torch._subclasses.fake_tensor import FakeTensorMode + + with enable_python_dispatcher(), FakeTensorMode(): + x = torch.randn(1, 2, 3, dtype=torch.bfloat16, device="cpu") + w = torch.randn( + 1, 2, 3, dtype=torch.bfloat16, requires_grad=True, device="cpu" + ) + b = torch.randn( + 1, 2, 3, dtype=torch.bfloat16, requires_grad=True, device="cpu" + ) + out = f(x, w, b) + + for o_ref, o in zip(out_ref, out): + self.assertEqual(o_ref.dtype, o.dtype) + + @unittest.skipIf(not SM70OrLater, "triton") + def test_rms_norm_decomp_accelerator(self, device): + @torch.compile + def rms_norm_sinh(a, b, c): + output = torch.nn.functional.rms_norm(a, b, c) + return torch.sinh(output) + + normalized_shape_arg = (3, 3, 3) + input_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True) + weight_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True) + + def forward_pass_fn(): + return rms_norm_sinh(input_tensor, normalized_shape_arg, weight_tensor) + + model_output, generated_codes = torch._inductor.utils.run_fw_bw_and_get_code( + forward_pass_fn + ) + + # check RMSNorm was fused with sinh + self.assertTrue( + "triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0] + ) + self.assertTrue( + "triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1] + ) + + +instantiate_device_type_tests( + DecompOneOffTests, globals(), only_for="xpu", allow_xpu=True +) + + +class HasDecompTest(TestCase): + def setUp(self): + super().setUp() + self.maxDiff = None + + @staticmethod + def _can_appear_in_trace(op: torch._ops.OpOverload) -> bool: + has_tensor_arg = any( + "Tensor" in str(a.type) + for a in itertools.chain(op._schema.arguments, op._schema.returns) + ) + if not has_tensor_arg: + return False + + try: + # CompositeImplicitAutograd ops are transparent to the tracer, so don't need decompositions + return not _is_cia_op(op) + except RuntimeError as e: + # has_key fails for some jit-registered ops, which shouldn't be + # relevant here anyway + if "does not exist" in str(e): + return False + raise + + def test_has_decomposition(self): + def all_aten_overloads(): + for name in torch._C._dispatch_get_all_op_names(): + if not name.startswith("aten::"): + continue + + name = name[6:] + if "." in name: + packet_name, overload_name = name.split(".") + else: + packet_name, overload_name = name, "default" + + packet = getattr(aten, packet_name) + assert isinstance(packet, torch._ops.OpOverloadPacket) + op = getattr(packet, overload_name) + yield op + + # This is for operators that are only registered in some CI + # configurations, so would cause the test to fail + allow_list = {aten.get_gradients.default} + + overloads_wanting_decomp = { + op for op in all_aten_overloads() if self._can_appear_in_trace(op) + } + ops_missing_decomp = overloads_wanting_decomp - decomposition_table.keys() + ops_missing_decomp -= allow_list + self.assertExpected( + "".join(sorted(op.name() + "\n" for op in ops_missing_decomp)) + ) + + def test_aten_core_operators(self): + # If a decomposition isn't included in the core decompositions, + # then it must decompose a core ATen operator. + # + # See NOTE [Core ATen Ops] + # + # If this test fails then either: + # - Add the decomposition to torch._decomp.core_aten_decompositions, + # if decomposition should be used by inductor (not a core operator). + # - Run this test again with EXPECTTEST_ACCEPT=1 to update the list of + # core ATen operators (and inductor will not use the decomposition). + + # Some decompositions are registered for CompositeImplicitAutograd + # operators, which never appear in AOTAutograd's graph so are never used. + useful_decomps = { + op + for op in decomposition_table.keys() + if isinstance(op, torch._ops.OpOverload) and self._can_appear_in_trace(op) + } + core_decomps = torch._decomp.core_aten_decompositions().keys() + core_aten_ops = useful_decomps - core_decomps + self.assertExpected("".join(sorted(op.name() + "\n" for op in core_aten_ops))) + + +if __name__ == "__main__": + run_tests() From fb001fd20ccb1fd79199fc52c69c0e96233dad1e Mon Sep 17 00:00:00 2001 From: libohao1201 Date: Sun, 2 Nov 2025 19:33:13 -0800 Subject: [PATCH 02/12] Fix lint error --- test/xpu/test_decomp.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/xpu/test_decomp.py b/test/xpu/test_decomp.py index 2b1d637e6b..9d48984675 100644 --- a/test/xpu/test_decomp.py +++ b/test/xpu/test_decomp.py @@ -183,7 +183,6 @@ def _getDefaultRtolAndAtol(dtype0, dtype1): return rtol, atol - def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs): assert orig.dtype == decomp.dtype, f"{i} Operation: {op}" if orig.numel() == 0 or decomp.numel() == 0: @@ -432,7 +431,7 @@ def normalize_op_input_output(f, sample, requires_grad=True): "bernoulli", ), # bernoulli is a function of randomness, so couldn't do cross-reference. # XPU specific exclude cases - # ("xpu", None, "some_xpu_specific_op"), # 根据需要添加 XPU 特定的排除项 + # ("xpu", None, "some_xpu_specific_op"), } CROSS_REF_BACKWARD_EXCLUDE_SET = { @@ -445,7 +444,7 @@ def normalize_op_input_output(f, sample, requires_grad=True): "bernoulli", ), # bernoulli is a function of randomness, so couldn't do cross-reference. # XPU specific backward exclude cases - # ("xpu", torch.float16, "nn.functional.some_op"), # 根据需要添加 + # ("xpu", torch.float16, "nn.functional.some_op"), } all_decomposed = set() From 7a1c46f1e778c5971f9459a45a0e0718ae954905 Mon Sep 17 00:00:00 2001 From: libohao1201 Date: Sun, 2 Nov 2025 19:44:53 -0800 Subject: [PATCH 03/12] Fix lint error --- test/xpu/test_decomp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/xpu/test_decomp.py b/test/xpu/test_decomp.py index 9d48984675..eb79ac80f3 100644 --- a/test/xpu/test_decomp.py +++ b/test/xpu/test_decomp.py @@ -431,7 +431,7 @@ def normalize_op_input_output(f, sample, requires_grad=True): "bernoulli", ), # bernoulli is a function of randomness, so couldn't do cross-reference. # XPU specific exclude cases - # ("xpu", None, "some_xpu_specific_op"), + # ("xpu", None, "some_xpu_specific_op"), } CROSS_REF_BACKWARD_EXCLUDE_SET = { @@ -444,7 +444,7 @@ def normalize_op_input_output(f, sample, requires_grad=True): "bernoulli", ), # bernoulli is a function of randomness, so couldn't do cross-reference. # XPU specific backward exclude cases - # ("xpu", torch.float16, "nn.functional.some_op"), + # ("xpu", torch.float16, "nn.functional.some_op"), } all_decomposed = set() From 0657beebf419d9b4dc5a39b815587e2b1aeaa3a8 Mon Sep 17 00:00:00 2001 From: libohao1201 Date: Fri, 7 Nov 2025 22:33:14 -0800 Subject: [PATCH 04/12] Add test_decomp.py in run_test_with_skip.py 1. windows skip the whole file - test_decomp.py 2. linux skip specific cases in test_decomp.py --- test/xpu/run_test_with_skip.py | 83 ++++++++++++++++++++++++++++++++-- test/xpu/skip_list_common.py | 64 ++++++++++++++++++++++++++ test/xpu/windows_skip_cases.py | 23 ++++++++++ 3 files changed, 165 insertions(+), 5 deletions(-) create mode 100644 test/xpu/windows_skip_cases.py diff --git a/test/xpu/run_test_with_skip.py b/test/xpu/run_test_with_skip.py index 57e361d68a..dfa7617301 100644 --- a/test/xpu/run_test_with_skip.py +++ b/test/xpu/run_test_with_skip.py @@ -13,32 +13,105 @@ default="selected", help="Test cases scope", ) +# Add skip-cases parameter to import window skip dictionary +parser.add_argument( + "--skip-cases", + action="store_true", + default=False, + help="Use window skip dictionary for test cases", +) args = parser.parse_args() +def should_skip_entire_file(skip_list): + """Check if the skip list contains any entire file skip pattern (*.py::)""" + if not skip_list: + return False + return any(item.endswith('.py::') for item in skip_list) + +# Import window skip dictionary if skip-cases is True +if args.skip_cases: + try: + # Import the window skip dictionary module + from window_skip_dict import skip_dict as window_skip_dict + + # Merge the window skip dictionary with the default one using intelligent strategy + merged_skip_dict = {} + + # First, copy all keys from default skip_dict + for key in skip_dict: + merged_skip_dict[key] = skip_dict[key].copy() if skip_dict[key] else [] + + # Then merge with window_skip_dict using intelligent strategy + for key in window_skip_dict: + window_skip_list = window_skip_dict[key] + + if key in merged_skip_dict: + default_skip_list = merged_skip_dict[key] + + # Intelligent merge strategy: + if should_skip_entire_file(window_skip_list): + # If Windows wants to skip entire file, use ONLY Windows skip list + merged_skip_dict[key] = window_skip_list + print(f"Windows entire file skip detected for {key}, using: {window_skip_list}") + else: + # Otherwise, merge both lists and remove duplicates + combined_list = default_skip_list + [item for item in window_skip_list if item not in default_skip_list] + merged_skip_dict[key] = combined_list + print(f"Windows merging skip lists for {key}: {combined_list}") + else: + # Add new key-value pair from window_skip_dict + merged_skip_dict[key] = window_skip_list + print(f"Windows adding new skip key: {key} with {window_skip_list}") + + print("Using intelligently merged skip dictionary") + + except ImportError: + print("Warning: window_skip_dict module not found, using default skip dictionary") + merged_skip_dict = skip_dict + except Exception as e: + print(f"Error importing window skip dictionary: {e}") + merged_skip_dict = skip_dict +else: + merged_skip_dict = skip_dict + print("Using default skip dictionary") res = 0 fail_test = [] -for key in skip_dict: - skip_list = skip_dict[key] +for key in merged_skip_dict: + skip_list = merged_skip_dict[key] exe_list = None + if args.test_cases == "skipped": + # When running only skipped cases, use skip_list as exe_list exe_list = skip_list skip_list = None - if exe_list is None: + if not exe_list: # Check if exe_list is empty + print(f"Skipping {key} as no tests to execute") continue elif args.test_cases == "all": + # When running all cases, don't skip any skip_list = None + # For "selected" case, use the skip_list as is + + print(f"Running test case: {key}") + if skip_list: + print(f"Skip list: {skip_list}") + if exe_list: + print(f"Execute list: {exe_list}") + fail = launch_test(key, skip_list=skip_list, exe_list=exe_list) res += fail if fail: fail_test.append(key) + if fail_test: print(",".join(fail_test) + " have failures") - +else: + print("All tests passed!") if os.name == "nt": sys.exit(res) else: exit_code = os.WEXITSTATUS(res) - sys.exit(exit_code) + sys.exit(exit_code) \ No newline at end of file diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index e2ce184a10..ce587cdacb 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -889,4 +889,68 @@ "test_sparse_matmul_xpu_float64", # - RuntimeError: Double and complex datatype matmul is not supported in oneDNN "test_sparse_mm_xpu_float64", # - NotImplementedError: Could not run 'aten::addmm' with arguments from the 'SparseXPU' backend. This could be because the operator doesn't exist for this backend, or wa... ), + "test_decomp.py": ( + # AssertionError: Tensor-likes are not close! ; Exception: Tensor-likes are not close! + "test_comprehensive_baddbmm_xpu_float64", + "test_comprehensive_logspace_tensor_overload_xpu_int16", + "test_comprehensive_logspace_tensor_overload_xpu_int32", + "test_comprehensive_logspace_tensor_overload_xpu_int64", + "test_comprehensive_logspace_xpu_int16", + "test_comprehensive_logspace_xpu_int32", + "test_comprehensive_logspace_xpu_int64", + # RuntimeError: could not create a primitive descriptor for the deconvolution forward propagation primitive. + "test_comprehensive_nn_functional_conv_transpose2d_xpu_bfloat16", + "test_comprehensive_nn_functional_conv_transpose2d_xpu_complex128", + "test_comprehensive_nn_functional_conv_transpose2d_xpu_complex32", + "test_comprehensive_nn_functional_conv_transpose2d_xpu_complex64", + "test_comprehensive_nn_functional_conv_transpose2d_xpu_float16", + "test_comprehensive_nn_functional_conv_transpose2d_xpu_float32", + "test_comprehensive_nn_functional_conv_transpose2d_xpu_float64", + "test_comprehensive_nn_functional_conv_transpose3d_xpu_bfloat16", + "test_comprehensive_nn_functional_conv_transpose3d_xpu_complex128", + "test_comprehensive_nn_functional_conv_transpose3d_xpu_complex32", + "test_comprehensive_nn_functional_conv_transpose3d_xpu_complex64", + "test_comprehensive_nn_functional_conv_transpose3d_xpu_float16", + "test_comprehensive_nn_functional_conv_transpose3d_xpu_float32", + "test_comprehensive_nn_functional_conv_transpose3d_xpu_float64", + # AssertionError: Tensor-likes are not close! ; Exception: Tensor-likes are not close! + "test_comprehensive_nn_functional_instance_norm_xpu_float64", + # RuntimeError: Difference from float64 is larger with decomposition nll_loss_forward.default than original on output 0. + "test_comprehensive_nn_functional_nll_loss_xpu_float16", + "test_comprehensive_nn_functional_pad_reflect_xpu_bfloat16", + # NotImplementedError: Could not run 'aten::_flash_attention_forward' with arguments from the 'CPU' backend. + "test_comprehensive_torch_ops_aten__flash_attention_forward_xpu_float16", + # AssertionError: Scalars are not close! ; Exception: Scalars are not close! + "test_comprehensive_vdot_xpu_complex128", + "test_comprehensive_vdot_xpu_complex64", + # AssertionError: Tensor-likes are not close! ; Exception: Tensor-likes are not close! + "test_quick_addmm_xpu_float64", + "test_quick_baddbmm_xpu_float64", + "test_quick_core_backward_baddbmm_xpu_float64", + # Exception: Jacobian mismatch for output 0 with respect to input 0 + "test_quick_core_backward_mv_xpu_float64", + # AssertionError: Tensor-likes are not equal! ; Exception: Tensor-likes are not equal! + "test_quick_logspace_tensor_overload_xpu_int16", + "test_quick_logspace_tensor_overload_xpu_int32", + "test_quick_logspace_tensor_overload_xpu_int64", + "test_quick_logspace_xpu_int16", + "test_quick_logspace_xpu_int32", + "test_quick_logspace_xpu_int64", + # AssertionError: Scalars are not close! ; Exception: Scalars are not close! + "test_quick_vdot_xpu_complex128", + "test_quick_vdot_xpu_complex64", + # AssertionError: Tensor-likes are not close! + "test_exponential_non_inf_xpu", + # RuntimeError: I got this output for HasDecompTest.test_aten_core_operators: + "test_aten_core_operators", + "test_has_decomposition", + # AssertionError: Tensor-likes are not close! + "test_comprehensive_diff_xpu_complex128", + "test_comprehensive_ormqr_xpu_complex128", + "test_quick_var_mean_xpu_float64", + "test_comprehensive_diff_xpu_complex64", + "test_comprehensive_ormqr_xpu_complex64", + "test_quick_mean_xpu_complex128", + "test_comprehensive_grid_sampler_2d_xpu_bfloat16", + ) } diff --git a/test/xpu/windows_skip_cases.py b/test/xpu/windows_skip_cases.py new file mode 100644 index 0000000000..8293770358 --- /dev/null +++ b/test/xpu/windows_skip_cases.py @@ -0,0 +1,23 @@ +""" +Window specific skip list for unit tests +Using pytest -k filtering syntax +""" + +skip_dict = { + # Windows: Skip entire files using *.py:: pattern + "test_decomp": [ + "test_decomp.py::", # Skip entire file on Windows + ], + + # Files where Windows only needs to skip specific tests (will merge with Linux defaults) + # "test_linalg": [ + # "test_cholesky_windows_bug", # Only skip specific Windows issues + # "test_qr_windows_memory", # Will be merged with Linux skip list + # ], + + # New test groups only needed on Windows + # "windows_specific_issues": [ + # "test_dll_loading", + # "test_path_length", + # ], +} \ No newline at end of file From f45f0461974287ceaca75407550fb953035f87d0 Mon Sep 17 00:00:00 2001 From: libohao1201 Date: Fri, 7 Nov 2025 22:55:06 -0800 Subject: [PATCH 05/12] Fix lint error. --- test/xpu/run_test_with_skip.py | 42 +++++++++++++++++++++------------- test/xpu/skip_list_common.py | 18 +++++++-------- test/xpu/windows_skip_cases.py | 4 +--- 3 files changed, 36 insertions(+), 28 deletions(-) diff --git a/test/xpu/run_test_with_skip.py b/test/xpu/run_test_with_skip.py index dfa7617301..e1e578b16a 100644 --- a/test/xpu/run_test_with_skip.py +++ b/test/xpu/run_test_with_skip.py @@ -17,56 +17,66 @@ parser.add_argument( "--skip-cases", action="store_true", - default=False, + default=False, help="Use window skip dictionary for test cases", ) args = parser.parse_args() + def should_skip_entire_file(skip_list): """Check if the skip list contains any entire file skip pattern (*.py::)""" if not skip_list: return False - return any(item.endswith('.py::') for item in skip_list) + return any(item.endswith(".py::") for item in skip_list) + # Import window skip dictionary if skip-cases is True if args.skip_cases: try: # Import the window skip dictionary module from window_skip_dict import skip_dict as window_skip_dict - + # Merge the window skip dictionary with the default one using intelligent strategy merged_skip_dict = {} - + # First, copy all keys from default skip_dict for key in skip_dict: merged_skip_dict[key] = skip_dict[key].copy() if skip_dict[key] else [] - + # Then merge with window_skip_dict using intelligent strategy for key in window_skip_dict: window_skip_list = window_skip_dict[key] - + if key in merged_skip_dict: default_skip_list = merged_skip_dict[key] - + # Intelligent merge strategy: if should_skip_entire_file(window_skip_list): # If Windows wants to skip entire file, use ONLY Windows skip list merged_skip_dict[key] = window_skip_list - print(f"Windows entire file skip detected for {key}, using: {window_skip_list}") + print( + f"Windows entire file skip detected for {key}, using: {window_skip_list}" + ) else: # Otherwise, merge both lists and remove duplicates - combined_list = default_skip_list + [item for item in window_skip_list if item not in default_skip_list] + combined_list = default_skip_list + [ + item + for item in window_skip_list + if item not in default_skip_list + ] merged_skip_dict[key] = combined_list print(f"Windows merging skip lists for {key}: {combined_list}") else: # Add new key-value pair from window_skip_dict merged_skip_dict[key] = window_skip_list print(f"Windows adding new skip key: {key} with {window_skip_list}") - + print("Using intelligently merged skip dictionary") - + except ImportError: - print("Warning: window_skip_dict module not found, using default skip dictionary") + print( + "Warning: window_skip_dict module not found, using default skip dictionary" + ) merged_skip_dict = skip_dict except Exception as e: print(f"Error importing window skip dictionary: {e}") @@ -81,7 +91,7 @@ def should_skip_entire_file(skip_list): for key in merged_skip_dict: skip_list = merged_skip_dict[key] exe_list = None - + if args.test_cases == "skipped": # When running only skipped cases, use skip_list as exe_list exe_list = skip_list @@ -93,13 +103,13 @@ def should_skip_entire_file(skip_list): # When running all cases, don't skip any skip_list = None # For "selected" case, use the skip_list as is - + print(f"Running test case: {key}") if skip_list: print(f"Skip list: {skip_list}") if exe_list: print(f"Execute list: {exe_list}") - + fail = launch_test(key, skip_list=skip_list, exe_list=exe_list) res += fail if fail: @@ -114,4 +124,4 @@ def should_skip_entire_file(skip_list): sys.exit(res) else: exit_code = os.WEXITSTATUS(res) - sys.exit(exit_code) \ No newline at end of file + sys.exit(exit_code) diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index cb2934f3fa..56f3fb1c17 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -753,7 +753,7 @@ "test_sparse_mm_xpu_float64", # - NotImplementedError: Could not run 'aten::addmm' with arguments from the 'SparseXPU' backend. This could be because the operator doesn't exist for this backend, or wa... ), "test_decomp.py": ( - # AssertionError: Tensor-likes are not close! ; Exception: Tensor-likes are not close! + # AssertionError: Tensor-likes are not close! ; Exception: Tensor-likes are not close! "test_comprehensive_baddbmm_xpu_float64", "test_comprehensive_logspace_tensor_overload_xpu_int16", "test_comprehensive_logspace_tensor_overload_xpu_int32", @@ -776,35 +776,35 @@ "test_comprehensive_nn_functional_conv_transpose3d_xpu_float16", "test_comprehensive_nn_functional_conv_transpose3d_xpu_float32", "test_comprehensive_nn_functional_conv_transpose3d_xpu_float64", - # AssertionError: Tensor-likes are not close! ; Exception: Tensor-likes are not close! + # AssertionError: Tensor-likes are not close! ; Exception: Tensor-likes are not close! "test_comprehensive_nn_functional_instance_norm_xpu_float64", - # RuntimeError: Difference from float64 is larger with decomposition nll_loss_forward.default than original on output 0. + # RuntimeError: Difference from float64 is larger with decomposition nll_loss_forward.default than original on output 0. "test_comprehensive_nn_functional_nll_loss_xpu_float16", "test_comprehensive_nn_functional_pad_reflect_xpu_bfloat16", # NotImplementedError: Could not run 'aten::_flash_attention_forward' with arguments from the 'CPU' backend. "test_comprehensive_torch_ops_aten__flash_attention_forward_xpu_float16", - # AssertionError: Scalars are not close! ; Exception: Scalars are not close! + # AssertionError: Scalars are not close! ; Exception: Scalars are not close! "test_comprehensive_vdot_xpu_complex128", "test_comprehensive_vdot_xpu_complex64", - # AssertionError: Tensor-likes are not close! ; Exception: Tensor-likes are not close! + # AssertionError: Tensor-likes are not close! ; Exception: Tensor-likes are not close! "test_quick_addmm_xpu_float64", "test_quick_baddbmm_xpu_float64", "test_quick_core_backward_baddbmm_xpu_float64", # Exception: Jacobian mismatch for output 0 with respect to input 0 "test_quick_core_backward_mv_xpu_float64", - # AssertionError: Tensor-likes are not equal! ; Exception: Tensor-likes are not equal! + # AssertionError: Tensor-likes are not equal! ; Exception: Tensor-likes are not equal! "test_quick_logspace_tensor_overload_xpu_int16", "test_quick_logspace_tensor_overload_xpu_int32", "test_quick_logspace_tensor_overload_xpu_int64", "test_quick_logspace_xpu_int16", "test_quick_logspace_xpu_int32", "test_quick_logspace_xpu_int64", - # AssertionError: Scalars are not close! ; Exception: Scalars are not close! + # AssertionError: Scalars are not close! ; Exception: Scalars are not close! "test_quick_vdot_xpu_complex128", "test_quick_vdot_xpu_complex64", - # AssertionError: Tensor-likes are not close! + # AssertionError: Tensor-likes are not close! "test_exponential_non_inf_xpu", - # RuntimeError: I got this output for HasDecompTest.test_aten_core_operators: + # RuntimeError: I got this output for HasDecompTest.test_aten_core_operators: "test_aten_core_operators", "test_has_decomposition", # AssertionError: Tensor-likes are not close! diff --git a/test/xpu/windows_skip_cases.py b/test/xpu/windows_skip_cases.py index 8293770358..e0f6633d71 100644 --- a/test/xpu/windows_skip_cases.py +++ b/test/xpu/windows_skip_cases.py @@ -8,16 +8,14 @@ "test_decomp": [ "test_decomp.py::", # Skip entire file on Windows ], - # Files where Windows only needs to skip specific tests (will merge with Linux defaults) # "test_linalg": [ # "test_cholesky_windows_bug", # Only skip specific Windows issues # "test_qr_windows_memory", # Will be merged with Linux skip list # ], - # New test groups only needed on Windows # "windows_specific_issues": [ # "test_dll_loading", # "test_path_length", # ], -} \ No newline at end of file +} From 4dd91a29228411e505cf911b5a7052ce658a12e3 Mon Sep 17 00:00:00 2001 From: libohao1201 Date: Fri, 7 Nov 2025 23:03:10 -0800 Subject: [PATCH 06/12] "Remove duplicated skip list entries in test/xpu/skip_list_common.py to fix lint error." --- test/xpu/skip_list_common.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index 56f3fb1c17..1d9d562c76 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -733,25 +733,6 @@ # CUDA specific case "test_cufft_plan_cache_xpu_float64", ), - "test_sparse_xpu.py": ( - "test_bmm_deterministic_xpu_float64", # - AssertionError: Torch not compiled with CUDA enabled - "test_bmm_oob_xpu", # - NotImplementedError: Could not run 'aten::bmm' with arguments from the 'SparseXPU' backend. This could be because the operator doesn't exist for this backend, or was ... - "test_bmm_xpu_float64", # - NotImplementedError: Could not run 'aten::bmm' with arguments from the 'SparseXPU' backend. This could be because the operator doesn't exist for this backend, or was ... - "test_dsmm_xpu_float64", # - NotImplementedError: Could not run 'aten::mm' with arguments from the 'SparseXPU' backend. This could be because the operator doesn't exist for this backend, or was o... - "test_empty_like_xpu_float64", # - AssertionError: "Could not run 'aten::empty_strided' with arguments from the 'Sparse(CPU|CUDA)' backend" does not match "Could not run 'aten::empty_strided' with argu... - "test_factory_device_type_inference_xpu", # - RuntimeError: PyTorch is not linked with support for cuda devices - "test_hsmm_xpu_float64", # - NotImplementedError: Could not run 'aten::hspmm' with arguments from the 'SparseXPU' backend. This could be because the operator doesn't exist for this backend, or wa... - "test_mv_xpu_float64", # - NotImplementedError: Could not run 'aten::mm' with arguments from the 'SparseXPU' backend. This could be because the operator doesn't exist for this backend, or was o... - "test_new_device_single_gpu_xpu", # - RuntimeError: PyTorch was compiled without CUDA support - "test_print_coalesced_xpu_float64", # - RuntimeError: I got this output for TestSparseXPU.test_print_coalesced_xpu_float64: - "test_print_uncoalesced_xpu_float64", # - RuntimeError: I got this output for TestSparseXPU.test_print_uncoalesced_xpu_float64 - "test_sparse_addmm_xpu_bfloat16", # - NotImplementedError: Could not run 'aten::addmm' with arguments from the 'SparseXPU' backend. This could be because the operator doesn't exist for this backend, or wa... - "test_sparse_addmm_xpu_float16", # - NotImplementedError: Could not run 'aten::addmm' with arguments from the 'SparseXPU' backend. This could be because the operator doesn't exist for this backend, or wa... - "test_sparse_addmm_xpu_float64", # - NotImplementedError: Could not run 'aten::addmm' with arguments from the 'SparseXPU' backend. This could be because the operator doesn't exist for this backend, or wa... - "test_sparse_matmul_xpu_float32", # - NotImplementedError: Could not run 'aten::_sparse_sparse_matmul' with arguments from the 'SparseXPU' backend. This could be because the operator doesn't exist for thi... - "test_sparse_matmul_xpu_float64", # - RuntimeError: Double and complex datatype matmul is not supported in oneDNN - "test_sparse_mm_xpu_float64", # - NotImplementedError: Could not run 'aten::addmm' with arguments from the 'SparseXPU' backend. This could be because the operator doesn't exist for this backend, or wa... - ), "test_decomp.py": ( # AssertionError: Tensor-likes are not close! ; Exception: Tensor-likes are not close! "test_comprehensive_baddbmm_xpu_float64", From e1179f9795af4d414b2de6ce08c0f289e03eaf10 Mon Sep 17 00:00:00 2001 From: libohao1201 Date: Sat, 8 Nov 2025 07:30:35 -0800 Subject: [PATCH 07/12] skip 5 cases under test_decomp.py - worker crashed --- test/xpu/skip_list_common.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index 1d9d562c76..08788f115f 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -796,6 +796,12 @@ "test_comprehensive_ormqr_xpu_complex64", "test_quick_mean_xpu_complex128", "test_comprehensive_grid_sampler_2d_xpu_bfloat16", + # worker 'gw[x]' crashed + "test_quick_core_backward__unsafe_masked_index_xpu_float64", + "test_comprehensive_to_sparse_xpu_int8", + "test_comprehensive_grid_sampler_2d_xpu_float64", + "test_quick_core_backward__unsafe_masked_index_put_accumulate_xpu_float64", + "test_quick_core_backward__unsafe_masked_index_xpu_float64", ), "functorch/test_ops_functorch_xpu.py": None, "test_sparse_xpu.py": None, From ff54b585abd5f0073ce47b59f7d4214bbead98b6 Mon Sep 17 00:00:00 2001 From: libohao1201 Date: Sat, 8 Nov 2025 07:54:20 -0800 Subject: [PATCH 08/12] Fix lint error in test/xpu/skip_list_common.py. --- test/xpu/skip_list_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index 08788f115f..767a437d58 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -796,7 +796,7 @@ "test_comprehensive_ormqr_xpu_complex64", "test_quick_mean_xpu_complex128", "test_comprehensive_grid_sampler_2d_xpu_bfloat16", - # worker 'gw[x]' crashed + # worker 'gw[x]' crashed "test_quick_core_backward__unsafe_masked_index_xpu_float64", "test_comprehensive_to_sparse_xpu_int8", "test_comprehensive_grid_sampler_2d_xpu_float64", From e1bd2cdc3745b8e05bf4bf3e165131d2bdfb9eed Mon Sep 17 00:00:00 2001 From: libohao1201 Date: Sun, 9 Nov 2025 18:29:24 -0800 Subject: [PATCH 09/12] Skip "test_quick_core_backward" and "test_comprehensive" due to worker crash error. --- test/xpu/skip_list_common.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index 9aff44f527..0bdc47476e 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -792,6 +792,9 @@ # RuntimeError: I got this output for HasDecompTest.test_aten_core_operators: "test_aten_core_operators", "test_has_decomposition", + # worker 'gw[x]' crashed + "test_quick_core_backward", + "test_comprehensive", # AssertionError: Tensor-likes are not close! "test_comprehensive_diff_xpu_complex128", "test_comprehensive_ormqr_xpu_complex128", @@ -800,12 +803,6 @@ "test_comprehensive_ormqr_xpu_complex64", "test_quick_mean_xpu_complex128", "test_comprehensive_grid_sampler_2d_xpu_bfloat16", - # worker 'gw[x]' crashed - "test_quick_core_backward__unsafe_masked_index_xpu_float64", - "test_comprehensive_to_sparse_xpu_int8", - "test_comprehensive_grid_sampler_2d_xpu_float64", - "test_quick_core_backward__unsafe_masked_index_put_accumulate_xpu_float64", - "test_quick_core_backward__unsafe_masked_index_xpu_float64", ), "functorch/test_ops_functorch_xpu.py": None, "test_sparse_xpu.py": None, From 085911e11fbc8ef2a8a4263c3de9708d7d71df8c Mon Sep 17 00:00:00 2001 From: libohao1201 Date: Sun, 9 Nov 2025 22:04:28 -0800 Subject: [PATCH 10/12] Revise max-worker-restart to 1000 for pytest invocations --- test/xpu/skip_list_common.py | 3 --- test/xpu/xpu_test_utils.py | 6 +++--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index 0bdc47476e..335b88cc83 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -792,9 +792,6 @@ # RuntimeError: I got this output for HasDecompTest.test_aten_core_operators: "test_aten_core_operators", "test_has_decomposition", - # worker 'gw[x]' crashed - "test_quick_core_backward", - "test_comprehensive", # AssertionError: Tensor-likes are not close! "test_comprehensive_diff_xpu_complex128", "test_comprehensive_ormqr_xpu_complex128", diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index a4d57b607b..7914f5cfd8 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -1182,7 +1182,7 @@ def launch_test(test_case, skip_list=None, exe_list=None): skip_options += skip_option skip_options += '"' test_command = ( - f"pytest --junit-xml=./op_ut_with_skip_{test_case}.xml " + test_case + f"pytest --junit-xml=./op_ut_with_skip_{test_case}.xml --max-worker-restart=1000 " + test_case ) test_command += skip_options elif exe_list is not None: @@ -1192,11 +1192,11 @@ def launch_test(test_case, skip_list=None, exe_list=None): exe_options += exe_option exe_options += '"' test_command = ( - f"pytest --junit-xml=./op_ut_with_exe_{test_case}.xml " + test_case + f"pytest --junit-xml=./op_ut_with_exe_{test_case}.xml --max-worker-restart=1000 " + test_case ) test_command += exe_options else: test_command = ( - f"pytest --junit-xml=./op_ut_with_all_{test_case}.xml " + test_case + f"pytest --junit-xml=./op_ut_with_all_{test_case}.xml --max-worker-restart=1000 " + test_case ) return os.system(test_command) From 6290b244f41f43bb902b2e17f6ea312b5f02a43f Mon Sep 17 00:00:00 2001 From: libohao1201 Date: Sun, 9 Nov 2025 22:08:00 -0800 Subject: [PATCH 11/12] Fix lint error --- test/xpu/xpu_test_utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 7914f5cfd8..586895929a 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -1182,7 +1182,8 @@ def launch_test(test_case, skip_list=None, exe_list=None): skip_options += skip_option skip_options += '"' test_command = ( - f"pytest --junit-xml=./op_ut_with_skip_{test_case}.xml --max-worker-restart=1000 " + test_case + f"pytest --junit-xml=./op_ut_with_skip_{test_case}.xml --max-worker-restart=1000 " + + test_case ) test_command += skip_options elif exe_list is not None: @@ -1192,11 +1193,13 @@ def launch_test(test_case, skip_list=None, exe_list=None): exe_options += exe_option exe_options += '"' test_command = ( - f"pytest --junit-xml=./op_ut_with_exe_{test_case}.xml --max-worker-restart=1000 " + test_case + f"pytest --junit-xml=./op_ut_with_exe_{test_case}.xml --max-worker-restart=1000 " + + test_case ) test_command += exe_options else: test_command = ( - f"pytest --junit-xml=./op_ut_with_all_{test_case}.xml --max-worker-restart=1000 " + test_case + f"pytest --junit-xml=./op_ut_with_all_{test_case}.xml --max-worker-restart=1000 " + + test_case ) return os.system(test_command) From da1c581427a9db526d329c636c294e6cb845fe7e Mon Sep 17 00:00:00 2001 From: libohao1201 Date: Sun, 16 Nov 2025 17:45:14 -0800 Subject: [PATCH 12/12] Revise xml file naming convention in xpu tests --- test/xpu/xpu_test_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 492449dc9e..29bee2ce3a 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -1099,7 +1099,7 @@ def launch_test(test_case, skip_list=None, exe_list=None): skip_options += skip_option skip_options += '"' test_command = ( - f"pytest --junit-xml=./op_ut_with_skip_{test_case}.xml --max-worker-restart=1000 " + f"pytest --junit-xml=./op_ut_with_skip.{test_case}.xml --max-worker-restart=1000 " + test_case ) test_command += skip_options @@ -1110,13 +1110,13 @@ def launch_test(test_case, skip_list=None, exe_list=None): exe_options += exe_option exe_options += '"' test_command = ( - f"pytest --junit-xml=./op_ut_with_exe_{test_case}.xml --max-worker-restart=1000 " + f"pytest --junit-xml=./op_ut_with_exe.{test_case}.xml --max-worker-restart=1000 " + test_case ) test_command += exe_options else: test_command = ( - f"pytest --junit-xml=./op_ut_with_all_{test_case}.xml --max-worker-restart=1000 " + f"pytest --junit-xml=./op_ut_with_all.{test_case}.xml --max-worker-restart=1000 " + test_case ) return os.system(test_command)