@@ -3,6 +3,8 @@ namespace GraphBLAS.FSharp.Backend.Common
33open Brahma.FSharp
44open FSharp.Quotations
55open GraphBLAS.FSharp .Backend .Quotes
6+ open GraphBLAS.FSharp .Backend .Objects .ArraysExtensions
7+ open GraphBLAS.FSharp .Backend .Objects .ClCell
68
79module PrefixSum =
810 let private update ( opAdd : Expr < 'a -> 'a -> 'a >) ( clContext : ClContext ) workGroupSize =
@@ -38,7 +40,7 @@ module PrefixSum =
3840 )
3941
4042 processor.Post( Msg.CreateRunMsg<_, _> kernel)
41- processor.Post ( Msg.CreateFreeMsg ( mirror))
43+ mirror.Free processor
4244
4345 let private scanGeneral
4446 beforeLocalSumClear
@@ -48,10 +50,8 @@ module PrefixSum =
4850 workGroupSize
4951 =
5052
51- let subSum = SubSum.treeSum opAdd
52-
5353 let scan =
54- <@ fun ( ndRange : Range1D ) inputArrayLength verticesLength ( resultBuffer : ClArray < 'a >) ( verticesBuffer : ClArray < 'a >) ( totalSumBuffer : ClCell < 'a >) ( zero : ClCell < 'a >) ( mirror : ClCell < bool >) ->
54+ <@ fun ( ndRange : Range1D ) inputArrayLength verticesLength ( inputArray : ClArray < 'a >) ( verticesBuffer : ClArray < 'a >) ( totalSumBuffer : ClCell < 'a >) ( zero : ClCell < 'a >) ( mirror : ClCell < bool >) ->
5555
5656 let mirror = mirror.Value
5757
@@ -62,46 +62,34 @@ module PrefixSum =
6262 if mirror then
6363 i <- inputArrayLength - 1 - i
6464
65- let localID = ndRange.LocalID0
65+ let lid = ndRange.LocalID0
6666
6767 let zero = zero.Value
6868
6969 if gid < inputArrayLength then
70- resultLocalBuffer.[ localID ] <- resultBuffer .[ i]
70+ resultLocalBuffer.[ lid ] <- inputArray .[ i]
7171 else
72- resultLocalBuffer.[ localID ] <- zero
72+ resultLocalBuffer.[ lid ] <- zero
7373
7474 barrierLocal ()
7575
76- (% subSum) workGroupSize localID resultLocalBuffer
77-
78- if localID = workGroupSize - 1 then
79- if verticesLength <= 1 && localID = gid then
80- totalSumBuffer.Value <- resultLocalBuffer.[ localID]
81-
82- verticesBuffer.[ gid / workGroupSize] <- resultLocalBuffer.[ localID]
83- (% beforeLocalSumClear) resultBuffer resultLocalBuffer.[ localID] inputArrayLength gid i
84- resultLocalBuffer.[ localID] <- zero
76+ // Local tree reduce
77+ (% SubSum.upSweep opAdd) workGroupSize lid resultLocalBuffer
8578
86- let mutable step = workGroupSize
79+ if lid = workGroupSize - 1 then
80+ // if last iteration
81+ if verticesLength <= 1 && lid = gid then
82+ totalSumBuffer.Value <- resultLocalBuffer.[ lid]
8783
88- while step > 1 do
89- barrierLocal ()
84+ verticesBuffer.[ gid / workGroupSize] <- resultLocalBuffer.[ lid]
85+ (% beforeLocalSumClear) inputArray resultLocalBuffer.[ lid] inputArrayLength gid i
86+ resultLocalBuffer.[ lid] <- zero
9087
91- if localID < workGroupSize / step then
92- let i = step * ( localID + 1 ) - 1
93- let j = i - ( step >>> 1 )
94-
95- let tmp = resultLocalBuffer.[ i]
96- let buff = (% opAdd) tmp resultLocalBuffer.[ j]
97- resultLocalBuffer.[ i] <- buff
98- resultLocalBuffer.[ j] <- tmp
99-
100- step <- step >>> 1
88+ (% SubSum.downSweep opAdd) workGroupSize lid resultLocalBuffer
10189
10290 barrierLocal ()
10391
104- (% writeData) resultBuffer resultLocalBuffer inputArrayLength workGroupSize gid i localID @>
92+ (% writeData) inputArray resultLocalBuffer inputArrayLength workGroupSize gid i lid @>
10593
10694 let program = clContext.Compile( scan)
10795
@@ -132,13 +120,14 @@ module PrefixSum =
132120 )
133121
134122 processor.Post( Msg.CreateRunMsg<_, _> kernel)
135- processor.Post( Msg.CreateFreeMsg( zero))
136- processor.Post( Msg.CreateFreeMsg( mirror))
123+
124+ zero.Free processor
125+ mirror.Free processor
137126
138127 let private scanExclusive < 'a when 'a : struct > =
139128 scanGeneral
140129 <@ fun ( _ : ClArray < 'a >) ( _ : 'a ) ( _ : int ) ( _ : int ) ( _ : int ) -> () @>
141- <@ fun ( resultBuffer : ClArray < 'a >) ( resultLocalBuffer : 'a []) ( inputArrayLength : int ) ( smth : int ) ( gid : int ) ( i : int ) ( localID : int ) ->
130+ <@ fun ( resultBuffer : ClArray < 'a >) ( resultLocalBuffer : 'a []) ( inputArrayLength : int ) ( _ : int ) ( gid : int ) ( i : int ) ( localID : int ) ->
142131
143132 if gid < inputArrayLength then
144133 resultBuffer.[ i] <- resultLocalBuffer.[ localID] @>
@@ -206,8 +195,8 @@ module PrefixSum =
206195 verticesArrays <- swap verticesArrays
207196 verticesLength <- ( verticesLength - 1 ) / workGroupSize + 1
208197
209- processor.Post ( Msg.CreateFreeMsg ( firstVertices))
210- processor.Post ( Msg.CreateFreeMsg ( secondVertices))
198+ firstVertices.Free processor
199+ secondVertices.Free processor
211200
212201 totalSum
213202
@@ -226,7 +215,7 @@ module PrefixSum =
226215 /// <code>
227216 /// let arr = [| 1; 1; 1; 1 |]
228217 /// let sum = [| 0 |]
229- /// runExcludeInplace clContext workGroupSize processor arr sum <@ (+) @> 0
218+ /// runExcludeInplace clContext workGroupSize processor arr sum (+) 0
230219 /// |> ignore
231220 /// ...
232221 /// > val arr = [| 0; 1; 2; 3 |]
@@ -252,7 +241,7 @@ module PrefixSum =
252241 /// <code>
253242 /// let arr = [| 1; 1; 1; 1 |]
254243 /// let sum = [| 0 |]
255- /// runExcludeInplace clContext workGroupSize processor arr sum <@ (+) @> 0
244+ /// runExcludeInplace clContext workGroupSize processor arr sum (+) 0
256245 /// |> ignore
257246 /// ...
258247 /// > val arr = [| 1; 2; 3; 4 |]
@@ -270,3 +259,73 @@ module PrefixSum =
270259 fun ( processor : MailboxProcessor < _ >) ( inputArray : ClArray < int >) ->
271260
272261 scan processor inputArray 0
262+
263+ module ByKey =
264+ let private sequentialSegments opWrite ( clContext : ClContext ) workGroupSize opAdd zero =
265+
266+ let kernel =
267+ <@ fun ( ndRange : Range1D ) lenght uniqueKeysCount ( values : ClArray < 'a >) ( keys : ClArray < int >) ( offsets : ClArray < int >) ->
268+ let gid = ndRange.GlobalID0
269+
270+ if gid < uniqueKeysCount then
271+ let sourcePosition = offsets.[ gid]
272+ let sourceKey = keys.[ sourcePosition]
273+
274+ let mutable currentSum = zero
275+ let mutable previousSum = zero
276+
277+ let mutable currentPosition = sourcePosition
278+
279+ while currentPosition < lenght
280+ && keys.[ currentPosition] = sourceKey do
281+
282+ previousSum <- currentSum
283+ currentSum <- (% opAdd) currentSum values.[ currentPosition]
284+
285+ values.[ currentPosition] <- (% opWrite) previousSum currentSum
286+
287+ currentPosition <- currentPosition + 1 @>
288+
289+ let kernel = clContext.Compile kernel
290+
291+ fun ( processor : MailboxProcessor < _ >) uniqueKeysCount ( values : ClArray < 'a >) ( keys : ClArray < int >) ( offsets : ClArray < int >) ->
292+
293+ let kernel = kernel.GetKernel()
294+
295+ let ndRange =
296+ Range1D.CreateValid( values.Length, workGroupSize)
297+
298+ processor.Post(
299+ Msg.MsgSetArguments
300+ ( fun () -> kernel.KernelFunc ndRange values.Length uniqueKeysCount values keys offsets)
301+ )
302+
303+ processor.Post( Msg.CreateRunMsg<_, _> kernel)
304+
305+ /// <summary>
306+ /// Exclude scan by key.
307+ /// </summary>
308+ /// <example>
309+ /// <code>
310+ /// let arr = [| 1; 1; 1; 1; 1; 1|]
311+ /// let keys = [| 1; 2; 2; 2; 3; 3 |]
312+ /// ...
313+ /// > val result = [| 0; 0; 1; 2; 0; 1 |]
314+ /// </code>
315+ /// </example>
316+ let sequentialExclude clContext =
317+ sequentialSegments ( Map.fst ()) clContext
318+
319+ /// <summary>
320+ /// Include scan by key.
321+ /// </summary>
322+ /// <example>
323+ /// <code>
324+ /// let arr = [| 1; 1; 1; 1; 1; 1|]
325+ /// let keys = [| 1; 2; 2; 2; 3; 3 |]
326+ /// ...
327+ /// > val result = [| 1; 1; 2; 3; 1; 2 |]
328+ /// </code>
329+ /// </example>
330+ let sequentialInclude clContext =
331+ sequentialSegments ( Map.snd ()) clContext
0 commit comments