Skip to content

Commit faaf223

Browse files
committed
FEAT: adding select and replace
1 parent 718f1cc commit faaf223

File tree

2 files changed

+121
-0
lines changed

2 files changed

+121
-0
lines changed

arrayfire/data.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .library import *
1616
from .array import *
1717
from .util import *
18+
from .util import _is_number
1819

1920
def constant(val, d0, d1=None, d2=None, d3=None, dtype=Dtype.f32):
2021
"""
@@ -781,3 +782,118 @@ def upper(a, is_unit_diag=False):
781782
out = Array()
782783
safe_call(backend.get().af_upper(ct.pointer(out.arr), a.arr, is_unit_diag))
783784
return out
785+
786+
def select(cond, lhs, rhs):
787+
"""
788+
Select elements from one of two arrays based on condition.
789+
790+
Parameters
791+
----------
792+
793+
cond : af.Array
794+
Conditional array
795+
796+
lhs : af.Array or scalar
797+
numerical array whose elements are picked when conditional element is True
798+
799+
rhs : af.Array or scalar
800+
numerical array whose elements are picked when conditional element is False
801+
802+
Returns
803+
--------
804+
805+
out: af.Array
806+
An array containing elements from `lhs` when `cond` is True and `rhs` when False.
807+
808+
Examples
809+
---------
810+
811+
>>> import arrayfire as af
812+
>>> a = af.randu(3,3)
813+
>>> b = af.randu(3,3)
814+
>>> cond = a > b
815+
>>> res = af.select(cond, a, b)
816+
817+
>>> af.display(a)
818+
[3 3 1 1]
819+
0.4107 0.1794 0.3775
820+
0.8224 0.4198 0.3027
821+
0.9518 0.0081 0.6456
822+
823+
>>> af.display(b)
824+
[3 3 1 1]
825+
0.7269 0.3569 0.3341
826+
0.7104 0.1437 0.0899
827+
0.5201 0.4563 0.5363
828+
829+
>>> af.display(res)
830+
[3 3 1 1]
831+
0.7269 0.3569 0.3775
832+
0.8224 0.4198 0.3027
833+
0.9518 0.4563 0.6456
834+
"""
835+
out = Array()
836+
837+
is_left_array = isinstance(lhs, Array)
838+
is_right_array = isinstance(rhs, Array)
839+
840+
if not (is_left_array or is_right_array):
841+
raise TypeError("Atleast one input needs to be of type arrayfire.array")
842+
843+
elif (is_left_array and is_right_array):
844+
safe_call(backend.get().af_select(ct.pointer(out.arr), cond.arr, lhs.arr, rhs.arr))
845+
846+
elif (_is_number(rhs)):
847+
safe_call(backend.get().af_select_scalar_r(ct.pointer(out.arr), cond.arr, lhs.arr, ct.c_double(rhs)))
848+
else:
849+
safe_call(backend.get().af_select_scalar_l(ct.pointer(out.arr), cond.arr, ct.c_double(lhs), rhs.arr))
850+
851+
return out
852+
853+
def replace(lhs, cond, rhs):
854+
"""
855+
Select elements from one of two arrays based on condition.
856+
857+
Parameters
858+
----------
859+
860+
lhs : af.Array or scalar
861+
numerical array whose elements are replaced with `rhs` when conditional element is False
862+
863+
cond : af.Array
864+
Conditional array
865+
866+
rhs : af.Array or scalar
867+
numerical array whose elements are picked when conditional element is False
868+
869+
Examples
870+
---------
871+
>>> import arrayfire as af
872+
>>> a = af.randu(3,3)
873+
>>> af.display(a)
874+
[3 3 1 1]
875+
0.4107 0.1794 0.3775
876+
0.8224 0.4198 0.3027
877+
0.9518 0.0081 0.6456
878+
879+
>>> cond = (a >= 0.25) & (a <= 0.75)
880+
>>> af.display(cond)
881+
[3 3 1 1]
882+
1 0 1
883+
0 1 1
884+
0 0 1
885+
886+
>>> af.replace(a, cond, 0.3333)
887+
>>> af.display(a)
888+
[3 3 1 1]
889+
0.3333 0.1794 0.3333
890+
0.8224 0.3333 0.3333
891+
0.9518 0.0081 0.3333
892+
893+
"""
894+
is_right_array = isinstance(rhs, Array)
895+
896+
if (is_right_array):
897+
safe_call(backend.get().af_replace(lhs.arr, cond.arr, rhs.arr))
898+
else:
899+
safe_call(backend.get().af_replace_scalar(lhs.arr, cond.arr, ct.c_double(rhs)))

tests/simple/data.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,9 @@ def simple_data(verbose=False):
7676
af.transpose_inplace(a)
7777
display_func(a)
7878

79+
display_func(af.select(a > 0.3, a, -0.3))
80+
81+
af.replace(a, a > 0.3, -0.3)
82+
display_func(a)
83+
7984
_util.tests['data'] = simple_data

0 commit comments

Comments
 (0)