File tree Expand file tree Collapse file tree 1 file changed +9
-4
lines changed Expand file tree Collapse file tree 1 file changed +9
-4
lines changed Original file line number Diff line number Diff line change @@ -151,13 +151,18 @@ def result_type(
151151 return _reduce (_result_type , others + scalars )
152152
153153
154- def _result_type (x , y ):
154+ def _result_type (
155+ x : Array | DType | bool | int | float | complex ,
156+ y : Array | DType | bool | int | float | complex ,
157+ ) -> DType :
155158 if not (isinstance (x , _py_scalars ) or isinstance (y , _py_scalars )):
156- xdt = x . dtype if not isinstance (x , torch .dtype ) else x
157- ydt = y . dtype if not isinstance (y , torch .dtype ) else y
159+ xdt = x if isinstance (x , torch .dtype ) else x . dtype
160+ ydt = y if isinstance (y , torch .dtype ) else y . dtype
158161
159- if ( xdt , ydt ) in _promotion_table :
162+ try :
160163 return _promotion_table [xdt , ydt ]
164+ except KeyError :
165+ pass
161166
162167 # This doesn't result_type(dtype, dtype) for non-array API dtypes
163168 # because torch.result_type only accepts tensors. This does however, allow
You can’t perform that action at this time.
0 commit comments