11"""Test "unspecified" behavior which we cannot easily test in the Array API test suite.
22"""
3+ import itertools
4+
35import pytest
46import torch
57
@@ -51,7 +53,10 @@ def test_two_args(self):
5153 def test_multi_arg (self ):
5254 torch .set_default_dtype (torch .float32 )
5355
54- args = [1 , 2 , 3j , xp .arange (3 ), 4 , 5 , 6 ]
56+ args = [1. , 5 , 3 , torch .asarray ([3 ], dtype = torch .float16 ), 5 , 6 , 1. ]
57+ assert xp .result_type (* args ) == torch .float16
58+
59+ args = [1 , 2 , 3j , xp .arange (3 , dtype = xp .float32 ), 4 , 5 , 6 ]
5560 assert xp .result_type (* args ) == xp .complex64
5661
5762 args = [1 , 2 , 3j , xp .float64 , 4 , 5 , 6 ]
@@ -60,5 +65,10 @@ def test_multi_arg(self):
6065 args = [1 , 2 , 3j , xp .float64 , 4 , xp .asarray (3 , dtype = xp .int16 ), 5 , 6 , False ]
6166 assert xp .result_type (* args ) == xp .complex128
6267
68+ i64 = xp .ones (1 , dtype = xp .int64 )
69+ f16 = xp .ones (1 , dtype = xp .float16 )
70+ for i in itertools .permutations ([i64 , f16 , 1.0 , 1.0 ]):
71+ assert xp .result_type (* i ) == xp .float16 , f"{ i } "
72+
6373 with pytest .raises (ValueError ):
6474 xp .result_type (1 , 2 , 3 , 4 )
0 commit comments