Skip to content

Commit bdaf733

Browse files
committed
Rest of array methods
Added 5.3 operators
1 parent 355f8c7 commit bdaf733

File tree

2 files changed

+83
-16
lines changed

2 files changed

+83
-16
lines changed

scripts/lib/impl/operators.lua

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
local af = require("arrayfire")
55

66
-- Forward declarations --
7+
local CallWrap
78
local GetLib
89
local TwoArrays
910

@@ -35,6 +36,7 @@ end
3536
function M.Add (array_module, meta)
3637
-- Import these here since the array module is not yet registered.
3738
GetLib = array_module.GetLib
39+
CallWrap = array_module.CallWrap
3840
TwoArrays = array_module.TwoArrays
3941

4042
--
@@ -49,30 +51,35 @@ function M.Add (array_module, meta)
4951
--
5052
for k, v in pairs{
5153
__add = Binary("add"),
54+
__band = Binary("bitand"),
55+
__bnot = function(a)
56+
return CallWrap("af_not", a:get())
57+
end,
58+
__bor = Binary("bitor"),
59+
__bxor = Binary("bitxor"),
5260
__call = function(a, ...)
5361
-- operator()... ugh (proxy types, __index and __newindex shenanigans)
5462
end,
5563
__div = Binary("div"),
5664
__eq = Binary("eq", true),
5765
__lt = Binary("lt", true),
5866
__le = Binary("le", true),
59-
__mod = Binary("mod"),
67+
__mod = Binary("rem"),
6068
__mul = Binary("mul"),
61-
--[[
62-
__newindex = function(arr, k, v)
69+
__newindex = function(a, k, v)
6370
-- TODO: disable for non-proxies?
6471

6572
if k == "_" then
6673
-- lvalue assign of v
6774
end
68-
end
69-
]]
75+
end,
7076
__pow = Binary("pow"),
77+
__shl = Binary("bitshiftl"),
78+
__shr = Binary("bitshiftr"),
7179
__sub = Binary("sub"),
7280
__unm = function(a)
7381
return 0 - a
74-
end,
75-
-- TODO: 5.3 supports bitwise ops...
82+
end
7683
} do
7784
meta[k] = v
7885
end

scripts/lib/methods/methods.lua

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,26 @@ function M.Add (array_module, meta)
1717
local CallWrap = array_module.CallWrap
1818
local GetLib = array_module.GetLib
1919

20+
--
21+
local function Wrap (name)
22+
name = "af_" .. name
23+
24+
return function(arr)
25+
return CallWrap(name, arr:get())
26+
end
27+
end
28+
29+
-- --
30+
local SizeOf = {}
31+
32+
for prefix, size in ("f32 f64 s32 u32 s64 u64 u8 b8 c32 c64 s16 u16"):gmatch "(%a)(%d+)" do
33+
local k = af[prefix .. size]
34+
35+
if k then -- account for earlier versions
36+
SizeOf[k] = tonumber(size) / (prefix == "c" and 4 or 8) -- 8 bits to a byte; double complex types
37+
end
38+
end
39+
2040
--
2141
for k, v in pairs{
2242
--
@@ -25,10 +45,16 @@ function M.Add (array_module, meta)
2545
end,
2646

2747
--
28-
copy = function(arr)
29-
return CallWrap("af_copy_array", arr:get())
48+
bytes = function(arr)
49+
local ha = arr:get()
50+
local n, dtype = Call("af_get_elements", ha), Call("af_get_type", ha)
51+
52+
return n * (SizeOf[dtype] or 4)
3053
end,
3154

55+
--
56+
copy = Wrap("copy_array"),
57+
3258
--
3359
dims = function(arr, i)
3460
if i then
@@ -39,18 +65,50 @@ function M.Add (array_module, meta)
3965
end,
4066

4167
--
42-
elements = function(arr)
43-
return Call("af_get_elements", arr:get())
44-
end,
68+
elements = Wrap("get_elements"),
4569

4670
--
47-
eval = function(arr)
48-
Call("af_eval", arr:get())
49-
end,
71+
eval = Wrap("eval"),
5072

5173
--
5274
get = array_module.GetHandle,
5375

76+
--
77+
isbool = Wrap("is_bool"),
78+
79+
--
80+
iscolumn = Wrap("is_column"),
81+
82+
--
83+
iscomplex = Wrap("is_complex"),
84+
85+
--
86+
isdouble = Wrap("is_double"),
87+
88+
--
89+
isempty = Wrap("is_empty"),
90+
91+
--
92+
isfloating = Wrap("is_floating"),
93+
94+
--
95+
isinteger = Wrap("is_integer"),
96+
97+
--
98+
isrealfloating = Wrap("is_real_floating"),
99+
100+
--
101+
isrow = Wrap("is_row"),
102+
103+
--
104+
isscalar = Wrap("is_scalar"),
105+
106+
--
107+
issingle = Wrap("is_single"),
108+
109+
--
110+
isvector = Wrap("is_vector"),
111+
54112
--
55113
H = function(arr)
56114
return GetLib().transpose(arr, true)
@@ -67,7 +125,9 @@ function M.Add (array_module, meta)
67125
--
68126
T = function(arr)
69127
return GetLib().transpose(arr)
70-
end
128+
end,
129+
130+
type = Wrap("get_type")
71131
} do
72132
meta[k] = v
73133
end

0 commit comments

Comments
 (0)