|
1 | 1 | import math |
2 | 2 | import operator |
3 | 3 | from enum import Enum, auto |
4 | | -from typing import Callable, List, NamedTuple, Optional, Union |
| 4 | +from typing import Callable, List, NamedTuple, Optional, TypeVar, Union |
5 | 5 |
|
6 | 6 | import pytest |
7 | 7 | from hypothesis import assume, given |
@@ -85,11 +85,14 @@ def default_filter(s: Scalar) -> bool: |
85 | 85 | return math.isfinite(s) and s is not -0.0 and s is not +0.0 |
86 | 86 |
|
87 | 87 |
|
| 88 | +T = TypeVar("T") |
| 89 | + |
| 90 | + |
88 | 91 | def unary_assert_against_refimpl( |
89 | 92 | func_name: str, |
90 | 93 | in_: Array, |
91 | 94 | res: Array, |
92 | | - refimpl: Callable[[Scalar], Scalar], |
| 95 | + refimpl: Callable[[T], T], |
93 | 96 | expr_template: Optional[str] = None, |
94 | 97 | res_stype: Optional[ScalarType] = None, |
95 | 98 | filter_: Callable[[Scalar], bool] = default_filter, |
@@ -136,7 +139,7 @@ def binary_assert_against_refimpl( |
136 | 139 | left: Array, |
137 | 140 | right: Array, |
138 | 141 | res: Array, |
139 | | - refimpl: Callable[[Scalar, Scalar], Scalar], |
| 142 | + refimpl: Callable[[T, T], T], |
140 | 143 | expr_template: Optional[str] = None, |
141 | 144 | res_stype: Optional[ScalarType] = None, |
142 | 145 | left_sym: str = "x1", |
@@ -382,7 +385,7 @@ def binary_param_assert_against_refimpl( |
382 | 385 | right: Union[Array, Scalar], |
383 | 386 | res: Array, |
384 | 387 | op_sym: str, |
385 | | - refimpl: Callable[[Scalar, Scalar], Scalar], |
| 388 | + refimpl: Callable[[T, T], T], |
386 | 389 | res_stype: Optional[ScalarType] = None, |
387 | 390 | filter_: Callable[[Scalar], bool] = default_filter, |
388 | 391 | strict_check: Optional[bool] = None, |
@@ -456,7 +459,7 @@ def test_abs(ctx, data): |
456 | 459 | ctx.func_name, |
457 | 460 | x, |
458 | 461 | out, |
459 | | - abs, |
| 462 | + abs, # type: ignore |
460 | 463 | expr_template="abs({})={}", |
461 | 464 | filter_=lambda s: ( |
462 | 465 | s == float("infinity") or (math.isfinite(s) and s is not -0.0) |
@@ -1013,7 +1016,7 @@ def test_negative(ctx, data): |
1013 | 1016 | ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) |
1014 | 1017 | ph.assert_shape(ctx.func_name, out.shape, x.shape) |
1015 | 1018 | unary_assert_against_refimpl( |
1016 | | - ctx.func_name, x, out, operator.neg, expr_template="-({})={}" |
| 1019 | + ctx.func_name, x, out, operator.neg, expr_template="-({})={}" # type: ignore |
1017 | 1020 | ) |
1018 | 1021 |
|
1019 | 1022 |
|
|
0 commit comments