Skip to content

Commit 2b4a5b7

Browse files
committed
add: Vector.Merge
1 parent e0485c5 commit 2b4a5b7

File tree

5 files changed

+344
-91
lines changed

5 files changed

+344
-91
lines changed
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
namespace GraphBLAS.FSharp.Backend.Common
2+
3+
open Brahma.FSharp
4+
open GraphBLAS.FSharp.Backend.Objects
5+
open GraphBLAS.FSharp.Backend.Objects.ClContext
6+
7+
module Merge =
8+
module Vector =
9+
let run<'a, 'b when 'a: struct and 'b: struct> (clContext: ClContext) workGroupSize =
10+
11+
let merge =
12+
<@ fun (ndRange: Range1D) (firstSide: int) (secondSide: int) (sumOfSides: int) (firstIndicesBuffer: ClArray<int>) (firstValuesBuffer: ClArray<'a>) (secondIndicesBuffer: ClArray<int>) (secondValuesBuffer: ClArray<'b>) (allIndicesBuffer: ClArray<int>) (firstResultValues: ClArray<'a>) (secondResultValues: ClArray<'b>) (isLeftBitMap: ClArray<int>) ->
13+
14+
let gid = ndRange.GlobalID0
15+
let lid = ndRange.LocalID0
16+
17+
let mutable beginIdxLocal = local ()
18+
let mutable endIdxLocal = local ()
19+
20+
if lid < 2 then
21+
// (n - 1) * wgSize - 1 for lid = 0
22+
// n * wgSize - 1 for lid = 1
23+
// where n in 1 .. wgGroupCount
24+
let x = lid * (workGroupSize - 1) + gid - 1
25+
26+
let diagonalNumber = min (sumOfSides - 1) x
27+
28+
let mutable leftEdge = max 0 (diagonalNumber + 1 - secondSide)
29+
let mutable rightEdge = min (firstSide - 1) diagonalNumber
30+
31+
while leftEdge <= rightEdge do
32+
let middleIdx = (leftEdge + rightEdge) / 2
33+
34+
let firstIndex = firstIndicesBuffer.[middleIdx]
35+
36+
let secondIndex =
37+
secondIndicesBuffer.[diagonalNumber - middleIdx]
38+
39+
if firstIndex <= secondIndex then
40+
leftEdge <- middleIdx + 1
41+
else
42+
rightEdge <- middleIdx - 1
43+
44+
// Here localID equals either 0 or 1
45+
if lid = 0 then
46+
beginIdxLocal <- leftEdge
47+
else
48+
endIdxLocal <- leftEdge
49+
50+
barrierLocal ()
51+
52+
let beginIdx = beginIdxLocal
53+
let endIdx = endIdxLocal
54+
let firstLocalLength = endIdx - beginIdx
55+
let mutable x = workGroupSize - firstLocalLength
56+
57+
if endIdx = firstSide then
58+
x <- secondSide - gid + lid + beginIdx
59+
60+
let secondLocalLength = x
61+
62+
//First indices are from 0 to firstLocalLength - 1 inclusive
63+
//Second indices are from firstLocalLength to firstLocalLength + secondLocalLength - 1 inclusive
64+
let localIndices = localArray<int> workGroupSize
65+
66+
if lid < firstLocalLength then
67+
localIndices.[lid] <- firstIndicesBuffer.[beginIdx + lid]
68+
69+
if lid < secondLocalLength then
70+
localIndices.[firstLocalLength + lid] <- secondIndicesBuffer.[gid - beginIdx]
71+
72+
barrierLocal ()
73+
74+
if gid < sumOfSides then
75+
let mutable leftEdge = lid + 1 - secondLocalLength
76+
if leftEdge < 0 then leftEdge <- 0
77+
78+
let mutable rightEdge = firstLocalLength - 1
79+
80+
rightEdge <- min rightEdge lid
81+
82+
while leftEdge <= rightEdge do
83+
let middleIdx = (leftEdge + rightEdge) / 2
84+
let firstIndex = localIndices.[middleIdx]
85+
86+
let secondIndex =
87+
localIndices.[firstLocalLength + lid - middleIdx]
88+
89+
if firstIndex <= secondIndex then
90+
leftEdge <- middleIdx + 1
91+
else
92+
rightEdge <- middleIdx - 1
93+
94+
let boundaryX = rightEdge
95+
let boundaryY = lid - leftEdge
96+
97+
// boundaryX and boundaryY can't be off the right edge of array (only off the left edge)
98+
let isValidX = boundaryX >= 0
99+
let isValidY = boundaryY >= 0
100+
101+
let mutable fstIdx = 0
102+
103+
if isValidX then
104+
fstIdx <- localIndices.[boundaryX]
105+
106+
let mutable sndIdx = 0
107+
108+
if isValidY then
109+
sndIdx <- localIndices.[firstLocalLength + boundaryY]
110+
111+
if not isValidX || isValidY && fstIdx <= sndIdx then
112+
allIndicesBuffer.[gid] <- sndIdx
113+
secondResultValues.[gid] <- secondValuesBuffer.[gid - lid - beginIdx + boundaryY]
114+
isLeftBitMap.[gid] <- 0
115+
else
116+
allIndicesBuffer.[gid] <- fstIdx
117+
firstResultValues.[gid] <- firstValuesBuffer.[beginIdx + boundaryX]
118+
isLeftBitMap.[gid] <- 1 @>
119+
120+
let kernel = clContext.Compile merge
121+
122+
fun (processor: MailboxProcessor<_>) (firstVector: ClVector.Sparse<'a>) (secondVector: ClVector.Sparse<'b>) ->
123+
124+
let firstSide = firstVector.Indices.Length
125+
126+
let secondSide = secondVector.Indices.Length
127+
128+
let sumOfSides = firstSide + secondSide
129+
130+
let allIndices =
131+
clContext.CreateClArrayWithSpecificAllocationMode<int>(DeviceOnly, sumOfSides)
132+
133+
let firstValues =
134+
clContext.CreateClArrayWithSpecificAllocationMode<'a>(DeviceOnly, sumOfSides)
135+
136+
let secondValues =
137+
clContext.CreateClArrayWithSpecificAllocationMode<'b>(DeviceOnly, sumOfSides)
138+
139+
let isLeftBitmap =
140+
clContext.CreateClArrayWithSpecificAllocationMode<int>(DeviceOnly, sumOfSides)
141+
142+
let ndRange =
143+
Range1D.CreateValid(sumOfSides, workGroupSize)
144+
145+
let kernel = kernel.GetKernel()
146+
147+
processor.Post(
148+
Msg.MsgSetArguments
149+
(fun () ->
150+
kernel.KernelFunc
151+
ndRange
152+
firstSide
153+
secondSide
154+
sumOfSides
155+
firstVector.Indices
156+
firstVector.Values
157+
secondVector.Indices
158+
secondVector.Values
159+
allIndices
160+
firstValues
161+
secondValues
162+
isLeftBitmap)
163+
)
164+
165+
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
166+
167+
allIndices, firstValues, secondValues, isLeftBitmap
168+

src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
<Compile Include="Common/Sort/Radix.fs" />
3636
<Compile Include="Common/Sort/Bitonic.fs" />
3737
<Compile Include="Common/Sum.fs" />
38+
<Compile Include="Common\Merge.fs" />
3839
<Compile Include="Matrix/Common.fs" />
3940
<Compile Include="Matrix/COOMatrix/Map2.fs" />
4041
<Compile Include="Matrix/COOMatrix/Map2AtLeastOne.fs" />
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
module GraphBLAS.FSharp.Tests.Common.Merge
2+
3+
open GraphBLAS.FSharp.Backend.Vector
4+
open GraphBLAS.FSharp.Backend.Common
5+
open GraphBLAS.FSharp.Objects
6+
open GraphBLAS.FSharp.Tests
7+
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
8+
open Brahma.FSharp
9+
open Expecto
10+
11+
let processor = Context.defaultContext.Queue
12+
13+
let context = Context.defaultContext.ClContext
14+
15+
type Result<'a>= None | Left of 'a | Right of 'a
16+
17+
let config = { Utils.defaultConfig with endSize = 100000 }
18+
19+
let makeTest isEqual zero testFun (firstArray: 'a []) (secondArray: 'a []) =
20+
let firstVector = Vector.Sparse.FromArray(firstArray, isEqual zero)
21+
22+
let secondVector = Vector.Sparse.FromArray(secondArray, isEqual zero)
23+
24+
if firstVector.NNZ > 0 && secondVector.NNZ > 0 then
25+
26+
// actual run
27+
let clFirstVector = firstVector.ToDevice context
28+
29+
let clSecondVector = secondVector.ToDevice context
30+
31+
let (allIndices: ClArray<int>), (firstValues: ClArray<'a>), (secondValues: ClArray<'a>), (isLeftBitmap: ClArray<int>) =
32+
testFun processor clFirstVector clSecondVector
33+
34+
clFirstVector.Dispose processor
35+
clSecondVector.Dispose processor
36+
37+
let actualIndices = allIndices.ToHostAndFree processor
38+
let actualFirstValues = firstValues.ToHostAndFree processor
39+
let actualSecondValues = secondValues.ToHostAndFree processor
40+
let actualIsLeftBitmap = isLeftBitmap.ToHostAndFree processor
41+
42+
let actualValues =
43+
(actualFirstValues, actualSecondValues, actualIsLeftBitmap)
44+
|||> Array.map3 (fun leftValue rightValue isLeft -> if isLeft = 1 then leftValue else rightValue)
45+
46+
// expected run
47+
let firstValuesAndIndices =
48+
Array.map2 (fun value index -> (value, index)) firstVector.Values firstVector.Indices
49+
50+
let secondValuesAndIndices =
51+
Array.map2 (fun value index -> (value, index)) secondVector.Values secondVector.Indices
52+
53+
// preserve order of values then use stable sort
54+
let allValuesAndIndices =
55+
Array.concat [ firstValuesAndIndices; secondValuesAndIndices ]
56+
57+
// stable sort
58+
let expectedValues, expectedIndices =
59+
Seq.sortBy snd allValuesAndIndices
60+
|> Seq.toArray
61+
|> Array.unzip
62+
63+
"Values should be the same"
64+
|> Utils.compareArrays isEqual actualValues expectedValues
65+
66+
"Indices should be the same"
67+
|> Utils.compareArrays (=) actualIndices expectedIndices
68+
69+
let createTest<'a when 'a : struct> isEqual (zero: 'a) =
70+
Merge.Vector.run context Utils.defaultWorkGroupSize
71+
|> makeTest isEqual zero
72+
|> testPropertyWithConfig config $"test on %A{typeof<'a>}"
73+
74+
let tests =
75+
[ createTest<int> (=) 0
76+
77+
if Utils.isFloat64Available context.ClDevice then
78+
createTest<float> (=) 0.0
79+
80+
createTest<float32> Utils.float32IsEqual 0.0f
81+
createTest<bool> (=) false ]
82+
|> testList "Merge"
83+

tests/GraphBLAS-sharp.Tests/GraphBLAS-sharp.Tests.fsproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
<Compile Include="Common/Reduce/ReduceByKey.fs" />
3333
<Compile Include="Common/Scan/PrefixSum.fs" />
3434
<Compile Include="Common/Scan/ByKey.fs" />
35+
<Compile Include="Common\Merge.fs" />
3536
<!--Compile Include="MatrixOperationsTests/GetTuplesTests.fs" /-->
3637
<!--Compile Include="MatrixOperationsTests/MxvTests.fs" /-->
3738
<!--Compile Include="MatrixOperationsTests/VxmTests.fs" /-->

0 commit comments

Comments
 (0)