@@ -22,6 +22,7 @@ let wgSize = Constants.Common.defaultWorkGroupSize
2222let makeTest
2323 formatFrom
2424 ( convertFun : RawCommandQueue -> AllocationFlag -> ClVector < 'a > -> ClVector < 'a >)
25+ ( convertFunUnsorted : option < RawCommandQueue -> AllocationFlag -> ClVector < 'a > -> ClVector < 'a >>)
2526 isZero
2627 case
2728 ( array : 'a [])
@@ -37,7 +38,7 @@ let makeTest
3738
3839 let actual =
3940 let clVector = vector.ToDevice context
40- let convertedVector = convertFun q HostInterop clVector
41+ let convertedVector = convertFun q DeviceOnly clVector
4142
4243 let res = convertedVector.ToHost q
4344
@@ -56,6 +57,27 @@ let makeTest
5657
5758 Expect.equal actual expected " Vectors must be the same"
5859
60+ match convertFunUnsorted with
61+ | None -> ()
62+ | Some convertFunUnsorted ->
63+ let clVector = vector.ToDevice context
64+ let convertedVector = convertFunUnsorted q DeviceOnly clVector
65+
66+ let res = convertedVector.ToHost q
67+
68+ match res, expected with
69+ | Vector.Sparse res, Vector.Sparse expected ->
70+ let iv = Array.zip res.Indices res.Values
71+ let resSorted = Array.sortBy ( fun ( i , v ) -> i) iv
72+ let indices , values = Array.unzip resSorted
73+ Expect.equal indices expected.Indices " Indices must be the same"
74+ Expect.equal values expected.Values " Values must be the same"
75+ Expect.equal res.Size expected.Size " Size must be the same"
76+ | _ -> ()
77+
78+ clVector.Dispose()
79+ convertedVector.Dispose()
80+
5981let testFixtures case =
6082 let getCorrectnessTestName datatype formatFrom =
6183 sprintf $" Correctness on %s {datatype}, %A {formatFrom} -> %A {case.Format}"
@@ -68,19 +90,21 @@ let testFixtures case =
6890 match case.Format with
6991 | Sparse ->
7092 [ let convertFun = Vector.toSparse context wgSize
93+ let convertFunUnsorted = Vector.toSparseUnsorted context wgSize
7194
7295 Utils.listOfUnionCases< VectorFormat>
7396 |> List.map
7497 ( fun formatFrom ->
75- makeTest formatFrom convertFun ((=) 0 ) case
98+ makeTest formatFrom convertFun ( Some convertFunUnsorted ) ( (=) 0 ) case
7699 |> testPropertyWithConfig config ( getCorrectnessTestName " int" formatFrom))
77100
78101 let convertFun = Vector.toSparse context wgSize
102+ let convertFunUnsorted = Vector.toSparseUnsorted context wgSize
79103
80104 Utils.listOfUnionCases< VectorFormat>
81105 |> List.map
82106 ( fun formatFrom ->
83- makeTest formatFrom convertFun ((=) false ) case
107+ makeTest formatFrom convertFun ( Some convertFunUnsorted ) ( (=) false ) case
84108 |> testPropertyWithConfig config ( getCorrectnessTestName " bool" formatFrom)) ]
85109 |> List.concat
86110 | Dense ->
@@ -89,15 +113,15 @@ let testFixtures case =
89113 Utils.listOfUnionCases< VectorFormat>
90114 |> List.map
91115 ( fun formatFrom ->
92- makeTest formatFrom convertFun ((=) 0 ) case
116+ makeTest formatFrom convertFun None ((=) 0 ) case
93117 |> testPropertyWithConfig config ( getCorrectnessTestName " int" formatFrom))
94118
95119 let convertFun = Vector.toDense context wgSize
96120
97121 Utils.listOfUnionCases< VectorFormat>
98122 |> List.map
99123 ( fun formatFrom ->
100- makeTest formatFrom convertFun ((=) false ) case
124+ makeTest formatFrom convertFun None ((=) false ) case
101125 |> testPropertyWithConfig config ( getCorrectnessTestName " bool" formatFrom)) ]
102126 |> List.concat
103127
0 commit comments