2222using namespace sycl ;
2323
2424template <typename T>
25- void test (queue &Q, int M, int N, int K)
25+ static
26+ bool test (queue &Q, int M, int N, int K)
2627{
27- std::cout << " \n Benchmarking (" << M << " x " << K << " ) x (" << K << " x " << N << " ) matrix multiplication, " << type_string<T>() << std::endl ;;
28+ std::cout << " \n Benchmarking (" << M << " x " << K << " ) x (" << K << " x " << N << " ) matrix multiplication, " << type_string<T>() << " \n " ;;
2829
2930 std::cout << " -> Initializing data...\n " ;
3031
@@ -38,7 +39,8 @@ void test(queue &Q, int M, int N, int K)
3839 auto C = malloc_device<T>(ldc * N, Q);
3940
4041 constexpr int rd_size = 1048576 ;
41- auto host_data = malloc_host<T>(rd_size, Q);
42+ std::vector<T> host_vector (rd_size);
43+ auto host_data = host_vector.data ();
4244
4345 /* Measure time for a given number of GEMM calls */
4446 auto time_gemms = [=, &Q](int runs) -> double {
@@ -74,10 +76,9 @@ void test(queue &Q, int M, int N, int K)
7476 }
7577 if (linear_id >= elems) break ;
7678 }
77- std::cout << (ok ? " passes." : " FAILS!" ) << std::endl;
78- if (!ok) {
79- exit (1 );
80- }
79+
80+ std::cout << " gemm " << (ok ? " passes." : " FAILS!" ) << " for type: " << type_string<T>() << " \n " ;
81+ if (!ok) { return false ; }
8182
8283 /* Fill A/B with random data */
8384 generate_random_data (rd_size, host_data);
@@ -114,15 +115,17 @@ void test(queue &Q, int M, int N, int K)
114115 unit = ' P' ;
115116 }
116117
117- std::cout << " \n Average performance: " << flops << unit << ' F' << std::endl ;
118+ std::cout << " \n Average performance: " << flops << unit << ' F' << " \n " ;
118119
119120 /* Free data */
120- free (A, Q);
121- free (B, Q);
122121 free (C, Q);
123- free (host_data, Q);
122+ free (B, Q);
123+ free (A, Q);
124+
125+ return true ;
124126}
125127
128+ static
126129void usage (const char *pname)
127130{
128131 std::cerr << " Usage:\n "
@@ -133,17 +136,37 @@ void usage(const char *pname)
133136 << " double [default]\n "
134137 << " single\n "
135138 << " half\n "
139+ << " all (runs all above)\n "
136140 << " \n "
137141 << " This benchmark uses the default DPC++ device, which can be controlled using\n "
138142 << " the ONEAPI_DEVICE_SELECTOR environment variable\n " ;
139143 std::exit (1 );
140144}
141145
146+ static
147+ bool device_has_fp64 (sycl::device const & D) {
148+ return (D.get_info <sycl::info::device::double_fp_config>().size () != 0 );
149+ }
150+
151+ static
152+ void device_info (sycl::device const & D) {
153+ std::cout << " oneMKL DPC++ GEMM benchmark\n "
154+ << " ---------------------------\n "
155+ << " Platform: " << D.get_platform ().get_info <info::platform::name>() << " \n "
156+ << " Device: " << D.get_info <info::device::name>() << " \n "
157+ << " Driver_version: " << D.get_info <info::device::driver_version>() << " \n "
158+ << " Core/EU count: " << D.get_info <info::device::max_compute_units>() << " \n "
159+ << " Maximum clock frequency: " << D.get_info <info::device::max_clock_frequency>() << " MHz" << " \n "
160+ << " FP64 capability: " << (device_has_fp64 (D) ? " yes" : " no" ) << " \n "
161+ << " \n "
162+ ;
163+ }
164+
142165int main (int argc, char **argv)
143166{
144167 auto pname = argv[0 ];
145168 int M = 4096 , N = 4096 , K = 4096 ;
146- std::string type = " double " ;
169+ std::string type = " none " ;
147170
148171 if (argc <= 1 )
149172 usage (pname);
@@ -163,20 +186,55 @@ int main(int argc, char **argv)
163186 if (M <= 0 || N <= 0 || K <= 0 )
164187 usage (pname);
165188
166- queue Q;
189+ bool g_success = true ;
190+ try {
191+ device D (default_selector_v);
192+ device_info (D);
167193
168- std::cout << " oneMKL DPC++ GEMM benchmark\n "
169- << " ---------------------------\n "
170- << " Device: " << Q.get_device ().get_info <info::device::name>() << std::endl
171- << " Core/EU count: " << Q.get_device ().get_info <info::device::max_compute_units>() << std::endl
172- << " Maximum clock frequency: " << Q.get_device ().get_info <info::device::max_clock_frequency>() << " MHz" << std::endl;
173-
174- if (type == " double" )
175- test<double >(Q, M, N, K);
176- else if (type == " single" || type == " float" )
177- test<float >(Q, M, N, K);
178- else if (type == " half" )
179- test<half>(Q, M, N, K);
180- else
181- usage (pname);
194+ context C (D);
195+ queue Q (C, D);
196+
197+ if (" none" == type)
198+ std::string type = device_has_fp64 (D) ? " double" : " float" ;
199+
200+ if (type == " double" ) {
201+ if (device_has_fp64 (D))
202+ test<double >(Q, M, N, K);
203+ else {
204+ std::cout << " no FP64 capability on given SYCL device and type == \" double\" " ;
205+ return 1 ;
206+ }
207+ }
208+ else if (type == " single" || type == " float" )
209+ g_success = g_success && test<float >(Q, M, N, K);
210+ else if (type == " half" )
211+ g_success = g_success && test<half>(Q, M, N, K);
212+ else if (type == " all" ) {
213+ type = " half" ;
214+ g_success = g_success && test<half>(Q, M, N, K);
215+
216+ type = " float" ;
217+ g_success = g_success && test<float >(Q, M, N, K);
218+
219+ if (device_has_fp64 (D)) {
220+ type = " double" ;
221+ g_success = g_success && test<double >(Q, M, N, K);
222+ }
223+ } else {
224+ type = " none" ;
225+ usage (pname);
226+ }
227+ } catch (sycl::exception const & e) {
228+ std::cerr << " SYCL exception: " << e.what () << " \n " ;
229+ std::cerr << " while performing GEMM for"
230+ << " M=" << M
231+ << " , N=" << N
232+ << " , K=" << K
233+ << " , type `" << type << " `"
234+ << " \n " ;
235+ return 139 ;
236+ }
237+
238+ return g_success ? 0 : 1 ;
182239}
240+
0 commit comments