Skip to content

Commit ac20889

Browse files
authored
use stock pytorch defintion of python_ref_db for test_ops (#2303)
disable_build disable_e2e disable_distributed
1 parent cb0dee7 commit ac20889

File tree

2 files changed

+14
-16
lines changed

2 files changed

+14
-16
lines changed

test/xpu/skip_list_common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
"test_python_ref_executor__refs_mul_executor_aten_xpu_complex32",
1515
# https://github.com/intel/torch-xpu-ops/issues/2254
1616
"histogramdd",
17+
"_vdot_",
18+
"_dot_",
19+
"_flash_attention_",
20+
"_efficient_attention_",
1721
),
1822
"test_binary_ufuncs_xpu.py": (
1923
"test_fmod_remainder_by_zero_integral_xpu_int64", # zero division is an undefined behavior: different handles on different backends

test/xpu/xpu_test_utils.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import unittest
88

99
import torch
10-
from torch import bfloat16, cuda
10+
from torch import cuda
1111
from torch.testing._internal import (
1212
common_cuda,
1313
common_device_type,
@@ -354,6 +354,11 @@
354354
"_refs.div",
355355
"test_python_ref_torch_fallback",
356356
),
357+
("_refs.true_div", "test_python_ref"),
358+
(
359+
"_refs.true_div",
360+
"test_python_ref_torch_fallback",
361+
),
357362
("argsort", "test_non_standard_bool_values"),
358363
("sort", "test_non_standard_bool_values"),
359364
]
@@ -865,7 +870,6 @@ def __init__(self, patch_test_case=True) -> None:
865870
)
866871
self.foreach_reduce_op_db = common_methods_invocations.foreach_reduce_op_db
867872
self.foreach_other_op_db = common_methods_invocations.foreach_other_op_db
868-
self.python_ref_db = common_methods_invocations.python_ref_db
869873
self.ops_and_refs = common_methods_invocations.ops_and_refs
870874
self.largeTensorTest = common_device_type.largeTensorTest
871875
self.TEST_CUDA = common_cuda.TEST_CUDA
@@ -921,19 +925,10 @@ def gen_xpu_wrappers(op_name, wrappers):
921925

922926
def align_supported_dtypes(self, db):
923927
for opinfo in db:
924-
if (
925-
opinfo.name not in _xpu_computation_op_list
926-
and (
927-
opinfo.torch_opinfo.name not in _xpu_computation_op_list
928-
if db == common_methods_invocations.python_ref_db
929-
else True
930-
)
931-
) or opinfo.name in _ops_without_cuda_support:
928+
if opinfo.name in _ops_without_cuda_support:
932929
opinfo.dtypesIf["xpu"] = opinfo.dtypes
933930
else:
934931
backward_dtypes = set(opinfo.backward_dtypesIfCUDA)
935-
if bfloat16 in opinfo.dtypesIf["xpu"]:
936-
backward_dtypes.add(bfloat16)
937932
opinfo.backward_dtypes = tuple(backward_dtypes)
938933

939934
if opinfo.name in _ops_dtype_different_cuda_support:
@@ -1039,13 +1034,13 @@ def __init__(self, *args):
10391034
self.align_db_decorators(db)
10401035
self.filter_fp64_sample_input(db)
10411036
self.align_db_decorators(module_db)
1042-
common_methods_invocations.python_ref_db = [
1037+
_python_ref_db = [
10431038
op
1044-
for op in self.python_ref_db
1039+
for op in common_methods_invocations.python_ref_db
10451040
if op.torch_opinfo_name in _xpu_computation_op_list
10461041
]
10471042
common_methods_invocations.ops_and_refs = (
1048-
common_methods_invocations.op_db + common_methods_invocations.python_ref_db
1043+
common_methods_invocations.op_db + _python_ref_db
10491044
)
10501045
common_methods_invocations.unary_ufuncs = [
10511046
op
@@ -1128,7 +1123,6 @@ def __exit__(self, exc_type, exc_value, traceback):
11281123
self.instantiate_parametrized_tests_fn
11291124
)
11301125
common_utils.TestCase = self.test_case_cls
1131-
common_methods_invocations.python_ref_db = self.python_ref_db
11321126
common_methods_invocations.ops_and_refs = self.ops_and_refs
11331127
common_device_type.largeTensorTest = self.largeTensorTest
11341128
common_cuda.TEST_CUDA = self.TEST_CUDA

0 commit comments

Comments
 (0)