Skip to content

Commit 3bb2e9a

Browse files
committed
refactor: formatting
1 parent bc7ebdf commit 3bb2e9a

File tree

6 files changed

+183
-281
lines changed

6 files changed

+183
-281
lines changed

src/GraphBLAS-sharp.Backend/Common/PrefixSum.fs

Lines changed: 33 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ module PrefixSum =
136136
scanGeneral
137137
<@ fun (resultBuffer: ClArray<'a>) (value: 'a) (inputArrayLength: int) (gid: int) (i: int) ->
138138

139-
if gid < inputArrayLength then resultBuffer.[i] <- value @>
139+
if gid < inputArrayLength then
140+
resultBuffer.[i] <- value @>
140141
<@ fun (resultBuffer: ClArray<'a>) (resultLocalBuffer: 'a []) (inputArrayLength: int) (workGroupSize: int) (gid: int) (i: int) (localID: int) ->
141142

142143
if gid < inputArrayLength
@@ -214,7 +215,7 @@ module PrefixSum =
214215
/// <code>
215216
/// let arr = [| 1; 1; 1; 1 |]
216217
/// let sum = [| 0 |]
217-
/// runExcludeInplace clContext workGroupSize processor arr sum <@ (+) @> 0
218+
/// runExcludeInplace clContext workGroupSize processor arr sum (+) 0
218219
/// |> ignore
219220
/// ...
220221
/// > val arr = [| 0; 1; 2; 3 |]
@@ -240,7 +241,7 @@ module PrefixSum =
240241
/// <code>
241242
/// let arr = [| 1; 1; 1; 1 |]
242243
/// let sum = [| 0 |]
243-
/// runExcludeInplace clContext workGroupSize processor arr sum <@ (+) @> 0
244+
/// runExcludeInplace clContext workGroupSize processor arr sum (+) 0
244245
/// |> ignore
245246
/// ...
246247
/// > val arr = [| 1; 2; 3; 4 |]
@@ -259,76 +260,7 @@ module PrefixSum =
259260

260261
scan processor inputArray 0
261262

262-
263263
module ByKey =
264-
let private oneWorkGroup
265-
writeZero
266-
zero
267-
uniqueKey
268-
(opAdd: Expr<'a -> 'a -> 'a>)
269-
(clContext: ClContext)
270-
workGroupSize
271-
=
272-
273-
let scan =
274-
<@ fun (ndRange: Range1D) length (values: ClArray<'a>) (keys: ClArray<int>) ->
275-
276-
let localValues = localArray<'a> workGroupSize
277-
let localKeys = localArray<int> workGroupSize
278-
279-
let gid = ndRange.GlobalID0
280-
let lid = ndRange.LocalID0
281-
282-
if gid < length then
283-
// only one workgroup
284-
localValues.[lid] <- values.[lid]
285-
localKeys.[lid] <- keys.[gid]
286-
else
287-
localValues.[lid] <- zero
288-
localKeys.[lid] <- uniqueKey
289-
290-
barrierLocal ()
291-
292-
// Local tree reduce
293-
(%SubSum.upSweepByKey opAdd) workGroupSize lid localValues localKeys
294-
295-
// if root item
296-
if lid = workGroupSize - 1
297-
|| localValues.[lid] <> localValues.[lid + 1] then
298-
299-
(%writeZero) localValues lid zero
300-
301-
(%SubSum.downSweepByKey opAdd) workGroupSize lid localValues localKeys
302-
303-
barrierLocal ()
304-
305-
values.[lid] <- localValues.[lid] @>
306-
307-
let program = clContext.Compile(scan)
308-
309-
fun (processor: MailboxProcessor<_>) (keys: ClArray<int>) (values: ClArray<'a>) ->
310-
311-
let kernel = program.GetKernel()
312-
313-
let ndRange =
314-
Range1D.CreateValid(values.Length, workGroupSize)
315-
316-
processor.Post(
317-
Msg.MsgSetArguments
318-
(fun () ->
319-
kernel.KernelFunc
320-
ndRange
321-
values.Length
322-
values
323-
keys)
324-
)
325-
326-
processor.Post(Msg.CreateRunMsg<_, _> kernel)
327-
328-
let oneWorkGroupExclude zero = oneWorkGroup <@ (fun _ _ _ -> ()) @> zero
329-
330-
let onwWorkGroupInclude zero = oneWorkGroup <@ (fun localValues lid zero -> localValues.[lid] <- zero) @> zero
331-
332264
let private sequentialSegments opWrite (clContext: ClContext) workGroupSize opAdd zero =
333265

334266
let kernel =
@@ -345,7 +277,7 @@ module PrefixSum =
345277
let mutable currentPosition = sourcePosition
346278

347279
while currentPosition < lenght
348-
&& keys.[currentPosition] = sourceKey do
280+
&& keys.[currentPosition] = sourceKey do
349281

350282
previousSum <- currentSum
351283
currentSum <- (%opAdd) currentSum values.[currentPosition]
@@ -365,21 +297,35 @@ module PrefixSum =
365297

366298
processor.Post(
367299
Msg.MsgSetArguments
368-
(fun () ->
369-
kernel.KernelFunc
370-
ndRange
371-
values.Length
372-
uniqueKeysCount
373-
values
374-
keys
375-
offsets)
300+
(fun () -> kernel.KernelFunc ndRange values.Length uniqueKeysCount values keys offsets)
376301
)
377302

378303
processor.Post(Msg.CreateRunMsg<_, _> kernel)
379304

380-
381-
let sequentialExclude clContext = sequentialSegments (Map.fst ()) clContext
382-
383-
let sequentialInclude clContext = sequentialSegments (Map.snd ()) clContext
384-
385-
305+
/// <summary>
306+
/// Exclude scan by key.
307+
/// </summary>
308+
/// <example>
309+
/// <code>
310+
/// let arr = [| 1; 1; 1; 1; 1; 1|]
311+
/// let keys = [| 1; 2; 2; 2; 3; 3 |]
312+
/// ...
313+
/// > val result = [| 0; 0; 1; 2; 0; 1 |]
314+
/// </code>
315+
/// </example>
316+
let sequentialExclude clContext =
317+
sequentialSegments (Map.fst ()) clContext
318+
319+
/// <summary>
320+
/// Include scan by key.
321+
/// </summary>
322+
/// <example>
323+
/// <code>
324+
/// let arr = [| 1; 1; 1; 1; 1; 1|]
325+
/// let keys = [| 1; 2; 2; 2; 3; 3 |]
326+
/// ...
327+
/// > val result = [| 1; 1; 2; 3; 1; 2 |]
328+
/// </code>
329+
/// </example>
330+
let sequentialInclude clContext =
331+
sequentialSegments (Map.snd ()) clContext

src/GraphBLAS-sharp.Backend/Quotes/SubSum.fs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ module SubSum =
5757
step <- step >>> 1 @>
5858

5959
let upSweepByKey opAdd =
60-
<@ fun wgSize lid (localBuffer: 'a []) (localKeys: 'b [])->
60+
<@ fun wgSize lid (localBuffer: 'a []) (localKeys: 'b []) ->
6161
let mutable step = 2
6262

6363
while step <= wgSize do
@@ -69,8 +69,7 @@ module SubSum =
6969
let firstKey = localKeys.[firstIndex]
7070
let secondKey = localKeys.[secondIndex]
7171

72-
if lid < wgSize / step
73-
&& firstKey = secondKey then
72+
if lid < wgSize / step && firstKey = secondKey then
7473

7574
let firstValue = localBuffer.[firstIndex]
7675
let secondValue = localBuffer.[secondIndex]
@@ -94,8 +93,7 @@ module SubSum =
9493
let rightKey = localKeys.[rightIndex]
9594
let leftKey = localKeys.[leftIndex]
9695

97-
if lid < wgSize / step
98-
&& rightKey = leftKey then
96+
if lid < wgSize / step && rightKey = leftKey then
9997

10098
let tmp = localBuffer.[rightIndex]
10199

tests/GraphBLAS-sharp.Tests/Common/Scan/ByKey.fs

Lines changed: 39 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -10,53 +10,33 @@ let context = Context.defaultContext.ClContext
1010

1111
let processor = Context.defaultContext.Queue
1212

13-
let scanByKey scan keysAndValues =
14-
// select keys
15-
Array.map fst keysAndValues
16-
// get unique keys
17-
|> Array.distinct
18-
|> Array.map (fun key ->
19-
// select with certain key
20-
Array.filter (fst >> ((=) key)) keysAndValues
21-
// get values
22-
|> Array.map snd
23-
// scan values and get only values without sum
24-
|> (fst << scan))
25-
|> Array.concat
26-
2713
let checkResult isEqual keysAndValues actual hostScan =
2814

29-
let expected = scanByKey hostScan keysAndValues
30-
31-
let keys, values = Array.unzip keysAndValues
32-
printfn "---------------"
33-
34-
printfn "keys: %A" keys
35-
printfn "values: %A" values
36-
printfn $"expected: %A{expected}"
37-
38-
printfn "-----------"
15+
let expected =
16+
HostPrimitives.scanByKey hostScan keysAndValues
3917

4018
"Results must be the same"
4119
|> Utils.compareArrays isEqual actual expected
4220

4321
let makeTestSequentialSegments isEqual scanHost scanDevice (keysAndValues: (int * 'a) []) =
4422
if keysAndValues.Length > 0 then
4523
let keys, values =
46-
Array.sortBy fst keysAndValues
47-
|> Array.unzip
24+
Array.sortBy fst keysAndValues |> Array.unzip
4825

4926
let offsets =
5027
HostPrimitives.getUniqueBitmapFirstOccurrence keys
5128
|> HostPrimitives.getBitPositions
5229

5330
let uniqueKeysCount = Array.distinct keys |> Array.length
5431

55-
let clKeys = context.CreateClArrayWithSpecificAllocationMode(HostInterop, keys)
32+
let clKeys =
33+
context.CreateClArrayWithSpecificAllocationMode(HostInterop, keys)
5634

57-
let clValues = context.CreateClArrayWithSpecificAllocationMode(HostInterop, values)
35+
let clValues =
36+
context.CreateClArrayWithSpecificAllocationMode(HostInterop, values)
5837

59-
let clOffsets = context.CreateClArrayWithSpecificAllocationMode(HostInterop, offsets)
38+
let clOffsets =
39+
context.CreateClArrayWithSpecificAllocationMode(HostInterop, offsets)
6040

6141
scanDevice processor uniqueKeysCount clValues clKeys clOffsets
6242

@@ -83,70 +63,21 @@ let sequentialSegmentsTests =
8363
[ createTest 0 <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude
8464

8565
if Utils.isFloat64Available context.ClDevice then
86-
createTest 0.0 <@ (+) @> (+) Utils.floatIsEqual PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude
87-
88-
createTest 0.0f <@ (+) @> (+) Utils.float32IsEqual PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude
89-
90-
createTest false <@ (||) @> (||) (=) PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude
91-
createTest 0u <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude ]
92-
|> testList "exclude"
93-
94-
let includeTests =
95-
[ createTest 0 <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude
96-
97-
if Utils.isFloat64Available context.ClDevice then
98-
createTest 0.0 <@ (+) @> (+) Utils.floatIsEqual PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude
99-
100-
createTest 0.0f <@ (+) @> (+) Utils.float32IsEqual PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude
101-
102-
createTest false <@ (||) @> (||) (=) PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude
103-
createTest 0u <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude ]
104-
105-
|> testList "include"
106-
107-
testList "Sequential segments" [ excludeTests; includeTests ]
108-
109-
let makeTestOneWorkGroup isEqual scanHost scanDevice (keysAndValues: (int * 'a) []) =
110-
if keysAndValues.Length > 0 then
111-
let keys, values =
112-
Array.sortBy fst keysAndValues
113-
|> Array.unzip
114-
115-
let uniqueKeysCount = Array.distinct keys |> Array.length
116-
117-
let clKeys = context.CreateClArrayWithSpecificAllocationMode(HostInterop, keys)
118-
119-
let clValues = context.CreateClArrayWithSpecificAllocationMode(HostInterop, values)
120-
121-
scanDevice processor uniqueKeysCount clValues clKeys
122-
123-
let actual = clValues.ToHostAndFree processor
124-
clKeys.Free processor
125-
126-
let keysAndValues = Array.zip keys values
127-
128-
checkResult isEqual keysAndValues actual scanHost
129-
130-
let oneWorkGroupCreateTest (zero: 'a) opAddQ opAdd isEqual deviceScan hostScan =
131-
132-
let workGroupSize = 256
133-
134-
let hostScan = hostScan zero opAdd
135-
136-
let deviceScan =
137-
deviceScan context workGroupSize opAddQ zero
138-
139-
makeTestSequentialSegments isEqual hostScan deviceScan
140-
|> testPropertyWithConfig { Utils.defaultConfig with endSize = workGroupSize } $"test on {typeof<'a>}"
141-
142-
let oneWorkGroupTests =
143-
let excludeTests =
144-
[ createTest 0 <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude
145-
146-
if Utils.isFloat64Available context.ClDevice then
147-
createTest 0.0 <@ (+) @> (+) Utils.floatIsEqual PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude
148-
149-
createTest 0.0f <@ (+) @> (+) Utils.float32IsEqual PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude
66+
createTest
67+
0.0
68+
<@ (+) @>
69+
(+)
70+
Utils.floatIsEqual
71+
PrefixSum.ByKey.sequentialExclude
72+
HostPrimitives.prefixSumExclude
73+
74+
createTest
75+
0.0f
76+
<@ (+) @>
77+
(+)
78+
Utils.float32IsEqual
79+
PrefixSum.ByKey.sequentialExclude
80+
HostPrimitives.prefixSumExclude
15081

15182
createTest false <@ (||) @> (||) (=) PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude
15283
createTest 0u <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude ]
@@ -156,19 +87,25 @@ let oneWorkGroupTests =
15687
[ createTest 0 <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude
15788

15889
if Utils.isFloat64Available context.ClDevice then
159-
createTest 0.0 <@ (+) @> (+) Utils.floatIsEqual PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude
160-
161-
createTest 0.0f <@ (+) @> (+) Utils.float32IsEqual PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude
90+
createTest
91+
0.0
92+
<@ (+) @>
93+
(+)
94+
Utils.floatIsEqual
95+
PrefixSum.ByKey.sequentialInclude
96+
HostPrimitives.prefixSumInclude
97+
98+
createTest
99+
0.0f
100+
<@ (+) @>
101+
(+)
102+
Utils.float32IsEqual
103+
PrefixSum.ByKey.sequentialInclude
104+
HostPrimitives.prefixSumInclude
162105

163106
createTest false <@ (||) @> (||) (=) PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude
164107
createTest 0u <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude ]
165108

166109
|> testList "include"
167110

168111
testList "Sequential segments" [ excludeTests; includeTests ]
169-
170-
171-
172-
173-
174-

tests/GraphBLAS-sharp.Tests/GraphBLAS-sharp.Tests.fsproj

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
<Compile Include="Common/Reduce/Sum.fs" />
3030
<Compile Include="Common/Reduce/Reduce.fs" />
3131
<Compile Include="Common/Reduce/ReduceByKey.fs" />
32-
<Compile Include="Common\Scan\PrefixSum.fs" />
33-
<Compile Include="Common\Scan\ByKey.fs" />
32+
<Compile Include="Common/Scan/PrefixSum.fs" />
33+
<Compile Include="Common/Scan/ByKey.fs" />
3434
<!--Compile Include="MatrixOperationsTests/GetTuplesTests.fs" /-->
3535
<!--Compile Include="MatrixOperationsTests/MxvTests.fs" /-->
3636
<!--Compile Include="MatrixOperationsTests/VxmTests.fs" /-->

0 commit comments

Comments
 (0)