Skip to content

Commit 743491e

Browse files
committed
merge: dev
2 parents 4ea099c + cb8a791 commit 743491e

File tree

18 files changed

+1033
-566
lines changed

18 files changed

+1033
-566
lines changed
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
namespace GraphBLAS.FSharp.Backend.Common
2+
3+
open Brahma.FSharp
4+
open GraphBLAS.FSharp.Backend.Objects.ClContext
5+
6+
module Merge =
7+
let run<'a, 'b when 'a: struct and 'b: struct and 'a: comparison> (clContext: ClContext) workGroupSize =
8+
9+
let defaultValue = Unchecked.defaultof<'a>
10+
11+
let merge =
12+
<@ fun (ndRange: Range1D) (firstSide: int) (secondSide: int) (sumOfSides: int) (firstValues: ClArray<'a>) (secondValues: ClArray<'a>) (resultValues: ClArray<'a>) ->
13+
14+
let gid = ndRange.GlobalID0
15+
let lid = ndRange.LocalID0
16+
17+
let mutable beginIdxLocal = local ()
18+
let mutable endIdxLocal = local ()
19+
20+
if lid < 2 then
21+
// (n - 1) * wgSize - 1 for lid = 0
22+
// n * wgSize - 1 for lid = 1
23+
// where n in 1 .. wgGroupCount
24+
let x = lid * (workGroupSize - 1) + gid - 1
25+
26+
let diagonalNumber = min (sumOfSides - 1) x
27+
28+
let mutable leftEdge = max 0 (diagonalNumber + 1 - secondSide)
29+
30+
let mutable rightEdge = min (firstSide - 1) diagonalNumber
31+
32+
while leftEdge <= rightEdge do
33+
let middleIdx = (leftEdge + rightEdge) / 2
34+
35+
let firstIndex = firstValues.[middleIdx]
36+
37+
let secondIndex =
38+
secondValues.[diagonalNumber - middleIdx]
39+
40+
if firstIndex <= secondIndex then
41+
leftEdge <- middleIdx + 1
42+
else
43+
rightEdge <- middleIdx - 1
44+
45+
// Here localID equals either 0 or 1
46+
if lid = 0 then
47+
beginIdxLocal <- leftEdge
48+
else
49+
endIdxLocal <- leftEdge
50+
51+
barrierLocal ()
52+
53+
let beginIdx = beginIdxLocal
54+
let endIdx = endIdxLocal
55+
let firstLocalLength = endIdx - beginIdx
56+
57+
let mutable x = workGroupSize - firstLocalLength
58+
59+
if endIdx = firstSide then
60+
x <- secondSide - gid + lid + beginIdx
61+
62+
let secondLocalLength = x
63+
64+
//First indices are from 0 to firstLocalLength - 1 inclusive
65+
//Second indices are from firstLocalLength to firstLocalLength + secondLocalLength - 1 inclusive
66+
let localIndices = localArray<'a> workGroupSize
67+
68+
if lid < firstLocalLength then
69+
localIndices.[lid] <- firstValues.[beginIdx + lid]
70+
71+
if lid < secondLocalLength then
72+
localIndices.[firstLocalLength + lid] <- secondValues.[gid - beginIdx]
73+
74+
barrierLocal ()
75+
76+
if gid < sumOfSides then
77+
let mutable leftEdge = lid + 1 - secondLocalLength
78+
if leftEdge < 0 then leftEdge <- 0
79+
80+
let mutable rightEdge = firstLocalLength - 1
81+
82+
rightEdge <- min rightEdge lid
83+
84+
while leftEdge <= rightEdge do
85+
let middleIdx = (leftEdge + rightEdge) / 2
86+
let firstIndex = localIndices.[middleIdx]
87+
88+
let secondIndex =
89+
localIndices.[firstLocalLength + lid - middleIdx]
90+
91+
if firstIndex <= secondIndex then
92+
leftEdge <- middleIdx + 1
93+
else
94+
rightEdge <- middleIdx - 1
95+
96+
let boundaryX = rightEdge
97+
let boundaryY = lid - leftEdge
98+
99+
// boundaryX and boundaryY can't be off the right edge of array (only off the left edge)
100+
let isValidX = boundaryX >= 0
101+
let isValidY = boundaryY >= 0
102+
103+
let mutable fstIdx = defaultValue
104+
105+
if isValidX then
106+
fstIdx <- localIndices.[boundaryX]
107+
108+
let mutable sndIdx = defaultValue
109+
110+
if isValidY then
111+
sndIdx <- localIndices.[firstLocalLength + boundaryY]
112+
113+
if not isValidX || isValidY && fstIdx <= sndIdx then
114+
resultValues.[gid] <- sndIdx
115+
else
116+
resultValues.[gid] <- fstIdx @>
117+
118+
let kernel = clContext.Compile merge
119+
120+
fun (processor: MailboxProcessor<_>) (firstValues: ClArray<'a>) (secondValues: ClArray<'a>) ->
121+
122+
let firstSide = firstValues.Length
123+
124+
let secondSide = secondValues.Length
125+
126+
let sumOfSides = firstSide + secondSide
127+
128+
let resultValues =
129+
clContext.CreateClArrayWithSpecificAllocationMode<'a>(DeviceOnly, sumOfSides)
130+
131+
let ndRange =
132+
Range1D.CreateValid(sumOfSides, workGroupSize)
133+
134+
let kernel = kernel.GetKernel()
135+
136+
processor.Post(
137+
Msg.MsgSetArguments
138+
(fun () ->
139+
kernel.KernelFunc ndRange firstSide secondSide sumOfSides firstValues secondValues resultValues)
140+
)
141+
142+
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
143+
144+
resultValues

src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
<Compile Include="Objects/AtLeastOne.fs" />
1919
<Compile Include="Objects/ClContextExtensions.fs" />
2020
<Compile Include="Objects/ClCell.fs" />
21+
2122
<Compile Include="Quotes/SubReduce.fs" />
2223
<Compile Include="Quotes/Arithmetic.fs" />
2324
<Compile Include="Quotes/Convert.fs" />
@@ -26,6 +27,7 @@
2627
<Compile Include="Quotes/PreparePositions.fs" />
2728
<Compile Include="Quotes/Predicates.fs" />
2829
<Compile Include="Quotes/Map.fs" />
30+
2931
<Compile Include="Quotes/Search.fs" />
3032
<Compile Include="Common/Scatter.fs" />
3133
<Compile Include="Common/Utils.fs" />
@@ -35,26 +37,29 @@
3537
<Compile Include="Common/Sort/Radix.fs" />
3638
<Compile Include="Common/Sort/Bitonic.fs" />
3739
<Compile Include="Common/Sum.fs" />
40+
<Compile Include="Common/Merge.fs" />
41+
3842
<Compile Include="Vector/Dense/Vector.fs" />
3943
<Compile Include="Vector/Sparse/Common.fs" />
40-
<Compile Include="Vector/Sparse/Map2AtLeastOne.fs" />
44+
<Compile Include="Vector/Sparse/Merge.fs" />
4145
<Compile Include="Vector/Sparse/Map2.fs" />
4246
<Compile Include="Vector/Sparse/Vector.fs" />
4347
<Compile Include="Vector/SpMV.fs" />
4448
<Compile Include="Vector/Vector.fs" />
45-
<Compile Include="Matrix/Common.fs" />
49+
<Compile Include="Matrix/Common.fs" />
4650
<Compile Include="Matrix/COO/Map.fs" />
51+
<Compile Include="Matrix/COO/Merge.fs" />
4752
<Compile Include="Matrix/COO/Map2.fs" />
48-
<Compile Include="Matrix/COO/Map2AtLeastOne.fs" />
4953
<Compile Include="Matrix/COO/Matrix.fs" />
50-
<Compile Include="Matrix/CSR/Map.fs" />
54+
<Compile Include="Matrix/CSR/Merge.fs" />
5155
<Compile Include="Matrix/CSR/Map2.fs" />
52-
<Compile Include="Matrix/CSR/Map2AtLeastOne.fs" />
56+
<Compile Include="Matrix/CSR/Map.fs" />
5357
<Compile Include="Matrix/CSR/Matrix.fs" />
5458
<Compile Include="Matrix/LIL/Matrix.fs" />
5559
<Compile Include="Matrix/SpGeMM/Expand.fs" />
5660
<Compile Include="Matrix/SpGeMM/Masked.fs" />
5761
<Compile Include="Matrix/Matrix.fs" />
62+
5863
<Compile Include="Algorithms/BFS.fs" />
5964
</ItemGroup>
6065
<Import Project="..\..\.paket\Paket.Restore.targets" />

src/GraphBLAS-sharp.Backend/Matrix/COO/Map2.fs

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ open GraphBLAS.FSharp.Backend
88
open GraphBLAS.FSharp.Backend.Quotes
99
open GraphBLAS.FSharp.Backend.Objects.ClMatrix
1010
open GraphBLAS.FSharp.Backend.Objects.ClContext
11-
open GraphBLAS.FSharp.Backend.Quotes
1211

1312
module internal Map2 =
1413

@@ -134,3 +133,123 @@ module internal Map2 =
134133
Rows = resultRows
135134
Columns = resultColumns
136135
Values = resultValues }
136+
137+
module AtLeastOne =
138+
let preparePositionsAtLeastOne<'a, 'b, 'c when 'a: struct and 'b: struct and 'c: struct and 'c: equality>
139+
(clContext: ClContext)
140+
(opAdd: Expr<'a option -> 'b option -> 'c option>)
141+
workGroupSize
142+
=
143+
144+
let preparePositions =
145+
<@ fun (ndRange: Range1D) length (allRowsBuffer: ClArray<int>) (allColumnsBuffer: ClArray<int>) (leftValuesBuffer: ClArray<'a>) (rightValuesBuffer: ClArray<'b>) (allValuesBuffer: ClArray<'c>) (rawPositionsBuffer: ClArray<int>) (isLeftBitmap: ClArray<int>) ->
146+
147+
let i = ndRange.GlobalID0
148+
149+
if (i < length - 1
150+
&& allRowsBuffer.[i] = allRowsBuffer.[i + 1]
151+
&& allColumnsBuffer.[i] = allColumnsBuffer.[i + 1]) then
152+
153+
let result =
154+
(%opAdd) (Some leftValuesBuffer.[i + 1]) (Some rightValuesBuffer.[i])
155+
156+
(%PreparePositions.both) i result rawPositionsBuffer allValuesBuffer
157+
elif (i > 0
158+
&& i < length
159+
&& (allRowsBuffer.[i] <> allRowsBuffer.[i - 1]
160+
|| allColumnsBuffer.[i] <> allColumnsBuffer.[i - 1]))
161+
|| i = 0 then
162+
163+
let leftResult =
164+
(%opAdd) (Some leftValuesBuffer.[i]) None
165+
166+
let rightResult =
167+
(%opAdd) None (Some rightValuesBuffer.[i])
168+
169+
(%PreparePositions.leftRight)
170+
i
171+
leftResult
172+
rightResult
173+
isLeftBitmap
174+
allValuesBuffer
175+
rawPositionsBuffer @>
176+
177+
let kernel = clContext.Compile(preparePositions)
178+
179+
fun (processor: MailboxProcessor<_>) (allRows: ClArray<int>) (allColumns: ClArray<int>) (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (isLeft: ClArray<int>) ->
180+
let length = leftValues.Length
181+
182+
let ndRange =
183+
Range1D.CreateValid(length, workGroupSize)
184+
185+
let rawPositionsGpu =
186+
clContext.CreateClArrayWithSpecificAllocationMode<int>(DeviceOnly, length)
187+
188+
let allValues =
189+
clContext.CreateClArrayWithSpecificAllocationMode<'c>(DeviceOnly, length)
190+
191+
let kernel = kernel.GetKernel()
192+
193+
processor.Post(
194+
Msg.MsgSetArguments
195+
(fun () ->
196+
kernel.KernelFunc
197+
ndRange
198+
length
199+
allRows
200+
allColumns
201+
leftValues
202+
rightValues
203+
allValues
204+
rawPositionsGpu
205+
isLeft)
206+
)
207+
208+
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
209+
210+
rawPositionsGpu, allValues
211+
212+
213+
///<param name="clContext">.</param>
214+
///<param name="opAdd">.</param>
215+
///<param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
216+
let run<'a, 'b, 'c when 'a: struct and 'b: struct and 'c: struct and 'c: equality>
217+
(clContext: ClContext)
218+
(opAdd: Expr<'a option -> 'b option -> 'c option>)
219+
workGroupSize
220+
=
221+
222+
let merge = Merge.run clContext workGroupSize
223+
224+
let preparePositions =
225+
preparePositionsAtLeastOne clContext opAdd workGroupSize
226+
227+
let setPositions =
228+
Common.setPositions<'c> clContext workGroupSize
229+
230+
fun (queue: MailboxProcessor<_>) allocationMode (matrixLeft: ClMatrix.COO<'a>) (matrixRight: ClMatrix.COO<'b>) ->
231+
232+
let allRows, allColumns, leftMergedValues, rightMergedValues, isLeft =
233+
merge queue matrixLeft matrixRight
234+
235+
let rawPositions, allValues =
236+
preparePositions queue allRows allColumns leftMergedValues rightMergedValues isLeft
237+
238+
queue.Post(Msg.CreateFreeMsg<_>(leftMergedValues))
239+
queue.Post(Msg.CreateFreeMsg<_>(rightMergedValues))
240+
241+
let resultRows, resultColumns, resultValues, _ =
242+
setPositions queue allocationMode allRows allColumns allValues rawPositions
243+
244+
queue.Post(Msg.CreateFreeMsg<_>(isLeft))
245+
queue.Post(Msg.CreateFreeMsg<_>(rawPositions))
246+
queue.Post(Msg.CreateFreeMsg<_>(allRows))
247+
queue.Post(Msg.CreateFreeMsg<_>(allColumns))
248+
queue.Post(Msg.CreateFreeMsg<_>(allValues))
249+
250+
{ Context = clContext
251+
RowCount = matrixLeft.RowCount
252+
ColumnCount = matrixLeft.ColumnCount
253+
Rows = resultRows
254+
Columns = resultColumns
255+
Values = resultValues }

src/GraphBLAS-sharp.Backend/Matrix/COO/Matrix.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ module Matrix =
2323
workGroupSize
2424
=
2525

26-
Map2AtLeastOne.run (Convert.atLeastOneToOption opAdd) clContext workGroupSize
26+
Map2.AtLeastOne.run clContext (Convert.atLeastOneToOption opAdd) workGroupSize
2727

2828
let getTuples (clContext: ClContext) workGroupSize =
2929

0 commit comments

Comments
 (0)