|
4 | 4 | from builtins import all as _builtin_all, any as _builtin_any |
5 | 5 |
|
6 | 6 | from ..common._aliases import (matrix_transpose as _aliases_matrix_transpose, |
7 | | - vecdot as _aliases_vecdot) |
| 7 | + vecdot as _aliases_vecdot, clip as _aliases_clip) |
8 | 8 | from .._internal import get_xp |
9 | 9 |
|
10 | 10 | import torch |
@@ -189,6 +189,8 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep |
189 | 189 | return torch.clone(x) |
190 | 190 | return torch.amin(x, axis, keepdims=keepdims) |
191 | 191 |
|
| 192 | +clip = get_xp(torch)(_aliases_clip) |
| 193 | + |
192 | 194 | # torch.sort also returns a tuple |
193 | 195 | # https://github.com/pytorch/pytorch/issues/70921 |
194 | 196 | def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> array: |
@@ -706,8 +708,8 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - |
706 | 708 | 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign', |
707 | 709 | 'divide', 'equal', 'floor_divide', 'greater', 'greater_equal', |
708 | 710 | 'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow', |
709 | | - 'remainder', 'subtract', 'max', 'min', 'sort', 'prod', 'sum', |
710 | | - 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze', |
| 711 | + 'remainder', 'subtract', 'max', 'min', 'clip', 'sort', 'prod', |
| 712 | + 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze', |
711 | 713 | 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape', |
712 | 714 | 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty', |
713 | 715 | 'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays', |
|
0 commit comments