Skip to content

Commit cb8a791

Browse files
authored
Merge pull request #76 from IgorErin/merge
Merge
2 parents e0485c5 + becee4d commit cb8a791

File tree

25 files changed

+1042
-585
lines changed

25 files changed

+1042
-585
lines changed
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
namespace GraphBLAS.FSharp.Backend.Common
2+
3+
open Brahma.FSharp
4+
open GraphBLAS.FSharp.Backend.Objects.ClContext
5+
6+
module Merge =
7+
let run<'a, 'b when 'a: struct and 'b: struct and 'a: comparison> (clContext: ClContext) workGroupSize =
8+
9+
let defaultValue = Unchecked.defaultof<'a>
10+
11+
let merge =
12+
<@ fun (ndRange: Range1D) (firstSide: int) (secondSide: int) (sumOfSides: int) (firstValues: ClArray<'a>) (secondValues: ClArray<'a>) (resultValues: ClArray<'a>) ->
13+
14+
let gid = ndRange.GlobalID0
15+
let lid = ndRange.LocalID0
16+
17+
let mutable beginIdxLocal = local ()
18+
let mutable endIdxLocal = local ()
19+
20+
if lid < 2 then
21+
// (n - 1) * wgSize - 1 for lid = 0
22+
// n * wgSize - 1 for lid = 1
23+
// where n in 1 .. wgGroupCount
24+
let x = lid * (workGroupSize - 1) + gid - 1
25+
26+
let diagonalNumber = min (sumOfSides - 1) x
27+
28+
let mutable leftEdge = max 0 (diagonalNumber + 1 - secondSide)
29+
30+
let mutable rightEdge = min (firstSide - 1) diagonalNumber
31+
32+
while leftEdge <= rightEdge do
33+
let middleIdx = (leftEdge + rightEdge) / 2
34+
35+
let firstIndex = firstValues.[middleIdx]
36+
37+
let secondIndex =
38+
secondValues.[diagonalNumber - middleIdx]
39+
40+
if firstIndex <= secondIndex then
41+
leftEdge <- middleIdx + 1
42+
else
43+
rightEdge <- middleIdx - 1
44+
45+
// Here localID equals either 0 or 1
46+
if lid = 0 then
47+
beginIdxLocal <- leftEdge
48+
else
49+
endIdxLocal <- leftEdge
50+
51+
barrierLocal ()
52+
53+
let beginIdx = beginIdxLocal
54+
let endIdx = endIdxLocal
55+
let firstLocalLength = endIdx - beginIdx
56+
57+
let mutable x = workGroupSize - firstLocalLength
58+
59+
if endIdx = firstSide then
60+
x <- secondSide - gid + lid + beginIdx
61+
62+
let secondLocalLength = x
63+
64+
//First indices are from 0 to firstLocalLength - 1 inclusive
65+
//Second indices are from firstLocalLength to firstLocalLength + secondLocalLength - 1 inclusive
66+
let localIndices = localArray<'a> workGroupSize
67+
68+
if lid < firstLocalLength then
69+
localIndices.[lid] <- firstValues.[beginIdx + lid]
70+
71+
if lid < secondLocalLength then
72+
localIndices.[firstLocalLength + lid] <- secondValues.[gid - beginIdx]
73+
74+
barrierLocal ()
75+
76+
if gid < sumOfSides then
77+
let mutable leftEdge = lid + 1 - secondLocalLength
78+
if leftEdge < 0 then leftEdge <- 0
79+
80+
let mutable rightEdge = firstLocalLength - 1
81+
82+
rightEdge <- min rightEdge lid
83+
84+
while leftEdge <= rightEdge do
85+
let middleIdx = (leftEdge + rightEdge) / 2
86+
let firstIndex = localIndices.[middleIdx]
87+
88+
let secondIndex =
89+
localIndices.[firstLocalLength + lid - middleIdx]
90+
91+
if firstIndex <= secondIndex then
92+
leftEdge <- middleIdx + 1
93+
else
94+
rightEdge <- middleIdx - 1
95+
96+
let boundaryX = rightEdge
97+
let boundaryY = lid - leftEdge
98+
99+
// boundaryX and boundaryY can't be off the right edge of array (only off the left edge)
100+
let isValidX = boundaryX >= 0
101+
let isValidY = boundaryY >= 0
102+
103+
let mutable fstIdx = defaultValue
104+
105+
if isValidX then
106+
fstIdx <- localIndices.[boundaryX]
107+
108+
let mutable sndIdx = defaultValue
109+
110+
if isValidY then
111+
sndIdx <- localIndices.[firstLocalLength + boundaryY]
112+
113+
if not isValidX || isValidY && fstIdx <= sndIdx then
114+
resultValues.[gid] <- sndIdx
115+
else
116+
resultValues.[gid] <- fstIdx @>
117+
118+
let kernel = clContext.Compile merge
119+
120+
fun (processor: MailboxProcessor<_>) (firstValues: ClArray<'a>) (secondValues: ClArray<'a>) ->
121+
122+
let firstSide = firstValues.Length
123+
124+
let secondSide = secondValues.Length
125+
126+
let sumOfSides = firstSide + secondSide
127+
128+
let resultValues =
129+
clContext.CreateClArrayWithSpecificAllocationMode<'a>(DeviceOnly, sumOfSides)
130+
131+
let ndRange =
132+
Range1D.CreateValid(sumOfSides, workGroupSize)
133+
134+
let kernel = kernel.GetKernel()
135+
136+
processor.Post(
137+
Msg.MsgSetArguments
138+
(fun () ->
139+
kernel.KernelFunc ndRange firstSide secondSide sumOfSides firstValues secondValues resultValues)
140+
)
141+
142+
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
143+
144+
resultValues

src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,22 @@
3535
<Compile Include="Common/Sort/Radix.fs" />
3636
<Compile Include="Common/Sort/Bitonic.fs" />
3737
<Compile Include="Common/Sum.fs" />
38+
<Compile Include="Common/Merge.fs" />
3839
<Compile Include="Matrix/Common.fs" />
39-
<Compile Include="Matrix/COOMatrix/Map2.fs" />
40-
<Compile Include="Matrix/COOMatrix/Map2AtLeastOne.fs" />
41-
<Compile Include="Matrix/COOMatrix/Map.fs" />
42-
<Compile Include="Matrix/COOMatrix/Matrix.fs" />
43-
<Compile Include="Matrix/CSRMatrix/Map2.fs" />
44-
<Compile Include="Matrix/CSRMatrix/SpGEMM/Expand.fs" />
45-
<Compile Include="Matrix/CSRMatrix/SpGEMM/Masked.fs" />
46-
<Compile Include="Matrix/CSRMatrix/Map2AtLeastOne.fs" />
47-
<Compile Include="Matrix/CSRMatrix/Map.fs" />
48-
<Compile Include="Matrix/CSRMatrix/Matrix.fs" />
40+
<Compile Include="Matrix/COO/Map.fs" />
41+
<Compile Include="Matrix/COO/Merge.fs" />
42+
<Compile Include="Matrix/COO/Map2.fs" />
43+
<Compile Include="Matrix/COO/Matrix.fs" />
44+
<Compile Include="Matrix/CSR/Merge.fs" />
45+
<Compile Include="Matrix/CSR/Map2.fs" />
46+
<Compile Include="Matrix/CSR/SpGEMM/Expand.fs" />
47+
<Compile Include="Matrix/CSR/SpGEMM/Masked.fs" />
48+
<Compile Include="Matrix/CSR/Map.fs" />
49+
<Compile Include="Matrix/CSR/Matrix.fs" />
4950
<Compile Include="Matrix/Matrix.fs" />
5051
<Compile Include="Vector/SparseVector/Common.fs" />
52+
<Compile Include="Vector/SparseVector/Merge.fs" />
5153
<Compile Include="Vector/SparseVector/Map2.fs" />
52-
<Compile Include="Vector/SparseVector/Map2AtLeastOne.fs" />
5354
<Compile Include="Vector/SparseVector/SparseVector.fs" />
5455
<Compile Include="Vector/DenseVector/DenseVector.fs" />
5556
<Compile Include="Vector/Vector.fs" />

src/GraphBLAS-sharp.Backend/Matrix/COOMatrix/Map2.fs renamed to src/GraphBLAS-sharp.Backend/Matrix/COO/Map2.fs

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,8 @@ open GraphBLAS.FSharp.Backend
88
open GraphBLAS.FSharp.Backend.Quotes
99
open GraphBLAS.FSharp.Backend.Objects.ClMatrix
1010
open GraphBLAS.FSharp.Backend.Objects.ClContext
11-
open GraphBLAS.FSharp.Backend.Quotes
1211

1312
module internal Map2 =
14-
1513
let preparePositions<'a, 'b, 'c> (clContext: ClContext) workGroupSize opAdd =
1614

1715
let preparePositions (op: Expr<'a option -> 'b option -> 'c option>) =
@@ -134,3 +132,123 @@ module internal Map2 =
134132
Rows = resultRows
135133
Columns = resultColumns
136134
Values = resultValues }
135+
136+
module AtLeastOne =
137+
let preparePositionsAtLeastOne<'a, 'b, 'c when 'a: struct and 'b: struct and 'c: struct and 'c: equality>
138+
(clContext: ClContext)
139+
(opAdd: Expr<'a option -> 'b option -> 'c option>)
140+
workGroupSize
141+
=
142+
143+
let preparePositions =
144+
<@ fun (ndRange: Range1D) length (allRowsBuffer: ClArray<int>) (allColumnsBuffer: ClArray<int>) (leftValuesBuffer: ClArray<'a>) (rightValuesBuffer: ClArray<'b>) (allValuesBuffer: ClArray<'c>) (rawPositionsBuffer: ClArray<int>) (isLeftBitmap: ClArray<int>) ->
145+
146+
let i = ndRange.GlobalID0
147+
148+
if (i < length - 1
149+
&& allRowsBuffer.[i] = allRowsBuffer.[i + 1]
150+
&& allColumnsBuffer.[i] = allColumnsBuffer.[i + 1]) then
151+
152+
let result =
153+
(%opAdd) (Some leftValuesBuffer.[i + 1]) (Some rightValuesBuffer.[i])
154+
155+
(%PreparePositions.both) i result rawPositionsBuffer allValuesBuffer
156+
elif (i > 0
157+
&& i < length
158+
&& (allRowsBuffer.[i] <> allRowsBuffer.[i - 1]
159+
|| allColumnsBuffer.[i] <> allColumnsBuffer.[i - 1]))
160+
|| i = 0 then
161+
162+
let leftResult =
163+
(%opAdd) (Some leftValuesBuffer.[i]) None
164+
165+
let rightResult =
166+
(%opAdd) None (Some rightValuesBuffer.[i])
167+
168+
(%PreparePositions.leftRight)
169+
i
170+
leftResult
171+
rightResult
172+
isLeftBitmap
173+
allValuesBuffer
174+
rawPositionsBuffer @>
175+
176+
let kernel = clContext.Compile(preparePositions)
177+
178+
fun (processor: MailboxProcessor<_>) (allRows: ClArray<int>) (allColumns: ClArray<int>) (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (isLeft: ClArray<int>) ->
179+
let length = leftValues.Length
180+
181+
let ndRange =
182+
Range1D.CreateValid(length, workGroupSize)
183+
184+
let rawPositionsGpu =
185+
clContext.CreateClArrayWithSpecificAllocationMode<int>(DeviceOnly, length)
186+
187+
let allValues =
188+
clContext.CreateClArrayWithSpecificAllocationMode<'c>(DeviceOnly, length)
189+
190+
let kernel = kernel.GetKernel()
191+
192+
processor.Post(
193+
Msg.MsgSetArguments
194+
(fun () ->
195+
kernel.KernelFunc
196+
ndRange
197+
length
198+
allRows
199+
allColumns
200+
leftValues
201+
rightValues
202+
allValues
203+
rawPositionsGpu
204+
isLeft)
205+
)
206+
207+
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
208+
209+
rawPositionsGpu, allValues
210+
211+
212+
///<param name="clContext">.</param>
213+
///<param name="opAdd">.</param>
214+
///<param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
215+
let run<'a, 'b, 'c when 'a: struct and 'b: struct and 'c: struct and 'c: equality>
216+
(clContext: ClContext)
217+
(opAdd: Expr<'a option -> 'b option -> 'c option>)
218+
workGroupSize
219+
=
220+
221+
let merge = Merge.run clContext workGroupSize
222+
223+
let preparePositions =
224+
preparePositionsAtLeastOne clContext opAdd workGroupSize
225+
226+
let setPositions =
227+
Common.setPositions<'c> clContext workGroupSize
228+
229+
fun (queue: MailboxProcessor<_>) allocationMode (matrixLeft: ClMatrix.COO<'a>) (matrixRight: ClMatrix.COO<'b>) ->
230+
231+
let allRows, allColumns, leftMergedValues, rightMergedValues, isLeft =
232+
merge queue matrixLeft matrixRight
233+
234+
let rawPositions, allValues =
235+
preparePositions queue allRows allColumns leftMergedValues rightMergedValues isLeft
236+
237+
queue.Post(Msg.CreateFreeMsg<_>(leftMergedValues))
238+
queue.Post(Msg.CreateFreeMsg<_>(rightMergedValues))
239+
240+
let resultRows, resultColumns, resultValues, _ =
241+
setPositions queue allocationMode allRows allColumns allValues rawPositions
242+
243+
queue.Post(Msg.CreateFreeMsg<_>(isLeft))
244+
queue.Post(Msg.CreateFreeMsg<_>(rawPositions))
245+
queue.Post(Msg.CreateFreeMsg<_>(allRows))
246+
queue.Post(Msg.CreateFreeMsg<_>(allColumns))
247+
queue.Post(Msg.CreateFreeMsg<_>(allValues))
248+
249+
{ Context = clContext
250+
RowCount = matrixLeft.RowCount
251+
ColumnCount = matrixLeft.ColumnCount
252+
Rows = resultRows
253+
Columns = resultColumns
254+
Values = resultValues }

src/GraphBLAS-sharp.Backend/Matrix/COOMatrix/Matrix.fs renamed to src/GraphBLAS-sharp.Backend/Matrix/COO/Matrix.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ module Matrix =
2121
workGroupSize
2222
=
2323

24-
Map2AtLeastOne.run clContext (Convert.atLeastOneToOption opAdd) workGroupSize
24+
Map2.AtLeastOne.run clContext (Convert.atLeastOneToOption opAdd) workGroupSize
2525

2626
let getTuples (clContext: ClContext) workGroupSize =
2727

0 commit comments

Comments
 (0)