|
15 | 15 | from .library import * |
16 | 16 | from .array import * |
17 | 17 | from .util import * |
| 18 | +from .util import _is_number |
18 | 19 |
|
19 | 20 | def constant(val, d0, d1=None, d2=None, d3=None, dtype=Dtype.f32): |
20 | 21 | """ |
@@ -781,3 +782,118 @@ def upper(a, is_unit_diag=False): |
781 | 782 | out = Array() |
782 | 783 | safe_call(backend.get().af_upper(ct.pointer(out.arr), a.arr, is_unit_diag)) |
783 | 784 | return out |
| 785 | + |
| 786 | +def select(cond, lhs, rhs): |
| 787 | + """ |
| 788 | + Select elements from one of two arrays based on condition. |
| 789 | +
|
| 790 | + Parameters |
| 791 | + ---------- |
| 792 | +
|
| 793 | + cond : af.Array |
| 794 | + Conditional array |
| 795 | +
|
| 796 | + lhs : af.Array or scalar |
| 797 | + numerical array whose elements are picked when conditional element is True |
| 798 | +
|
| 799 | + rhs : af.Array or scalar |
| 800 | + numerical array whose elements are picked when conditional element is False |
| 801 | +
|
| 802 | + Returns |
| 803 | + -------- |
| 804 | +
|
| 805 | + out: af.Array |
| 806 | + An array containing elements from `lhs` when `cond` is True and `rhs` when False. |
| 807 | +
|
| 808 | + Examples |
| 809 | + --------- |
| 810 | +
|
| 811 | + >>> import arrayfire as af |
| 812 | + >>> a = af.randu(3,3) |
| 813 | + >>> b = af.randu(3,3) |
| 814 | + >>> cond = a > b |
| 815 | + >>> res = af.select(cond, a, b) |
| 816 | +
|
| 817 | + >>> af.display(a) |
| 818 | + [3 3 1 1] |
| 819 | + 0.4107 0.1794 0.3775 |
| 820 | + 0.8224 0.4198 0.3027 |
| 821 | + 0.9518 0.0081 0.6456 |
| 822 | +
|
| 823 | + >>> af.display(b) |
| 824 | + [3 3 1 1] |
| 825 | + 0.7269 0.3569 0.3341 |
| 826 | + 0.7104 0.1437 0.0899 |
| 827 | + 0.5201 0.4563 0.5363 |
| 828 | +
|
| 829 | + >>> af.display(res) |
| 830 | + [3 3 1 1] |
| 831 | + 0.7269 0.3569 0.3775 |
| 832 | + 0.8224 0.4198 0.3027 |
| 833 | + 0.9518 0.4563 0.6456 |
| 834 | + """ |
| 835 | + out = Array() |
| 836 | + |
| 837 | + is_left_array = isinstance(lhs, Array) |
| 838 | + is_right_array = isinstance(rhs, Array) |
| 839 | + |
| 840 | + if not (is_left_array or is_right_array): |
| 841 | + raise TypeError("Atleast one input needs to be of type arrayfire.array") |
| 842 | + |
| 843 | + elif (is_left_array and is_right_array): |
| 844 | + safe_call(backend.get().af_select(ct.pointer(out.arr), cond.arr, lhs.arr, rhs.arr)) |
| 845 | + |
| 846 | + elif (_is_number(rhs)): |
| 847 | + safe_call(backend.get().af_select_scalar_r(ct.pointer(out.arr), cond.arr, lhs.arr, ct.c_double(rhs))) |
| 848 | + else: |
| 849 | + safe_call(backend.get().af_select_scalar_l(ct.pointer(out.arr), cond.arr, ct.c_double(lhs), rhs.arr)) |
| 850 | + |
| 851 | + return out |
| 852 | + |
| 853 | +def replace(lhs, cond, rhs): |
| 854 | + """ |
| 855 | + Select elements from one of two arrays based on condition. |
| 856 | +
|
| 857 | + Parameters |
| 858 | + ---------- |
| 859 | +
|
| 860 | + lhs : af.Array or scalar |
| 861 | + numerical array whose elements are replaced with `rhs` when conditional element is False |
| 862 | +
|
| 863 | + cond : af.Array |
| 864 | + Conditional array |
| 865 | +
|
| 866 | + rhs : af.Array or scalar |
| 867 | + numerical array whose elements are picked when conditional element is False |
| 868 | +
|
| 869 | + Examples |
| 870 | + --------- |
| 871 | + >>> import arrayfire as af |
| 872 | + >>> a = af.randu(3,3) |
| 873 | + >>> af.display(a) |
| 874 | + [3 3 1 1] |
| 875 | + 0.4107 0.1794 0.3775 |
| 876 | + 0.8224 0.4198 0.3027 |
| 877 | + 0.9518 0.0081 0.6456 |
| 878 | +
|
| 879 | + >>> cond = (a >= 0.25) & (a <= 0.75) |
| 880 | + >>> af.display(cond) |
| 881 | + [3 3 1 1] |
| 882 | + 1 0 1 |
| 883 | + 0 1 1 |
| 884 | + 0 0 1 |
| 885 | +
|
| 886 | + >>> af.replace(a, cond, 0.3333) |
| 887 | + >>> af.display(a) |
| 888 | + [3 3 1 1] |
| 889 | + 0.3333 0.1794 0.3333 |
| 890 | + 0.8224 0.3333 0.3333 |
| 891 | + 0.9518 0.0081 0.3333 |
| 892 | +
|
| 893 | + """ |
| 894 | + is_right_array = isinstance(rhs, Array) |
| 895 | + |
| 896 | + if (is_right_array): |
| 897 | + safe_call(backend.get().af_replace(lhs.arr, cond.arr, rhs.arr)) |
| 898 | + else: |
| 899 | + safe_call(backend.get().af_replace_scalar(lhs.arr, cond.arr, ct.c_double(rhs))) |
0 commit comments