|
7 | 7 | import unittest |
8 | 8 |
|
9 | 9 | import torch |
10 | | -from torch import bfloat16, cuda |
| 10 | +from torch import cuda |
11 | 11 | from torch.testing._internal import ( |
12 | 12 | common_cuda, |
13 | 13 | common_device_type, |
|
354 | 354 | "_refs.div", |
355 | 355 | "test_python_ref_torch_fallback", |
356 | 356 | ), |
| 357 | + ("_refs.true_div", "test_python_ref"), |
| 358 | + ( |
| 359 | + "_refs.true_div", |
| 360 | + "test_python_ref_torch_fallback", |
| 361 | + ), |
357 | 362 | ("argsort", "test_non_standard_bool_values"), |
358 | 363 | ("sort", "test_non_standard_bool_values"), |
359 | 364 | ] |
@@ -865,7 +870,6 @@ def __init__(self, patch_test_case=True) -> None: |
865 | 870 | ) |
866 | 871 | self.foreach_reduce_op_db = common_methods_invocations.foreach_reduce_op_db |
867 | 872 | self.foreach_other_op_db = common_methods_invocations.foreach_other_op_db |
868 | | - self.python_ref_db = common_methods_invocations.python_ref_db |
869 | 873 | self.ops_and_refs = common_methods_invocations.ops_and_refs |
870 | 874 | self.largeTensorTest = common_device_type.largeTensorTest |
871 | 875 | self.TEST_CUDA = common_cuda.TEST_CUDA |
@@ -921,19 +925,10 @@ def gen_xpu_wrappers(op_name, wrappers): |
921 | 925 |
|
922 | 926 | def align_supported_dtypes(self, db): |
923 | 927 | for opinfo in db: |
924 | | - if ( |
925 | | - opinfo.name not in _xpu_computation_op_list |
926 | | - and ( |
927 | | - opinfo.torch_opinfo.name not in _xpu_computation_op_list |
928 | | - if db == common_methods_invocations.python_ref_db |
929 | | - else True |
930 | | - ) |
931 | | - ) or opinfo.name in _ops_without_cuda_support: |
| 928 | + if opinfo.name in _ops_without_cuda_support: |
932 | 929 | opinfo.dtypesIf["xpu"] = opinfo.dtypes |
933 | 930 | else: |
934 | 931 | backward_dtypes = set(opinfo.backward_dtypesIfCUDA) |
935 | | - if bfloat16 in opinfo.dtypesIf["xpu"]: |
936 | | - backward_dtypes.add(bfloat16) |
937 | 932 | opinfo.backward_dtypes = tuple(backward_dtypes) |
938 | 933 |
|
939 | 934 | if opinfo.name in _ops_dtype_different_cuda_support: |
@@ -1039,13 +1034,13 @@ def __init__(self, *args): |
1039 | 1034 | self.align_db_decorators(db) |
1040 | 1035 | self.filter_fp64_sample_input(db) |
1041 | 1036 | self.align_db_decorators(module_db) |
1042 | | - common_methods_invocations.python_ref_db = [ |
| 1037 | + _python_ref_db = [ |
1043 | 1038 | op |
1044 | | - for op in self.python_ref_db |
| 1039 | + for op in common_methods_invocations.python_ref_db |
1045 | 1040 | if op.torch_opinfo_name in _xpu_computation_op_list |
1046 | 1041 | ] |
1047 | 1042 | common_methods_invocations.ops_and_refs = ( |
1048 | | - common_methods_invocations.op_db + common_methods_invocations.python_ref_db |
| 1043 | + common_methods_invocations.op_db + _python_ref_db |
1049 | 1044 | ) |
1050 | 1045 | common_methods_invocations.unary_ufuncs = [ |
1051 | 1046 | op |
@@ -1128,7 +1123,6 @@ def __exit__(self, exc_type, exc_value, traceback): |
1128 | 1123 | self.instantiate_parametrized_tests_fn |
1129 | 1124 | ) |
1130 | 1125 | common_utils.TestCase = self.test_case_cls |
1131 | | - common_methods_invocations.python_ref_db = self.python_ref_db |
1132 | 1126 | common_methods_invocations.ops_and_refs = self.ops_and_refs |
1133 | 1127 | common_device_type.largeTensorTest = self.largeTensorTest |
1134 | 1128 | common_cuda.TEST_CUDA = self.TEST_CUDA |
|
0 commit comments