Skip to content

Commit 355f8c7

Browse files
committed
Added where()
Copied over C++ code for index... still puzzling out a good attack :P Allow subtypes in ephemeral environments (currently just arrays, but meant to support proxies, sequences, and indices at least)
1 parent c2dbe91 commit 355f8c7

File tree

5 files changed

+290
-50
lines changed

5 files changed

+290
-50
lines changed

scripts/lib/funcs/vector.lua

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
--- Vector functions.
22

33
-- Standard library imports --
4+
local assert = assert
45
local min = math.min
56
local select = select
67
local type = type
@@ -20,6 +21,7 @@ local ToType = array.ToType
2021
local M = {}
2122

2223
-- See also: https://github.com/arrayfire/arrayfire/blob/devel/src/api/cpp/reduce.cpp
24+
-- https://github.com/arrayfire/arrayfire/blob/devel/src/api/cpp/where.cpp
2325

2426
--
2527
local function Bool (value)
@@ -172,7 +174,14 @@ function M.Add (into)
172174
end,
173175

174176
--
175-
sum = ReduceNaN("sum")
177+
sum = ReduceNaN("sum"),
178+
179+
--
180+
where = function(in_arr)
181+
assert(not GetLib().gforGet(), "WHERE can not be used inside GFOR") -- TODO: AF_ERR_RUNTIME);
182+
183+
return CallWrap("af_where", in_arr:get())
184+
end
176185
} do
177186
into[k] = v
178187
end

scripts/lib/impl/array.lua

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -139,17 +139,10 @@ end
139139

140140
--- DOCME
141141
-- @tparam LuaArray arr
142-
-- @bool remove
143142
-- @treturn ?|af_array|nil X
144-
function M.GetHandle (arr, remove)
143+
function M.GetHandle (arr)
145144
-- TODO: If proxy, add reference?
146-
local ha = arr.m_handle
147-
148-
if remove then
149-
arr.m_handle = nil
150-
end
151-
152-
return ha
145+
return arr.m_handle
153146
end
154147

155148
-- --
@@ -188,6 +181,8 @@ function M.IsConstant (item)
188181
return not not Constants[item] -- metatable redundant; coerce to false if missing
189182
end
190183

184+
-- TODO: IsProxy(), MakeProxy()...
185+
191186
--- DOCME
192187
-- @tparam LuaArray arr
193188
-- @tparam ?|af_array|nil handle
@@ -237,9 +232,8 @@ function M.ToType (ret_type, real, imag)
237232
if rtype == "c32" or rtype == "c64" then
238233
return { real = real, imag = imag }
239234
else
240-
return real
235+
return real -- TODO: Improve this!
241236
end
242-
-- TODO: Improve these a bit
243237
end
244238

245239
--- DOCME
@@ -279,7 +273,7 @@ end
279273
function M.WrapArray (arr)
280274
local wrapped = setmetatable({ m_handle = arr }, ArrayMethodsAndMetatable)
281275

282-
_AddToCurrentEnvironment_(wrapped)
276+
_AddToCurrentEnvironment_("array", wrapped)
283277

284278
return wrapped
285279
end
@@ -299,8 +293,8 @@ end
299293
for _, v in ipairs{
300294
"lib.impl.ephemeral",
301295
"lib.impl.operators",
302-
"lib.impl.seq",
303-
"lib.impl.index", -- depends on seq
296+
"lib.impl.seq", -- depends on ephemeral
297+
"lib.impl.index", -- depends on ephemeral, seq
304298
"lib.methods.methods"
305299
} do
306300
require(v).Add(M, ArrayMethodsAndMetatable)
@@ -309,6 +303,17 @@ end
309303
ArrayMethodsAndMetatable.__index = ArrayMethodsAndMetatable
310304
ArrayMethodsAndMetatable.__metatable = MetaValue
311305

306+
-- Register array environment type.
307+
M.RegisterEnvironmentCleanup("array", function(arr)
308+
local ha = arr:get()
309+
310+
arr.m_handle = nil -- set() can error out
311+
312+
return not ha or af.af_release_array(ha) == SUCCESS
313+
-- TODO: pooling?
314+
end, "Errors releasing %i arrays")
315+
-- TODO: Register "array_proxy"?
316+
312317
-- By default, check valid names.
313318
M.CheckNames(true)
314319

scripts/lib/impl/ephemeral.lua

Lines changed: 60 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,14 @@
22

33
-- Standard library imports --
44
local assert = assert
5+
local concat = table.concat
56
local collectgarbage = collectgarbage
67
local error = error
78
local pairs = pairs
89
local pcall = pcall
910
local rawequal = rawequal
1011
local remove = table.remove
1112

12-
-- Modules --
13-
local af = require("arrayfire")
14-
15-
-- Forward declarations --
16-
local IsArray
17-
1813
-- Cookies --
1914
local _command = {}
2015

@@ -30,62 +25,88 @@ local Stack = {}
3025
-- --
3126
local Top = 0
3227

28+
--
29+
local function Remove (lists, elem)
30+
for elem_type, list in pairs(lists) do
31+
if list[elem] then
32+
list[elem] = nil
33+
34+
return elem_type
35+
end
36+
end
37+
end
38+
39+
-- --
40+
local Types = {}
41+
3342
--
3443
local function NewEnv ()
35-
local id, list, mode, step = ID, {}
44+
local id, lists, mode, step = ID, {}
3645

3746
ID = ID + 1
3847

48+
for elem_type in pairs(Types) do
49+
lists[elem_type] = {}
50+
end
51+
3952
return function(a, b, c)
4053
if rawequal(a, _command) then -- a: _command, b: what, c: arg
4154
if b == "set_mode" then
4255
mode = c
4356
elseif b == "get_id" then
4457
return id
45-
elseif b == "get_list" then
46-
return list
58+
elseif b == "get_lists" then
59+
return lists
4760
elseif b == "set_step" then
4861
step = c
4962
end
5063
elseif a == "get_step" then -- a: "get_step"
5164
return step
52-
elseif IsArray(a) then -- a: array?
65+
else -- a: element?
5366
local env = Stack[Top]
5467

5568
assert(env and env(_command, "get_id") == id, "Environment not active") -- is self?
5669

57-
local lower_env = (mode == "parent" or mode == "parent_gc") and Stack[Top - 1]
58-
-- TODO: pingpong, pingpong_gc
59-
if lower_env then
60-
lower_env(_command, "get_list")[a] = true
61-
end
70+
local elem_type = Remove(lists, a)
6271

63-
list[a] = nil
72+
if elem_type then
73+
local lower_env = (mode == "parent" or mode == "parent_gc") and Stack[Top - 1]
74+
-- TODO: pingpong, pingpong_gc
75+
if lower_env then
76+
lower_env(_command, "get_lists")[elem_type][a] = true
77+
end
6478

65-
return a
79+
return a
80+
end
6681
end
6782
end
6883
end
6984

7085
--
71-
local function Purge (list)
72-
local nerrs = 0
86+
local function Purge (lists)
87+
local errs
7388

74-
for arr in pairs(list) do
75-
local ha = arr:get(true)
89+
for elem_type, type_info in pairs(Types) do
90+
local elem_list, cleanup, nerrs = lists[elem_type], type_info.cleanup, 0
7691

77-
if ha then
78-
local err = af.af_release_array(ha)
79-
80-
if err ~= af.AF_SUCCESS then
92+
--
93+
for elem in pairs(elem_list) do
94+
if not cleanup(elem) then
8195
nerrs = nerrs + 1
8296
end
97+
98+
elem_list[elem] = nil
8399
end
84100

85-
list[arr] = nil
101+
--
102+
if nerrs > 0 then
103+
errs = errs or {}
104+
105+
errs[#errs + 1] = type_info.message:format(nerrs)
106+
end
86107
end
87108

88-
return nerrs
109+
return errs and concat(errs, "\n")
89110
end
90111

91112
-- --
@@ -97,7 +118,7 @@ local function GetResults (env, ok, a, ...)
97118

98119
env(_command, "set_mode", nil)
99120

100-
local nerrs = Purge(env(_command, "get_list"))
121+
local errs = Purge(env(_command, "get_lists"))
101122
-- Pingpong or normal? (How to end?)
102123
Cache[#Cache + 1] = env
103124
Top, Stack[Top] = Top - 1
@@ -106,25 +127,22 @@ local function GetResults (env, ok, a, ...)
106127
collectgarbage()
107128
end
108129
-- TODO: pingpong_gc
109-
if ok and nerrs == 0 then
130+
if ok and not errs then
110131
return a, ...
111132
else
112133
-- Clean up if pingpong
113-
error(not ok and a or ("Errors releasing %i arrays"):format(nerrs))
134+
error(not ok and a or errs)
114135
end
115136
end
116137

117138
--
118139
function M.Add (array_module)
119-
-- Import these here since the array module is not yet registered.
120-
IsArray = array_module.IsArray
121-
122140
--
123-
function array_module.AddToCurrentEnvironment (arr)
141+
function array_module.AddToCurrentEnvironment (elem_type, arr)
124142
local env = Top > 0 and Stack[Top]
125143

126144
if env then
127-
env(_command, "get_list")[arr] = true
145+
env(_command, "get_lists")[elem_type][arr] = true
128146
end
129147
end
130148
-- AddOneEnv
@@ -154,6 +172,13 @@ function M.Add (array_module)
154172

155173
return GetResults(env, pcall(func, env, ...))
156174
end
175+
176+
--
177+
function array_module.RegisterEnvironmentCleanup (elem_type, cleanup, message)
178+
assert(Top == 0 and #Cache == 0, "Attempt to register new environment type after launch")
179+
180+
Types[elem_type] = { cleanup = cleanup, message = message }
181+
end
157182
end
158183

159184
-- Export the module.

0 commit comments

Comments
 (0)