11namespace GraphBLAS.FSharp.Backend.Matrix.COO
22
33open Brahma.FSharp
4- open GraphBLAS.FSharp .Backend .Matrix
5- open GraphBLAS.FSharp .Backend .Quotes
6- open Microsoft.FSharp .Quotations
7- open GraphBLAS.FSharp .Backend .Objects
8- open GraphBLAS.FSharp .Backend
9- open GraphBLAS.FSharp .Backend .Objects .ClMatrix
104open GraphBLAS.FSharp .Backend .Objects .ClContext
5+ open GraphBLAS.FSharp .Backend .Objects
116
12- module internal Map2AtLeastOne =
13- let preparePositionsAtLeastOne < 'a , 'b , 'c when 'a : struct and 'b : struct and 'c : struct and 'c : equality >
14- ( clContext : ClContext )
15- ( opAdd : Expr < 'a option -> 'b option -> 'c option >)
16- workGroupSize
17- =
18-
19- let preparePositions =
20- <@ fun ( ndRange : Range1D ) length ( allRowsBuffer : ClArray < int >) ( allColumnsBuffer : ClArray < int >) ( leftValuesBuffer : ClArray < 'a >) ( rightValuesBuffer : ClArray < 'b >) ( allValuesBuffer : ClArray < 'c >) ( rawPositionsBuffer : ClArray < int >) ( isLeftBitmap : ClArray < int >) ->
21-
22- let i = ndRange.GlobalID0
23-
24- if ( i < length - 1
25- && allRowsBuffer.[ i] = allRowsBuffer.[ i + 1 ]
26- && allColumnsBuffer.[ i] = allColumnsBuffer.[ i + 1 ]) then
27-
28- let result =
29- (% opAdd) ( Some leftValuesBuffer.[ i + 1 ]) ( Some rightValuesBuffer.[ i])
30-
31- (% PreparePositions.both) i result rawPositionsBuffer allValuesBuffer
32- elif ( i > 0
33- && i < length
34- && ( allRowsBuffer.[ i] <> allRowsBuffer.[ i - 1 ]
35- || allColumnsBuffer.[ i] <> allColumnsBuffer.[ i - 1 ]))
36- || i = 0 then
37-
38- let leftResult =
39- (% opAdd) ( Some leftValuesBuffer.[ i]) None
40-
41- let rightResult =
42- (% opAdd) None ( Some rightValuesBuffer.[ i])
43-
44- (% PreparePositions.leftRight)
45- i
46- leftResult
47- rightResult
48- isLeftBitmap
49- allValuesBuffer
50- rawPositionsBuffer @>
51-
52- let kernel = clContext.Compile( preparePositions)
53-
54- fun ( processor : MailboxProcessor < _ >) ( allRows : ClArray < int >) ( allColumns : ClArray < int >) ( leftValues : ClArray < 'a >) ( rightValues : ClArray < 'b >) ( isLeft : ClArray < int >) ->
55- let length = leftValues.Length
56-
57- let ndRange =
58- Range1D.CreateValid( length, workGroupSize)
59-
60- let rawPositionsGpu =
61- clContext.CreateClArrayWithSpecificAllocationMode< int>( DeviceOnly, length)
62-
63- let allValues =
64- clContext.CreateClArrayWithSpecificAllocationMode< 'c>( DeviceOnly, length)
65-
66- let kernel = kernel.GetKernel()
67-
68- processor.Post(
69- Msg.MsgSetArguments
70- ( fun () ->
71- kernel.KernelFunc
72- ndRange
73- length
74- allRows
75- allColumns
76- leftValues
77- rightValues
78- allValues
79- rawPositionsGpu
80- isLeft)
81- )
82-
83- processor.Post( Msg.CreateRunMsg<_, _>( kernel))
84-
85- rawPositionsGpu, allValues
86-
87- let merge < 'a , 'b when 'a : struct and 'b : struct > ( clContext : ClContext ) workGroupSize =
7+ module Merge =
8+ let run < 'a , 'b when 'a : struct and 'b : struct > ( clContext : ClContext ) workGroupSize =
889
8910 let merge =
9011 <@ fun ( ndRange : Range1D ) firstSide secondSide sumOfSides ( firstRowsBuffer : ClArray < int >) ( firstColumnsBuffer : ClArray < int >) ( firstValuesBuffer : ClArray < 'a >) ( secondRowsBuffer : ClArray < int >) ( secondColumnsBuffer : ClArray < int >) ( secondValuesBuffer : ClArray < 'b >) ( allRowsBuffer : ClArray < int >) ( allColumnsBuffer : ClArray < int >) ( leftMergedValuesBuffer : ClArray < 'a >) ( rightMergedValuesBuffer : ClArray < 'b >) ( isLeftBitmap : ClArray < int >) ->
@@ -209,10 +130,10 @@ module internal Map2AtLeastOne =
209130
210131 let kernel = clContext.Compile( merge)
211132
212- fun ( processor : MailboxProcessor < _ >) ( matrixLeftRows : ClArray < int >) ( matrixLeftColumns : ClArray < int >) ( matrixLeftValues : ClArray < 'a >) ( matrixRightRows : ClArray < int >) ( matrixRightColumns : ClArray < int >) ( matrixRightValues : ClArray <'b >) ->
133+ fun ( processor : MailboxProcessor < _ >) ( leftMatrix : ClMatrix.COO < 'a >) ( rightMatrix : ClMatrix.COO <'b >) ->
213134
214- let firstSide = matrixLeftValues .Length
215- let secondSide = matrixRightValues .Length
135+ let firstSide = leftMatrix.Columns .Length
136+ let secondSide = rightMatrix.Columns .Length
216137 let sumOfSides = firstSide + secondSide
217138
218139 let allRows =
@@ -243,12 +164,12 @@ module internal Map2AtLeastOne =
243164 firstSide
244165 secondSide
245166 sumOfSides
246- matrixLeftRows
247- matrixLeftColumns
248- matrixLeftValues
249- matrixRightRows
250- matrixRightColumns
251- matrixRightValues
167+ leftMatrix.Rows
168+ leftMatrix.Columns
169+ leftMatrix.Values
170+ rightMatrix.Rows
171+ rightMatrix.Columns
172+ rightMatrix.Values
252173 allRows
253174 allColumns
254175 leftMergedValues
@@ -259,54 +180,3 @@ module internal Map2AtLeastOne =
259180 processor.Post( Msg.CreateRunMsg<_, _>( kernel))
260181
261182 allRows, allColumns, leftMergedValues, rightMergedValues, isLeft
262-
263- ///<param name="clContext">.</param>
264- ///<param name="opAdd">.</param>
265- ///<param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
266- let run < 'a , 'b , 'c when 'a : struct and 'b : struct and 'c : struct and 'c : equality >
267- ( clContext : ClContext )
268- ( opAdd : Expr < 'a option -> 'b option -> 'c option >)
269- workGroupSize
270- =
271-
272- let merge = merge clContext workGroupSize
273-
274- let preparePositions =
275- preparePositionsAtLeastOne clContext opAdd workGroupSize
276-
277- let setPositions =
278- Common.setPositions< 'c> clContext workGroupSize
279-
280- fun ( queue : MailboxProcessor < _ >) allocationMode ( matrixLeft : ClMatrix.COO < 'a >) ( matrixRight : ClMatrix.COO < 'b >) ->
281-
282- let allRows , allColumns , leftMergedValues , rightMergedValues , isLeft =
283- merge
284- queue
285- matrixLeft.Rows
286- matrixLeft.Columns
287- matrixLeft.Values
288- matrixRight.Rows
289- matrixRight.Columns
290- matrixRight.Values
291-
292- let rawPositions , allValues =
293- preparePositions queue allRows allColumns leftMergedValues rightMergedValues isLeft
294-
295- queue.Post( Msg.CreateFreeMsg<_>( leftMergedValues))
296- queue.Post( Msg.CreateFreeMsg<_>( rightMergedValues))
297-
298- let resultRows , resultColumns , resultValues , _ =
299- setPositions queue allocationMode allRows allColumns allValues rawPositions
300-
301- queue.Post( Msg.CreateFreeMsg<_>( isLeft))
302- queue.Post( Msg.CreateFreeMsg<_>( rawPositions))
303- queue.Post( Msg.CreateFreeMsg<_>( allRows))
304- queue.Post( Msg.CreateFreeMsg<_>( allColumns))
305- queue.Post( Msg.CreateFreeMsg<_>( allValues))
306-
307- { Context = clContext
308- RowCount = matrixLeft.RowCount
309- ColumnCount = matrixLeft.ColumnCount
310- Rows = resultRows
311- Columns = resultColumns
312- Values = resultValues }
0 commit comments