Skip to content

Commit e850a49

Browse files
authored
Merge pull request #71 from IgorErin/scan
Scan by key
2 parents 7885e8a + 420e2b8 commit e850a49

File tree

8 files changed

+265
-61
lines changed

8 files changed

+265
-61
lines changed

src/GraphBLAS-sharp.Backend/Common/PrefixSum.fs

Lines changed: 96 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ namespace GraphBLAS.FSharp.Backend.Common
33
open Brahma.FSharp
44
open FSharp.Quotations
55
open GraphBLAS.FSharp.Backend.Quotes
6+
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
7+
open GraphBLAS.FSharp.Backend.Objects.ClCell
68

79
module PrefixSum =
810
let private update (opAdd: Expr<'a -> 'a -> 'a>) (clContext: ClContext) workGroupSize =
@@ -38,7 +40,7 @@ module PrefixSum =
3840
)
3941

4042
processor.Post(Msg.CreateRunMsg<_, _> kernel)
41-
processor.Post(Msg.CreateFreeMsg(mirror))
43+
mirror.Free processor
4244

4345
let private scanGeneral
4446
beforeLocalSumClear
@@ -48,10 +50,8 @@ module PrefixSum =
4850
workGroupSize
4951
=
5052

51-
let subSum = SubSum.treeSum opAdd
52-
5353
let scan =
54-
<@ fun (ndRange: Range1D) inputArrayLength verticesLength (resultBuffer: ClArray<'a>) (verticesBuffer: ClArray<'a>) (totalSumBuffer: ClCell<'a>) (zero: ClCell<'a>) (mirror: ClCell<bool>) ->
54+
<@ fun (ndRange: Range1D) inputArrayLength verticesLength (inputArray: ClArray<'a>) (verticesBuffer: ClArray<'a>) (totalSumBuffer: ClCell<'a>) (zero: ClCell<'a>) (mirror: ClCell<bool>) ->
5555

5656
let mirror = mirror.Value
5757

@@ -62,46 +62,34 @@ module PrefixSum =
6262
if mirror then
6363
i <- inputArrayLength - 1 - i
6464

65-
let localID = ndRange.LocalID0
65+
let lid = ndRange.LocalID0
6666

6767
let zero = zero.Value
6868

6969
if gid < inputArrayLength then
70-
resultLocalBuffer.[localID] <- resultBuffer.[i]
70+
resultLocalBuffer.[lid] <- inputArray.[i]
7171
else
72-
resultLocalBuffer.[localID] <- zero
72+
resultLocalBuffer.[lid] <- zero
7373

7474
barrierLocal ()
7575

76-
(%subSum) workGroupSize localID resultLocalBuffer
77-
78-
if localID = workGroupSize - 1 then
79-
if verticesLength <= 1 && localID = gid then
80-
totalSumBuffer.Value <- resultLocalBuffer.[localID]
81-
82-
verticesBuffer.[gid / workGroupSize] <- resultLocalBuffer.[localID]
83-
(%beforeLocalSumClear) resultBuffer resultLocalBuffer.[localID] inputArrayLength gid i
84-
resultLocalBuffer.[localID] <- zero
76+
// Local tree reduce
77+
(%SubSum.upSweep opAdd) workGroupSize lid resultLocalBuffer
8578

86-
let mutable step = workGroupSize
79+
if lid = workGroupSize - 1 then
80+
// if last iteration
81+
if verticesLength <= 1 && lid = gid then
82+
totalSumBuffer.Value <- resultLocalBuffer.[lid]
8783

88-
while step > 1 do
89-
barrierLocal ()
84+
verticesBuffer.[gid / workGroupSize] <- resultLocalBuffer.[lid]
85+
(%beforeLocalSumClear) inputArray resultLocalBuffer.[lid] inputArrayLength gid i
86+
resultLocalBuffer.[lid] <- zero
9087

91-
if localID < workGroupSize / step then
92-
let i = step * (localID + 1) - 1
93-
let j = i - (step >>> 1)
94-
95-
let tmp = resultLocalBuffer.[i]
96-
let buff = (%opAdd) tmp resultLocalBuffer.[j]
97-
resultLocalBuffer.[i] <- buff
98-
resultLocalBuffer.[j] <- tmp
99-
100-
step <- step >>> 1
88+
(%SubSum.downSweep opAdd) workGroupSize lid resultLocalBuffer
10189

10290
barrierLocal ()
10391

104-
(%writeData) resultBuffer resultLocalBuffer inputArrayLength workGroupSize gid i localID @>
92+
(%writeData) inputArray resultLocalBuffer inputArrayLength workGroupSize gid i lid @>
10593

10694
let program = clContext.Compile(scan)
10795

@@ -132,13 +120,14 @@ module PrefixSum =
132120
)
133121

134122
processor.Post(Msg.CreateRunMsg<_, _> kernel)
135-
processor.Post(Msg.CreateFreeMsg(zero))
136-
processor.Post(Msg.CreateFreeMsg(mirror))
123+
124+
zero.Free processor
125+
mirror.Free processor
137126

138127
let private scanExclusive<'a when 'a: struct> =
139128
scanGeneral
140129
<@ fun (_: ClArray<'a>) (_: 'a) (_: int) (_: int) (_: int) -> () @>
141-
<@ fun (resultBuffer: ClArray<'a>) (resultLocalBuffer: 'a []) (inputArrayLength: int) (smth: int) (gid: int) (i: int) (localID: int) ->
130+
<@ fun (resultBuffer: ClArray<'a>) (resultLocalBuffer: 'a []) (inputArrayLength: int) (_: int) (gid: int) (i: int) (localID: int) ->
142131

143132
if gid < inputArrayLength then
144133
resultBuffer.[i] <- resultLocalBuffer.[localID] @>
@@ -206,8 +195,8 @@ module PrefixSum =
206195
verticesArrays <- swap verticesArrays
207196
verticesLength <- (verticesLength - 1) / workGroupSize + 1
208197

209-
processor.Post(Msg.CreateFreeMsg(firstVertices))
210-
processor.Post(Msg.CreateFreeMsg(secondVertices))
198+
firstVertices.Free processor
199+
secondVertices.Free processor
211200

212201
totalSum
213202

@@ -226,7 +215,7 @@ module PrefixSum =
226215
/// <code>
227216
/// let arr = [| 1; 1; 1; 1 |]
228217
/// let sum = [| 0 |]
229-
/// runExcludeInplace clContext workGroupSize processor arr sum <@ (+) @> 0
218+
/// runExcludeInplace clContext workGroupSize processor arr sum (+) 0
230219
/// |> ignore
231220
/// ...
232221
/// > val arr = [| 0; 1; 2; 3 |]
@@ -252,7 +241,7 @@ module PrefixSum =
252241
/// <code>
253242
/// let arr = [| 1; 1; 1; 1 |]
254243
/// let sum = [| 0 |]
255-
/// runExcludeInplace clContext workGroupSize processor arr sum <@ (+) @> 0
244+
/// runExcludeInplace clContext workGroupSize processor arr sum (+) 0
256245
/// |> ignore
257246
/// ...
258247
/// > val arr = [| 1; 2; 3; 4 |]
@@ -270,3 +259,73 @@ module PrefixSum =
270259
fun (processor: MailboxProcessor<_>) (inputArray: ClArray<int>) ->
271260

272261
scan processor inputArray 0
262+
263+
module ByKey =
264+
let private sequentialSegments opWrite (clContext: ClContext) workGroupSize opAdd zero =
265+
266+
let kernel =
267+
<@ fun (ndRange: Range1D) lenght uniqueKeysCount (values: ClArray<'a>) (keys: ClArray<int>) (offsets: ClArray<int>) ->
268+
let gid = ndRange.GlobalID0
269+
270+
if gid < uniqueKeysCount then
271+
let sourcePosition = offsets.[gid]
272+
let sourceKey = keys.[sourcePosition]
273+
274+
let mutable currentSum = zero
275+
let mutable previousSum = zero
276+
277+
let mutable currentPosition = sourcePosition
278+
279+
while currentPosition < lenght
280+
&& keys.[currentPosition] = sourceKey do
281+
282+
previousSum <- currentSum
283+
currentSum <- (%opAdd) currentSum values.[currentPosition]
284+
285+
values.[currentPosition] <- (%opWrite) previousSum currentSum
286+
287+
currentPosition <- currentPosition + 1 @>
288+
289+
let kernel = clContext.Compile kernel
290+
291+
fun (processor: MailboxProcessor<_>) uniqueKeysCount (values: ClArray<'a>) (keys: ClArray<int>) (offsets: ClArray<int>) ->
292+
293+
let kernel = kernel.GetKernel()
294+
295+
let ndRange =
296+
Range1D.CreateValid(values.Length, workGroupSize)
297+
298+
processor.Post(
299+
Msg.MsgSetArguments
300+
(fun () -> kernel.KernelFunc ndRange values.Length uniqueKeysCount values keys offsets)
301+
)
302+
303+
processor.Post(Msg.CreateRunMsg<_, _> kernel)
304+
305+
/// <summary>
306+
/// Exclude scan by key.
307+
/// </summary>
308+
/// <example>
309+
/// <code>
310+
/// let arr = [| 1; 1; 1; 1; 1; 1|]
311+
/// let keys = [| 1; 2; 2; 2; 3; 3 |]
312+
/// ...
313+
/// > val result = [| 0; 0; 1; 2; 0; 1 |]
314+
/// </code>
315+
/// </example>
316+
let sequentialExclude clContext =
317+
sequentialSegments (Map.fst ()) clContext
318+
319+
/// <summary>
320+
/// Include scan by key.
321+
/// </summary>
322+
/// <example>
323+
/// <code>
324+
/// let arr = [| 1; 1; 1; 1; 1; 1|]
325+
/// let keys = [| 1; 2; 2; 2; 3; 3 |]
326+
/// ...
327+
/// > val result = [| 1; 1; 2; 3; 1; 2 |]
328+
/// </code>
329+
/// </example>
330+
let sequentialInclude clContext =
331+
sequentialSegments (Map.snd ()) clContext

src/GraphBLAS-sharp.Backend/Quotes/Map.fs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,7 @@ module Map =
2121
match (%map) item with
2222
| Some _ -> 1
2323
| None -> 0 @>
24+
25+
let fst () = <@ fun fst _ -> fst @>
26+
27+
let snd () = <@ fun _ snd -> snd @>

src/GraphBLAS-sharp.Backend/Quotes/SubSum.fs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,30 @@ module SubSum =
3131

3232
barrierLocal () @>
3333

34-
let sequentialSum<'a> opAdd =
35-
sumGeneral<'a> <| sequentialAccess<'a> opAdd
34+
let sequentialSum<'a> = sumGeneral<'a> << sequentialAccess<'a>
3635

37-
let treeSum<'a> opAdd = sumGeneral<'a> <| treeAccess<'a> opAdd
36+
let upSweep<'a> = sumGeneral<'a> << treeAccess<'a>
37+
38+
let downSweep opAdd =
39+
<@ fun wgSize lid (localBuffer: 'a []) ->
40+
let mutable step = wgSize
41+
42+
while step > 1 do
43+
barrierLocal ()
44+
45+
if lid < wgSize / step then
46+
let i = step * (lid + 1) - 1
47+
let j = i - (step >>> 1)
48+
49+
let tmp = localBuffer.[i]
50+
51+
let operand = localBuffer.[j] // brahma error
52+
let buff = (%opAdd) tmp operand
53+
54+
localBuffer.[i] <- buff
55+
localBuffer.[j] <- tmp
56+
57+
step <- step >>> 1 @>
3858

3959
let localPrefixSum opAdd =
4060
<@ fun (lid: int) (workGroupSize: int) (array: 'a []) ->
@@ -52,4 +72,6 @@ module SubSum =
5272
barrierLocal ()
5373
array.[lid] <- value @>
5474

75+
76+
5577
let localIntPrefixSum = localPrefixSum <@ (+) @>
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
module GraphBLAS.FSharp.Tests.Backend.Common.Scan.ByKey
2+
3+
open GraphBLAS.FSharp.Backend.Common
4+
open GraphBLAS.FSharp.Backend.Objects.ClContext
5+
open Expecto
6+
open GraphBLAS.FSharp.Tests
7+
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
8+
9+
let context = Context.defaultContext.ClContext
10+
11+
let processor = Context.defaultContext.Queue
12+
13+
let checkResult isEqual keysAndValues actual hostScan =
14+
15+
let expected =
16+
HostPrimitives.scanByKey hostScan keysAndValues
17+
18+
"Results must be the same"
19+
|> Utils.compareArrays isEqual actual expected
20+
21+
let makeTestSequentialSegments isEqual scanHost scanDevice (keysAndValues: (int * 'a) []) =
22+
if keysAndValues.Length > 0 then
23+
let keys, values =
24+
Array.sortBy fst keysAndValues |> Array.unzip
25+
26+
let offsets =
27+
HostPrimitives.getUniqueBitmapFirstOccurrence keys
28+
|> HostPrimitives.getBitPositions
29+
30+
let uniqueKeysCount = Array.distinct keys |> Array.length
31+
32+
let clKeys =
33+
context.CreateClArrayWithSpecificAllocationMode(HostInterop, keys)
34+
35+
let clValues =
36+
context.CreateClArrayWithSpecificAllocationMode(HostInterop, values)
37+
38+
let clOffsets =
39+
context.CreateClArrayWithSpecificAllocationMode(HostInterop, offsets)
40+
41+
scanDevice processor uniqueKeysCount clValues clKeys clOffsets
42+
43+
let actual = clValues.ToHostAndFree processor
44+
clKeys.Free processor
45+
clOffsets.Free processor
46+
47+
let keysAndValues = Array.zip keys values
48+
49+
checkResult isEqual keysAndValues actual scanHost
50+
51+
let createTest (zero: 'a) opAddQ opAdd isEqual deviceScan hostScan =
52+
53+
let hostScan = hostScan zero opAdd
54+
55+
let deviceScan =
56+
deviceScan context Utils.defaultWorkGroupSize opAddQ zero
57+
58+
makeTestSequentialSegments isEqual hostScan deviceScan
59+
|> testPropertyWithConfig Utils.defaultConfig $"test on {typeof<'a>}"
60+
61+
let sequentialSegmentsTests =
62+
let excludeTests =
63+
[ createTest 0 <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude
64+
65+
if Utils.isFloat64Available context.ClDevice then
66+
createTest
67+
0.0
68+
<@ (+) @>
69+
(+)
70+
Utils.floatIsEqual
71+
PrefixSum.ByKey.sequentialExclude
72+
HostPrimitives.prefixSumExclude
73+
74+
createTest
75+
0.0f
76+
<@ (+) @>
77+
(+)
78+
Utils.float32IsEqual
79+
PrefixSum.ByKey.sequentialExclude
80+
HostPrimitives.prefixSumExclude
81+
82+
createTest false <@ (||) @> (||) (=) PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude
83+
createTest 0u <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude ]
84+
|> testList "exclude"
85+
86+
let includeTests =
87+
[ createTest 0 <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude
88+
89+
if Utils.isFloat64Available context.ClDevice then
90+
createTest
91+
0.0
92+
<@ (+) @>
93+
(+)
94+
Utils.floatIsEqual
95+
PrefixSum.ByKey.sequentialInclude
96+
HostPrimitives.prefixSumInclude
97+
98+
createTest
99+
0.0f
100+
<@ (+) @>
101+
(+)
102+
Utils.float32IsEqual
103+
PrefixSum.ByKey.sequentialInclude
104+
HostPrimitives.prefixSumInclude
105+
106+
createTest false <@ (||) @> (||) (=) PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude
107+
createTest 0u <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude ]
108+
109+
|> testList "include"
110+
111+
testList "Sequential segments" [ excludeTests; includeTests ]

tests/GraphBLAS-sharp.Tests/Common/ClArray/PrefixSum.fs renamed to tests/GraphBLAS-sharp.Tests/Common/Scan/PrefixSum.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module GraphBLAS.FSharp.Tests.Backend.Common.ClArray.PrefixSum
1+
module GraphBLAS.FSharp.Tests.Backend.Common.Scan.PrefixSum
22

33
open Expecto
44
open Expecto.Logging
@@ -62,7 +62,7 @@ let makeTest plus zero isEqual scan (array: 'a []) =
6262
let testFixtures plus plusQ zero isEqual name =
6363
PrefixSum.runIncludeInplace plusQ context wgSize
6464
|> makeTest plus zero isEqual
65-
|> testPropertyWithConfig config (sprintf "Correctness on %s" name)
65+
|> testPropertyWithConfig config $"Correctness on %s{name}"
6666

6767
let tests =
6868
q.Error.Add(fun e -> failwithf "%A" e)

0 commit comments

Comments
 (0)