@@ -331,7 +331,7 @@ module Operations =
331331 | _ -> failwith " Not implemented yet"
332332
333333 /// <summary>
334- /// CSR Matrix - sparse vector multiplication. Optimized for bool OR and AND operations.
334+ /// CSR Matrix - sparse vector multiplication. Optimized for bool OR and AND operations by skipping reduction stage .
335335 /// </summary>
336336 /// <param name="add">Type of binary function to reduce entries.</param>
337337 /// <param name="mul">Type of binary function to combine entries.</param>
@@ -352,6 +352,50 @@ module Operations =
352352 | ClMatrix.CSR m, ClVector.Sparse v -> Option.map ClVector.Sparse ( run queue m v)
353353 | _ -> failwith " Not implemented yet"
354354
355+ /// <summary>
356+ /// CSR Matrix - sparse vector multiplication with mask. Mask is complemented.
357+ /// </summary>
358+ /// <param name="add">Type of binary function to reduce entries.</param>
359+ /// <param name="mul">Type of binary function to combine entries.</param>
360+ /// <param name="clContext">OpenCL context.</param>
361+ /// <param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
362+ let SpMSpVMasked
363+ ( add : Expr < 'c option -> 'c option -> 'c option >)
364+ ( mul : Expr < 'a option -> 'b option -> 'c option >)
365+ ( clContext : ClContext )
366+ workGroupSize
367+ =
368+
369+ let run =
370+ SpMSpV.Masked.runMasked add mul clContext workGroupSize
371+
372+ fun ( queue : RawCommandQueue ) ( matrix : ClMatrix < 'a >) ( vector : ClVector < 'b >) ( mask : ClVector < 'd >) ->
373+ match matrix, vector, mask with
374+ | ClMatrix.CSR m, ClVector.Sparse v, ClVector.Dense mask -> Option.map ClVector.Sparse ( run queue m v mask)
375+ | _ -> failwith " Not implemented yet"
376+
377+ /// <summary>
378+ /// CSR Matrix - sparse vector multiplication with mask. Mask is complemented. Optimized for bool OR and AND operations by skipping reduction stage.
379+ /// </summary>
380+ /// <param name="add">Type of binary function to reduce entries.</param>
381+ /// <param name="mul">Type of binary function to combine entries.</param>
382+ /// <param name="clContext">OpenCL context.</param>
383+ /// <param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
384+ let SpMSpVMaskedBool
385+ ( add : Expr < bool option -> bool option -> bool option >)
386+ ( mul : Expr < bool option -> bool option -> bool option >)
387+ ( clContext : ClContext )
388+ workGroupSize
389+ =
390+
391+ let run =
392+ SpMSpV.Masked.runMaskedBoolStandard add mul clContext workGroupSize
393+
394+ fun ( queue : RawCommandQueue ) ( matrix : ClMatrix < 'a >) ( vector : ClVector < 'b >) ( mask : ClVector < 'd >) ->
395+ match matrix, vector, mask with
396+ | ClMatrix.CSR m, ClVector.Sparse v, ClVector.Dense mask -> Option.map ClVector.Sparse ( run queue m v mask)
397+ | _ -> failwith " Not implemented yet"
398+
355399 /// <summary>
356400 /// CSR Matrix - sparse vector multiplication.
357401 /// </summary>
0 commit comments