Skip to content

Commit 850ca14

Browse files
committed
add: Common.Merge with explicit error
1 parent 661e085 commit 850ca14

File tree

6 files changed

+201
-2
lines changed

6 files changed

+201
-2
lines changed
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
namespace GraphBLAS.FSharp.Backend.Common
2+
3+
open Brahma.FSharp
4+
open GraphBLAS.FSharp.Backend.Objects.ClContext
5+
6+
module Merge =
7+
let run<'a, 'b when 'a: struct and 'b: struct and 'a: comparison> (clContext: ClContext) workGroupSize =
8+
9+
let defaultValue = Unchecked.defaultof<'a>
10+
11+
let merge =
12+
<@ fun (ndRange: Range1D) (firstSide: int) (secondSide: int) (sumOfSides: int) (firstValues: ClArray<'a>) (secondValues: ClArray<'a>) (resultValues: ClArray<'a>) ->
13+
14+
let gid = ndRange.GlobalID0
15+
let lid = ndRange.LocalID0
16+
17+
let mutable beginIdxLocal = local ()
18+
let mutable endIdxLocal = local ()
19+
20+
if lid < 2 then
21+
// (n - 1) * wgSize - 1 for lid = 0
22+
// n * wgSize - 1 for lid = 1
23+
// where n in 1 .. wgGroupCount
24+
let x = lid * (workGroupSize - 1) + gid - 1
25+
26+
let diagonalNumber = min (sumOfSides - 1) x
27+
28+
let mutable leftEdge = max 0 (diagonalNumber + 1 - secondSide)
29+
30+
let mutable rightEdge = min (firstSide - 1) diagonalNumber
31+
32+
while leftEdge <= rightEdge do
33+
let middleIdx = (leftEdge + rightEdge) / 2
34+
35+
let firstIndex = firstValues.[middleIdx]
36+
37+
let secondIndex =
38+
secondValues.[diagonalNumber - middleIdx]
39+
40+
if firstIndex <= secondIndex then
41+
leftEdge <- middleIdx + 1
42+
else
43+
rightEdge <- middleIdx - 1
44+
45+
// Here localID equals either 0 or 1
46+
if lid = 0 then
47+
beginIdxLocal <- leftEdge
48+
else
49+
endIdxLocal <- leftEdge
50+
51+
barrierLocal ()
52+
53+
let beginIdx = beginIdxLocal
54+
let endIdx = endIdxLocal
55+
let firstLocalLength = endIdx - beginIdx
56+
let mutable x = workGroupSize - firstLocalLength
57+
58+
if endIdx = firstSide then
59+
x <- secondSide - gid + lid + beginIdx
60+
61+
let secondLocalLength = x
62+
63+
//First indices are from 0 to firstLocalLength - 1 inclusive
64+
//Second indices are from firstLocalLength to firstLocalLength + secondLocalLength - 1 inclusive
65+
let localIndices = localArray<'a> workGroupSize
66+
67+
if lid < firstLocalLength then
68+
localIndices.[lid] <- firstValues.[beginIdx + lid]
69+
70+
if lid < secondLocalLength then
71+
localIndices.[firstLocalLength + lid] <- firstValues.[gid - beginIdx]
72+
73+
barrierLocal ()
74+
75+
if gid < sumOfSides then
76+
let mutable leftEdge = lid + 1 - secondLocalLength
77+
if leftEdge < 0 then leftEdge <- 0
78+
79+
let mutable rightEdge = firstLocalLength - 1
80+
81+
rightEdge <- min rightEdge lid
82+
83+
while leftEdge <= rightEdge do
84+
let middleIdx = (leftEdge + rightEdge) / 2
85+
let firstIndex = localIndices.[middleIdx]
86+
87+
let secondIndex =
88+
localIndices.[firstLocalLength + lid - middleIdx]
89+
90+
if firstIndex <= secondIndex then
91+
leftEdge <- middleIdx + 1
92+
else
93+
rightEdge <- middleIdx - 1
94+
95+
let boundaryX = rightEdge
96+
let boundaryY = lid - leftEdge
97+
98+
// boundaryX and boundaryY can't be off the right edge of array (only off the left edge)
99+
let isValidX = boundaryX >= 0
100+
let isValidY = boundaryY >= 0
101+
102+
let mutable fstIdx = defaultValue
103+
104+
if isValidX then
105+
fstIdx <- localIndices.[boundaryX]
106+
107+
let mutable sndIdx = defaultValue
108+
109+
if isValidY then
110+
sndIdx <- localIndices.[firstLocalLength + boundaryY]
111+
112+
if not isValidX || isValidY && fstIdx <= sndIdx then
113+
resultValues.[gid] <- sndIdx
114+
else
115+
resultValues.[gid] <- fstIdx @>
116+
117+
let kernel = clContext.Compile merge
118+
119+
fun (processor: MailboxProcessor<_>) (firstValues: ClArray<'a>) (secondValues: ClArray<'a>) ->
120+
121+
let firstSide = firstValues.Length
122+
123+
let secondSide = secondValues.Length
124+
125+
let sumOfSides = firstSide + secondSide
126+
127+
let resultValues =
128+
clContext.CreateClArrayWithSpecificAllocationMode<'a>(DeviceOnly, sumOfSides)
129+
130+
let ndRange =
131+
Range1D.CreateValid(sumOfSides, workGroupSize)
132+
133+
let kernel = kernel.GetKernel()
134+
135+
processor.Post(
136+
Msg.MsgSetArguments
137+
(fun () ->
138+
kernel.KernelFunc ndRange firstSide secondSide sumOfSides firstValues secondValues resultValues)
139+
)
140+
141+
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
142+
143+
resultValues

src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
<Compile Include="Common/Sort/Radix.fs" />
3636
<Compile Include="Common/Sort/Bitonic.fs" />
3737
<Compile Include="Common/Sum.fs" />
38+
<Compile Include="Common\Merge.fs" />
3839
<Compile Include="Matrix/Common.fs" />
3940
<Compile Include="Matrix\COO\Map.fs" />
4041
<Compile Include="Matrix\COO\Merge.fs" />
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
module GraphBLAS.FSharp.Tests.Common.Merge
2+
3+
open Expecto
4+
open Brahma.FSharp
5+
open GraphBLAS.FSharp.Tests
6+
open GraphBLAS.FSharp.Backend.Common
7+
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
8+
9+
let context = Context.defaultContext.ClContext
10+
11+
let processor = Context.defaultContext.Queue
12+
13+
let config =
14+
{ Utils.defaultConfig with
15+
endSize = 10 }
16+
17+
let makeTest isEqual testFun (leftArray: 'a []) (rightArray: 'a []) =
18+
if leftArray.Length > 0 && rightArray.Length > 0 then
19+
20+
let leftArray = Array.sort leftArray |> Array.distinct
21+
22+
let rightArray = Array.sort rightArray |> Array.distinct
23+
24+
let clLeftArray = context.CreateClArray leftArray
25+
let clRightArray = context.CreateClArray rightArray
26+
27+
let clResult: ClArray<'a> =
28+
testFun processor clLeftArray clRightArray
29+
30+
let result = clResult.ToHostAndFree processor
31+
clLeftArray.Free processor
32+
clRightArray.Free processor
33+
34+
let expected =
35+
Array.concat [ leftArray; rightArray ]
36+
|> Array.sort
37+
38+
"Results must be the same"
39+
|> Utils.compareArrays isEqual result expected
40+
41+
let createTest<'a> isEqual =
42+
Merge.run context Utils.defaultWorkGroupSize
43+
|> makeTest isEqual
44+
|> testPropertyWithConfig config $"test on {typeof<'a>}"
45+
46+
let tests =
47+
[ createTest<int> (=)
48+
49+
if Utils.isFloat64Available context.ClDevice then
50+
createTest<float> (=)
51+
52+
createTest<float32> (=)
53+
createTest<bool> (=) ]
54+
|> testList "Merge"

tests/GraphBLAS-sharp.Tests/GraphBLAS-sharp.Tests.fsproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
<Compile Include="QuickGraph/Algorithms/BFS.fs" />
1919
<Compile Include="Common/Scatter.fs" />
2020
<Compile Include="Common/Gather.fs" />
21+
<Compile Include="Common\Merge.fs" />
2122
<Compile Include="Common/ClArray/Choose.fs" />
2223
<Compile Include="Common/ClArray/Exists.fs" />
2324
<Compile Include="Common/ClArray/Map.fs" />

tests/GraphBLAS-sharp.Tests/Matrix/Merge.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ let makeTestCSR isEqual zero testFun (leftArray: 'a [,], rightArray: 'a [,]) =
161161
let createTestCSR isEqual (zero: 'a) =
162162
Matrix.CSR.Merge.run context Utils.defaultWorkGroupSize
163163
|> makeTestCSR isEqual zero
164-
|> testPropertyWithConfig { config with endSize = 10 } $"test on {typeof<'a>}"
164+
|> testPropertyWithConfig config $"test on {typeof<'a>}"
165165

166166
let testsCSR =
167167
[ createTestCSR (=) 0

tests/GraphBLAS-sharp.Tests/Program.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,6 @@ open GraphBLAS.FSharp.Tests
9595

9696
[<EntryPoint>]
9797
let main argv =
98-
Matrix.Merge.testsCOO
98+
Common.Merge.tests
9999
|> testSequenced
100100
|> runTestsWithCLIArgs [] argv

0 commit comments

Comments
 (0)