Skip to content

Commit 0bbb067

Browse files
ppapierkPiotr PapierkowskiCopilottadkrawiecmengfei25
authored
Add exception to float->int conversion test case (#2105)
Downcasting float32 values beyond target type value range is an undefined behaviour according to torch documentation. --------- Co-authored-by: Piotr Papierkowski <ppapierkowski@habana.ai> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Tadeusz Krawiec <72737249+tadkrawiec@users.noreply.github.com> Co-authored-by: mengfei25 <mengfei.li@Intel.com>
1 parent dbf6b0d commit 0bbb067

File tree

2 files changed

+4
-6
lines changed

2 files changed

+4
-6
lines changed

test/xpu/skip_list_common.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -658,11 +658,7 @@
658658
"nn/test_pooling_xpu.py": None,
659659
"nn/test_dropout_xpu.py": None,
660660
"test_dataloader_xpu.py": None,
661-
"test_tensor_creation_ops_xpu.py": (
662-
# CPU only (vs Numpy). CUDA skips these cases since non-deterministic results are outputed for inf and nan.
663-
"test_float_to_int_conversion_finite_xpu_int8",
664-
"test_float_to_int_conversion_finite_xpu_int16",
665-
),
661+
"test_tensor_creation_ops_xpu.py": None,
666662
"test_autocast_xpu.py": None,
667663
"test_autograd_xpu.py": (
668664
# AttributeError: module 'torch.xpu' has no attribute

test/xpu/test_tensor_creation_ops_xpu.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1226,8 +1226,10 @@ def test_float_to_int_conversion_finite(self, device, dtype):
12261226
vals = (min, -2, -1.5, -0.5, 0, 0.5, 1.5, 2, max)
12271227
refs = None
12281228
if self.device_type == "cuda" or self.device_type == "xpu":
1229-
if torch.version.hip:
1229+
if torch.version.hip or torch.version.xpu:
12301230
# HIP min float -> int64 conversion is divergent
1231+
# XPU min float -> int8 conversion is divergent
1232+
# XPU min float -> int16 conversion is divergent
12311233
vals = (-2, -1.5, -0.5, 0, 0.5, 1.5, 2)
12321234
else:
12331235
vals = (min, -2, -1.5, -0.5, 0, 0.5, 1.5, 2)

0 commit comments

Comments
 (0)