@@ -451,6 +451,286 @@ class wi_element {
451451 }
452452};
453453
454+ // Note that similarly to the other matrix functions, uint16_t is used here to
455+ // represent bf16 type. Since the AMX and DPAS implementations don't support
456+ // uint16_t, this interpretation is possible. This design choice was made before
457+ // the introduction of SYCL experimental bfloat16 type. Our plan is to move
458+ // towards using the SYCL bfloat16. But since it is still experimental, we will
459+ // probably keep both uint16 interpretation and SYCL bfloat16.
460+ template <size_t NumRows, size_t NumCols, matrix_layout Layout, typename Group>
461+ class wi_element <uint16_t , NumRows, NumCols, Layout, Group> {
462+ joint_matrix<uint16_t , NumRows, NumCols, Layout, Group> &M;
463+ std::size_t idx;
464+
465+ public:
466+ wi_element (joint_matrix<uint16_t , NumRows, NumCols, Layout, Group> &Mat,
467+ std::size_t i)
468+ : M(Mat), idx(i) {}
469+ operator uint16_t () {
470+ #ifdef __SYCL_DEVICE_ONLY__
471+ return __spirv_VectorExtractDynamic (M.spvm , idx);
472+ #else
473+ throw runtime_error (" joint matrix is not supported on host device." ,
474+ PI_INVALID_DEVICE);
475+ #endif // __SYCL_DEVICE_ONLY__
476+ }
477+
478+ explicit operator bool () {
479+ #ifdef __SYCL_DEVICE_ONLY__
480+ return __spirv_VectorExtractDynamic (M.spvm , idx) !=
481+ static_cast <uint16_t >(0 );
482+ #else
483+ throw runtime_error (" joint matrix is not supported on host device." ,
484+ PI_INVALID_DEVICE);
485+ #endif // __SYCL_DEVICE_ONLY__
486+ }
487+
488+ wi_element &operator =(const uint16_t &rhs) {
489+ #ifdef __SYCL_DEVICE_ONLY__
490+ M.spvm = __spirv_VectorInsertDynamic (M.spvm , rhs, idx);
491+ return *this ;
492+ #else
493+ (void )rhs;
494+ throw runtime_error (" joint matrix is not supported on host device." ,
495+ PI_INVALID_DEVICE);
496+ #endif // __SYCL_DEVICE_ONLY__
497+ }
498+
499+ wi_element &
500+ operator =(const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &rhs) {
501+ #ifdef __SYCL_DEVICE_ONLY__
502+ M.spvm = __spirv_VectorInsertDynamic (
503+ M.spvm , __spirv_VectorExtractDynamic (rhs.M .spvm , rhs.idx ), idx);
504+ return *this ;
505+ #else
506+ (void )rhs;
507+ throw runtime_error (" joint matrix is not supported on host device." ,
508+ PI_INVALID_DEVICE);
509+ #endif // __SYCL_DEVICE_ONLY__
510+ }
511+
512+ // We use here the following functions for conversion (bf16=>fp32 and
513+ // fp32=>bf16). This is a workaround until we are able to use
514+ // __spirv_ConvertFToBF16INTEL and __spirv_ConvertBF16ToFINTEL once these are
515+ // supported in the CPU backend
516+ static float make_fp32 (uint16_t x) {
517+ unsigned int y = x;
518+ y = y << 16 ;
519+ float *res = reinterpret_cast <float *>(&y);
520+ return *res;
521+ }
522+
523+ static uint16_t make_bf16 (float x) {
524+ int *res = reinterpret_cast <int *>(&x);
525+ *res = *res >> 16 ;
526+ return (uint16_t )*res;
527+ }
528+
529+ friend uint16_t
530+ operator +(const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &lhs,
531+ const uint16_t &rhs) {
532+ #ifdef __SYCL_DEVICE_ONLY__
533+ return make_bf16 (
534+ make_fp32 (__spirv_VectorExtractDynamic (lhs.M .spvm , lhs.idx )) +
535+ make_fp32 (rhs));
536+ #else
537+ (void )lhs;
538+ (void )rhs;
539+ throw runtime_error (" joint matrix is not supported on host device." ,
540+ PI_INVALID_DEVICE);
541+ #endif // __SYCL_DEVICE_ONLY__
542+ }
543+
544+ wi_element &operator +=(const uint16_t &rhs) {
545+ #ifdef __SYCL_DEVICE_ONLY__
546+ M.spvm = __spirv_VectorInsertDynamic (
547+ M.spvm ,
548+ make_bf16 (make_fp32 (__spirv_VectorExtractDynamic (M.spvm , idx)) +
549+ make_fp32 (rhs)),
550+ idx);
551+ return *this ;
552+ #else
553+ (void )rhs;
554+ throw runtime_error (" joint matrix is not supported on host device." ,
555+ PI_INVALID_DEVICE);
556+ #endif // __SYCL_DEVICE_ONLY__
557+ }
558+
559+ friend uint16_t
560+ operator -(const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &lhs,
561+ const uint16_t &rhs) {
562+ #ifdef __SYCL_DEVICE_ONLY__
563+ return make_bf16 (
564+ make_fp32 (__spirv_VectorExtractDynamic (lhs.M .spvm , lhs.idx )) -
565+ make_fp32 (rhs));
566+ #else
567+ (void )lhs;
568+ (void )rhs;
569+ throw runtime_error (" joint matrix is not supported on host device." ,
570+ PI_INVALID_DEVICE);
571+ #endif // __SYCL_DEVICE_ONLY__
572+ }
573+
574+ wi_element &operator -=(const uint16_t &rhs) {
575+ #ifdef __SYCL_DEVICE_ONLY__
576+ M.spvm = __spirv_VectorInsertDynamic (
577+ M.spvm ,
578+ make_bf16 (make_fp32 (__spirv_VectorExtractDynamic (M.spvm , idx)) -
579+ make_fp32 (rhs)),
580+ idx);
581+ return *this ;
582+ #else
583+ (void )rhs;
584+ throw runtime_error (" joint matrix is not supported on host device." ,
585+ PI_INVALID_DEVICE);
586+ #endif // __SYCL_DEVICE_ONLY__
587+ }
588+
589+ friend uint16_t
590+ operator *(const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &lhs,
591+ const uint16_t &rhs) {
592+ #ifdef __SYCL_DEVICE_ONLY__
593+ return make_bf16 (
594+ make_fp32 (__spirv_VectorExtractDynamic (lhs.M .spvm , lhs.idx )) *
595+ make_fp32 (rhs));
596+ #else
597+ (void )lhs;
598+ (void )rhs;
599+ throw runtime_error (" joint matrix is not supported on host device." ,
600+ PI_INVALID_DEVICE);
601+ #endif // __SYCL_DEVICE_ONLY__
602+ }
603+
604+ wi_element &operator *=(const uint16_t &rhs) {
605+ #ifdef __SYCL_DEVICE_ONLY__
606+ M.spvm = __spirv_VectorInsertDynamic (
607+ M.spvm ,
608+ make_bf16 (make_fp32 (__spirv_VectorExtractDynamic (M.spvm , idx)) *
609+ make_fp32 (rhs)),
610+ idx);
611+ return *this ;
612+ #else
613+ (void )rhs;
614+ throw runtime_error (" joint matrix is not supported on host device." ,
615+ PI_INVALID_DEVICE);
616+ #endif // __SYCL_DEVICE_ONLY__
617+ }
618+
619+ friend uint16_t
620+ operator /(const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &lhs,
621+ const uint16_t &rhs) {
622+ #ifdef __SYCL_DEVICE_ONLY__
623+ return make_bf16 (
624+ make_fp32 (__spirv_VectorExtractDynamic (lhs.M .spvm , lhs.idx )) /
625+ make_fp32 (rhs));
626+ #else
627+ (void )lhs;
628+ (void )rhs;
629+ throw runtime_error (" joint matrix is not supported on host device." ,
630+ PI_INVALID_DEVICE);
631+ #endif // __SYCL_DEVICE_ONLY__
632+ }
633+
634+ wi_element &operator /=(const uint16_t &rhs) {
635+ #ifdef __SYCL_DEVICE_ONLY__
636+ M.spvm = __spirv_VectorInsertDynamic (
637+ M.spvm ,
638+ make_bf16 (make_fp32 (__spirv_VectorExtractDynamic (M.spvm , idx)) /
639+ make_fp32 (rhs)),
640+ idx);
641+ return *this ;
642+ #else
643+ (void )rhs;
644+ throw runtime_error (" joint matrix is not supported on host device." ,
645+ PI_INVALID_DEVICE);
646+ #endif // __SYCL_DEVICE_ONLY__
647+ }
648+
649+ friend bool
650+ operator <(const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &lhs,
651+ const uint16_t &rhs) {
652+ #ifdef __SYCL_DEVICE_ONLY__
653+ return make_fp32 (__spirv_VectorExtractDynamic (lhs.M .spvm , lhs.idx )) <
654+ make_fp32 (rhs);
655+ #else
656+ (void )lhs;
657+ (void )rhs;
658+ throw runtime_error (" joint matrix is not supported on host device." ,
659+ PI_INVALID_DEVICE);
660+ #endif // __SYCL_DEVICE_ONLY__
661+ }
662+
663+ friend bool
664+ operator <=(const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &lhs,
665+ const uint16_t &rhs) {
666+ #ifdef __SYCL_DEVICE_ONLY__
667+ return make_fp32 (__spirv_VectorExtractDynamic (lhs.M .spvm , lhs.idx )) <=
668+ make_fp32 (rhs);
669+ #else
670+ (void )lhs;
671+ (void )rhs;
672+ throw runtime_error (" joint matrix is not supported on host device." ,
673+ PI_INVALID_DEVICE);
674+ #endif // __SYCL_DEVICE_ONLY__
675+ }
676+
677+ friend bool
678+ operator >(const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &lhs,
679+ const uint16_t &rhs) {
680+ #ifdef __SYCL_DEVICE_ONLY__
681+ return make_fp32 (__spirv_VectorExtractDynamic (lhs.M .spvm , lhs.idx )) >
682+ make_fp32 (rhs);
683+ #else
684+ (void )lhs;
685+ (void )rhs;
686+ throw runtime_error (" joint matrix is not supported on host device." ,
687+ PI_INVALID_DEVICE);
688+ #endif // __SYCL_DEVICE_ONLY__
689+ }
690+
691+ friend bool
692+ operator >=(const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &lhs,
693+ const uint16_t &rhs) {
694+ #ifdef __SYCL_DEVICE_ONLY__
695+ return make_fp32 (__spirv_VectorExtractDynamic (lhs.M .spvm , lhs.idx )) >=
696+ make_fp32 (rhs);
697+ #else
698+ (void )lhs;
699+ (void )rhs;
700+ throw runtime_error (" joint matrix is not supported on host device." ,
701+ PI_INVALID_DEVICE);
702+ #endif // __SYCL_DEVICE_ONLY__
703+ }
704+
705+ friend bool
706+ operator ==(const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &lhs,
707+ const uint16_t &rhs) {
708+ #ifdef __SYCL_DEVICE_ONLY__
709+ return make_fp32 (__spirv_VectorExtractDynamic (lhs.M .spvm , lhs.idx )) ==
710+ make_fp32 (rhs);
711+ #else
712+ (void )lhs;
713+ (void )rhs;
714+ throw runtime_error (" joint matrix is not supported on host device." ,
715+ PI_INVALID_DEVICE);
716+ #endif // __SYCL_DEVICE_ONLY__
717+ }
718+
719+ friend bool
720+ operator !=(const wi_element<uint16_t , NumRows, NumCols, Layout, Group> &lhs,
721+ const uint16_t &rhs) {
722+ #ifdef __SYCL_DEVICE_ONLY__
723+ return make_fp32 (__spirv_VectorExtractDynamic (lhs.M .spvm , lhs.idx )) !=
724+ make_fp32 (rhs);
725+ #else
726+ (void )lhs;
727+ (void )rhs;
728+ throw runtime_error (" joint matrix is not supported on host device." ,
729+ PI_INVALID_DEVICE);
730+ #endif // __SYCL_DEVICE_ONLY__
731+ }
732+ };
733+
454734template <typename T, size_t NumRows, size_t NumCols, matrix_layout Layout,
455735 typename Group>
456736class wi_slice {
0 commit comments