@@ -393,13 +393,15 @@ extern "C" int onemklHgemm_batch(syclQueue_t device_queue, onemklTranspose trans
393393 int64_t *ldb, uint16_t *beta, short **c,
394394 int64_t *ldc, int64_t group_count, int64_t *group_size) {
395395 gemmBatchInfo gemmInfo (device_queue, group_count, transa, transb);
396+ device_queue->val .wait_and_throw ();
396397 auto status = oneapi::mkl::blas::column_major::gemm_batch (device_queue->val ,
397398 &gemmInfo.m_transa [0 ], &gemmInfo.m_transb [0 ],
398399 m, n, k, reinterpret_cast <sycl::half *>(alpha),
399400 reinterpret_cast <const sycl::half **>(&a[0 ]), lda,
400401 reinterpret_cast <const sycl::half **>(&b[0 ]), ldb,
401402 reinterpret_cast <sycl::half *>(beta), reinterpret_cast <sycl::half **>(&c[0 ]),
402403 ldc, group_count, group_size, {});
404+ device_queue->val .wait_and_throw ();
403405 return 0 ;
404406}
405407
@@ -410,13 +412,15 @@ extern "C" int onemklSgemm_batch(syclQueue_t device_queue, onemklTranspose trans
410412 int64_t *ldb, float *beta, float **c,
411413 int64_t *ldc, int64_t group_count, int64_t *group_size) {
412414 gemmBatchInfo gemmInfo (device_queue, group_count, transa, transb);
415+ device_queue->val .wait_and_throw ();
413416 auto status = oneapi::mkl::blas::column_major::gemm_batch (device_queue->val ,
414417 &gemmInfo.m_transa [0 ], &gemmInfo.m_transb [0 ],
415418 m, n, k, alpha,
416419 (const float **)&a[0 ], lda,
417420 (const float **)&b[0 ], ldb,
418421 beta, &c[0 ], ldc,
419422 group_count, group_size, {});
423+ device_queue->val .wait_and_throw ();
420424 return 0 ;
421425}
422426
@@ -427,13 +431,15 @@ extern "C" int onemklDgemm_batch(syclQueue_t device_queue, onemklTranspose trans
427431 int64_t *ldb, double *beta, double **c,
428432 int64_t *ldc, int64_t group_count, int64_t *group_size) {
429433 gemmBatchInfo gemmInfo (device_queue, group_count, transa, transb);
434+ device_queue->val .wait_and_throw ();
430435 auto status = oneapi::mkl::blas::column_major::gemm_batch (device_queue->val ,
431436 &gemmInfo.m_transa [0 ], &gemmInfo.m_transb [0 ],
432437 m, n, k, alpha,
433438 (const double **)&a[0 ], lda,
434439 (const double **)&b[0 ], ldb,
435440 beta, &c[0 ], ldc,
436441 group_count, group_size, {});
442+ device_queue->val .wait_and_throw ();
437443 return 0 ;
438444}
439445
@@ -445,6 +451,7 @@ extern "C" int onemklCgemm_batch(syclQueue_t device_queue, onemklTranspose trans
445451 int64_t *ldb, float _Complex *beta, float _Complex **c,
446452 int64_t *ldc, int64_t group_count, int64_t *group_size) {
447453 gemmBatchInfo gemmInfo (device_queue, group_count, transa, transb);
454+ device_queue->val .wait_and_throw ();
448455 auto status = oneapi::mkl::blas::column_major::gemm_batch (device_queue->val ,
449456 &gemmInfo.m_transa [0 ], &gemmInfo.m_transb [0 ],
450457 m, n, k, reinterpret_cast <std::complex <float > *>(alpha),
@@ -455,6 +462,7 @@ extern "C" int onemklCgemm_batch(syclQueue_t device_queue, onemklTranspose trans
455462 reinterpret_cast <std::complex <float > *>(beta),
456463 reinterpret_cast <std::complex <float > **>(&c[0 ]), ldc,
457464 group_count, group_size, {});
465+ device_queue->val .wait_and_throw ();
458466 return 0 ;
459467}
460468
@@ -467,6 +475,7 @@ extern "C" int onemklZgemm_batch(syclQueue_t device_queue, onemklTranspose trans
467475 double _Complex **c,
468476 int64_t *ldc, int64_t group_count, int64_t *group_size) {
469477 gemmBatchInfo gemmInfo (device_queue, group_count, transa, transb);
478+ device_queue->val .wait_and_throw ();
470479 auto status = oneapi::mkl::blas::column_major::gemm_batch (device_queue->val ,
471480 &gemmInfo.m_transa [0 ], &gemmInfo.m_transb [0 ],
472481 m, n, k, reinterpret_cast <std::complex <double > *>(alpha),
@@ -477,6 +486,7 @@ extern "C" int onemklZgemm_batch(syclQueue_t device_queue, onemklTranspose trans
477486 reinterpret_cast <std::complex <double > *>(beta),
478487 reinterpret_cast <std::complex <double > **>(&c[0 ]), ldc,
479488 group_count, group_size, {});
489+ device_queue->val .wait_and_throw ();
480490 return 0 ;
481491}
482492
@@ -487,12 +497,14 @@ extern "C" int onemklStrsm_batch(syclQueue_t device_queue, onemklSide left_right
487497 int64_t group_count, int64_t *group_size) {
488498 trsmBatchInfo trsmInfo (device_queue, left_right, upper_lower, transa,
489499 unit_diag, group_count);
500+ device_queue->val .wait_and_throw ();
490501
491502 auto status = oneapi::mkl::blas::column_major::trsm_batch (device_queue->val ,
492503 &trsmInfo.m_leftright [0 ], &trsmInfo.m_upperlower [0 ],
493504 &trsmInfo.m_transa [0 ], &trsmInfo.m_unitdiag [0 ],
494505 m, n, alpha, (const float **)&a[0 ], lda,
495506 &b[0 ], ldb, group_count, group_size, {});
507+ device_queue->val .wait_and_throw ();
496508 return 0 ;
497509}
498510
@@ -504,12 +516,14 @@ extern "C" int onemklDtrsm_batch(syclQueue_t device_queue, onemklSide left_right
504516 int64_t *group_size) {
505517 trsmBatchInfo trsmInfo (device_queue, left_right, upper_lower, transa,
506518 unit_diag, group_count);
519+ device_queue->val .wait_and_throw ();
507520
508521 auto status = oneapi::mkl::blas::column_major::trsm_batch (device_queue->val ,
509522 &trsmInfo.m_leftright [0 ], &trsmInfo.m_upperlower [0 ],
510523 &trsmInfo.m_transa [0 ], &trsmInfo.m_unitdiag [0 ],
511524 m, n, alpha, (const double **)&a[0 ], lda, &b[0 ],
512525 ldb, group_count, group_size, {});
526+ device_queue->val .wait_and_throw ();
513527 return 0 ;
514528}
515529
@@ -521,6 +535,7 @@ extern "C" int onemklCtrsm_batch(syclQueue_t device_queue, onemklSide left_right
521535 int64_t group_count, int64_t *group_size) {
522536 trsmBatchInfo trsmInfo (device_queue, left_right, upper_lower, transa,
523537 unit_diag, group_count);
538+ device_queue->val .wait_and_throw ();
524539
525540 auto status = oneapi::mkl::blas::column_major::trsm_batch (device_queue->val ,
526541 &trsmInfo.m_leftright [0 ], &trsmInfo.m_upperlower [0 ],
@@ -529,6 +544,7 @@ extern "C" int onemklCtrsm_batch(syclQueue_t device_queue, onemklSide left_right
529544 reinterpret_cast <const std::complex <float > **>(&a[0 ]),
530545 lda, reinterpret_cast <std::complex <float > **>(&b[0 ]),
531546 ldb, group_count, group_size, {});
547+ device_queue->val .wait_and_throw ();
532548 return 0 ;
533549}
534550
@@ -540,6 +556,7 @@ extern "C" int onemklZtrsm_batch(syclQueue_t device_queue, onemklSide left_right
540556 int64_t group_count, int64_t *group_size) {
541557 trsmBatchInfo trsmInfo (device_queue, left_right,
542558 upper_lower, transa, unit_diag, group_count);
559+ device_queue->val .wait_and_throw ();
543560
544561 auto status = oneapi::mkl::blas::column_major::trsm_batch (device_queue->val ,
545562 &trsmInfo.m_leftright [0 ], &trsmInfo.m_upperlower [0 ],
@@ -548,5 +565,6 @@ extern "C" int onemklZtrsm_batch(syclQueue_t device_queue, onemklSide left_right
548565 reinterpret_cast <const std::complex <double > **>(&a[0 ]),
549566 lda, reinterpret_cast <std::complex <double > **>(&b[0 ]),
550567 ldb, group_count, group_size, {});
568+ device_queue->val .wait_and_throw ();
551569 return 0 ;
552570}
0 commit comments