Skip to content

Commit 848a11b

Browse files
committed
refactor: Vector.Merge
1 parent 2b4a5b7 commit 848a11b

File tree

8 files changed

+266
-440
lines changed

8 files changed

+266
-440
lines changed

src/GraphBLAS-sharp.Backend/Common/Merge.fs

Lines changed: 0 additions & 168 deletions
This file was deleted.

src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
<Compile Include="Common/Sort/Radix.fs" />
3636
<Compile Include="Common/Sort/Bitonic.fs" />
3737
<Compile Include="Common/Sum.fs" />
38-
<Compile Include="Common\Merge.fs" />
3938
<Compile Include="Matrix/Common.fs" />
4039
<Compile Include="Matrix/COOMatrix/Map2.fs" />
4140
<Compile Include="Matrix/COOMatrix/Map2AtLeastOne.fs" />
@@ -49,8 +48,8 @@
4948
<Compile Include="Matrix/CSRMatrix/Matrix.fs" />
5049
<Compile Include="Matrix/Matrix.fs" />
5150
<Compile Include="Vector/SparseVector/Common.fs" />
51+
<Compile Include="Vector/SparseVector/Merge.fs" />
5252
<Compile Include="Vector/SparseVector/Map2.fs" />
53-
<Compile Include="Vector/SparseVector/Map2AtLeastOne.fs" />
5453
<Compile Include="Vector/SparseVector/SparseVector.fs" />
5554
<Compile Include="Vector/DenseVector/DenseVector.fs" />
5655
<Compile Include="Vector/Vector.fs" />

src/GraphBLAS-sharp.Backend/Vector/SparseVector/Map2.fs

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,93 @@ module internal Map2 =
206206
Values = resultValues
207207
Indices = resultIndices
208208
Size = rightVector.Size }
209+
210+
module AtLeastOne =
211+
let private preparePositions<'a, 'b, 'c when 'a: struct and 'b: struct and 'c: struct>
212+
(clContext: ClContext)
213+
op
214+
workGroupSize
215+
=
216+
217+
let preparePositions opAdd =
218+
<@ fun (ndRange: Range1D) length (allIndices: ClArray<int>) (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (isLeft: ClArray<int>) (allValues: ClArray<'c>) (positions: ClArray<int>) ->
219+
220+
let gid = ndRange.GlobalID0
221+
222+
if gid < length - 1
223+
&& allIndices.[gid] = allIndices.[gid + 1] then
224+
let result =
225+
(%opAdd) (Some leftValues.[gid]) (Some rightValues.[gid + 1])
226+
227+
(%PreparePositions.both) gid result positions allValues
228+
elif (gid < length
229+
&& gid > 0
230+
&& allIndices.[gid - 1] <> allIndices.[gid])
231+
|| gid = 0 then
232+
let leftResult = (%opAdd) (Some leftValues.[gid]) None
233+
let rightResult = (%opAdd) None (Some rightValues.[gid])
234+
235+
(%PreparePositions.leftRight) gid leftResult rightResult isLeft allValues positions @>
236+
237+
let kernel = clContext.Compile <| preparePositions op
238+
239+
fun (processor: MailboxProcessor<_>) (allIndices: ClArray<int>) (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (isLeft: ClArray<int>) ->
240+
241+
let length = allIndices.Length
242+
243+
let allValues =
244+
clContext.CreateClArrayWithSpecificAllocationMode<'c>(DeviceOnly, length)
245+
246+
let positions =
247+
clContext.CreateClArrayWithSpecificAllocationMode<int>(DeviceOnly, length)
248+
249+
let ndRange =
250+
Range1D.CreateValid(length, workGroupSize)
251+
252+
let kernel = kernel.GetKernel()
253+
254+
processor.Post(
255+
Msg.MsgSetArguments
256+
(fun () ->
257+
kernel.KernelFunc ndRange length allIndices leftValues rightValues isLeft allValues positions)
258+
)
259+
260+
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
261+
262+
allValues, positions
263+
264+
///<param name="clContext">.</param>
265+
///<param name="op">.</param>
266+
///<param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
267+
let run<'a, 'b, 'c when 'a: struct and 'b: struct and 'c: struct> (clContext: ClContext) op workGroupSize =
268+
269+
let merge = Merge.run clContext workGroupSize
270+
271+
let prepare =
272+
preparePositions<'a, 'b, 'c> clContext op workGroupSize
273+
274+
let setPositions =
275+
Common.setPositions clContext workGroupSize
276+
277+
fun (processor: MailboxProcessor<_>) allocationMode (leftVector: ClVector.Sparse<'a>) (rightVector: ClVector.Sparse<'b>) ->
278+
279+
let allIndices, leftValues, rightValues, isLeft = merge processor leftVector rightVector
280+
281+
let allValues, positions =
282+
prepare processor allIndices leftValues rightValues isLeft
283+
284+
processor.Post(Msg.CreateFreeMsg<_>(leftValues))
285+
processor.Post(Msg.CreateFreeMsg<_>(rightValues))
286+
processor.Post(Msg.CreateFreeMsg<_>(isLeft))
287+
288+
let resultValues, resultIndices =
289+
setPositions processor allocationMode allValues allIndices positions
290+
291+
processor.Post(Msg.CreateFreeMsg<_>(allIndices))
292+
processor.Post(Msg.CreateFreeMsg<_>(allValues))
293+
processor.Post(Msg.CreateFreeMsg<_>(positions))
294+
295+
{ Context = clContext
296+
Values = resultValues
297+
Indices = resultIndices
298+
Size = max leftVector.Size rightVector.Size }

0 commit comments

Comments
 (0)