@@ -206,3 +206,93 @@ module internal Map2 =
206206 Values = resultValues
207207 Indices = resultIndices
208208 Size = rightVector.Size }
209+
210+ module AtLeastOne =
211+ let private preparePositions < 'a , 'b , 'c when 'a : struct and 'b : struct and 'c : struct >
212+ ( clContext : ClContext )
213+ op
214+ workGroupSize
215+ =
216+
217+ let preparePositions opAdd =
218+ <@ fun ( ndRange : Range1D ) length ( allIndices : ClArray < int >) ( leftValues : ClArray < 'a >) ( rightValues : ClArray < 'b >) ( isLeft : ClArray < int >) ( allValues : ClArray < 'c >) ( positions : ClArray < int >) ->
219+
220+ let gid = ndRange.GlobalID0
221+
222+ if gid < length - 1
223+ && allIndices.[ gid] = allIndices.[ gid + 1 ] then
224+ let result =
225+ (% opAdd) ( Some leftValues.[ gid]) ( Some rightValues.[ gid + 1 ])
226+
227+ (% PreparePositions.both) gid result positions allValues
228+ elif ( gid < length
229+ && gid > 0
230+ && allIndices.[ gid - 1 ] <> allIndices.[ gid])
231+ || gid = 0 then
232+ let leftResult = (% opAdd) ( Some leftValues.[ gid]) None
233+ let rightResult = (% opAdd) None ( Some rightValues.[ gid])
234+
235+ (% PreparePositions.leftRight) gid leftResult rightResult isLeft allValues positions @>
236+
237+ let kernel = clContext.Compile <| preparePositions op
238+
239+ fun ( processor : MailboxProcessor < _ >) ( allIndices : ClArray < int >) ( leftValues : ClArray < 'a >) ( rightValues : ClArray < 'b >) ( isLeft : ClArray < int >) ->
240+
241+ let length = allIndices.Length
242+
243+ let allValues =
244+ clContext.CreateClArrayWithSpecificAllocationMode< 'c>( DeviceOnly, length)
245+
246+ let positions =
247+ clContext.CreateClArrayWithSpecificAllocationMode< int>( DeviceOnly, length)
248+
249+ let ndRange =
250+ Range1D.CreateValid( length, workGroupSize)
251+
252+ let kernel = kernel.GetKernel()
253+
254+ processor.Post(
255+ Msg.MsgSetArguments
256+ ( fun () ->
257+ kernel.KernelFunc ndRange length allIndices leftValues rightValues isLeft allValues positions)
258+ )
259+
260+ processor.Post( Msg.CreateRunMsg<_, _>( kernel))
261+
262+ allValues, positions
263+
264+ ///<param name="clContext">.</param>
265+ ///<param name="op">.</param>
266+ ///<param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
267+ let run < 'a , 'b , 'c when 'a : struct and 'b : struct and 'c : struct > ( clContext : ClContext ) op workGroupSize =
268+
269+ let merge = Merge.run clContext workGroupSize
270+
271+ let prepare =
272+ preparePositions< 'a, 'b, 'c> clContext op workGroupSize
273+
274+ let setPositions =
275+ Common.setPositions clContext workGroupSize
276+
277+ fun ( processor : MailboxProcessor < _ >) allocationMode ( leftVector : ClVector.Sparse < 'a >) ( rightVector : ClVector.Sparse < 'b >) ->
278+
279+ let allIndices , leftValues , rightValues , isLeft = merge processor leftVector rightVector
280+
281+ let allValues , positions =
282+ prepare processor allIndices leftValues rightValues isLeft
283+
284+ processor.Post( Msg.CreateFreeMsg<_>( leftValues))
285+ processor.Post( Msg.CreateFreeMsg<_>( rightValues))
286+ processor.Post( Msg.CreateFreeMsg<_>( isLeft))
287+
288+ let resultValues , resultIndices =
289+ setPositions processor allocationMode allValues allIndices positions
290+
291+ processor.Post( Msg.CreateFreeMsg<_>( allIndices))
292+ processor.Post( Msg.CreateFreeMsg<_>( allValues))
293+ processor.Post( Msg.CreateFreeMsg<_>( positions))
294+
295+ { Context = clContext
296+ Values = resultValues
297+ Indices = resultIndices
298+ Size = max leftVector.Size rightVector.Size }
0 commit comments