@@ -658,212 +658,6 @@ class gemm_t<
658658 }
659659
660660 private:
661- // inline void dequantize(
662- // matB_acc_t& matB_acc,
663- // matB_t& matB,
664- // scale_t& scale,
665- // zero_pt_t& zero_pt) {
666- // // no tail, because this is matB
667- // constexpr uint32_t num_block_x = tile_size_x_b / block_size_x_b;
668- // constexpr uint32_t num_block_y = tile_size_y_b / block_size_y_b;
669- // #pragma unroll
670- // for (uint32_t i = 0; i < num_block_y; ++i) {
671- // #pragma unroll
672- // for (uint32_t j = 0; j < num_block_x; ++j) {
673- // int block_id = (i * num_block_x + j);
674- // // Must be little-endian
675- // auto matB_blk = matB.reg.xetla_format<uint8_t>()
676- // .xetla_select<matB_acc_t::block_elems / 2, 1>(
677- // block_id * matB_acc_t::block_elems / 2);
678-
679- // auto dst_blk = matB_acc.reg.xetla_select<matB_acc_t::block_elems,
680- // 1>(
681- // block_id * matB_acc_t::block_elems);
682-
683- // // int8 includes 2 4bits data.
684- // xetla_vector<int8_t, matB_acc_t::block_elems> cvt_blk_i8;
685-
686- // // lowest 4 bit
687- // {
688- // cvt_blk_i8.xetla_select<matB_acc_t::block_elems / 2, 2>(0) =
689- // matB_blk & 0xf;
690- // }
691- // // highest 4 bit
692- // {
693- // cvt_blk_i8.xetla_select<matB_acc_t::block_elems / 2, 2>(1) =
694- // matB_blk >> 4;
695- // }
696-
697- // // (b_i8 - zero_pt_i8) x scale = fp16
698- // constexpr uint32_t step = std::min(block_size_y_b, dequant_s);
699- // #pragma unroll
700- // for (uint32_t jj = 0; jj < block_size_x_b; jj++) {
701- // #pragma unroll
702- // for (uint32_t ii = 0; ii < block_size_y_b; ii += step) {
703- // uint32_t offset_y_in_tile = i * block_size_y_b + ii;
704- // uint32_t offset_x_in_tile = j * block_size_x_b + jj;
705-
706- // uint32_t scale_idx =
707- // (offset_y_in_tile) / dequant_s * scale_t::block_size_x +
708- // offset_x_in_tile;
709-
710- // if constexpr (compute_policy::quant_mode ==
711- // quant_mode::S4_ASYM) {
712- // uint32_t zero_pt_idx =
713- // offset_y_in_tile / dequant_s * zero_pt_t::block_size_x +
714- // offset_x_in_tile / pack_ratio;
715- // native_type_t<dtype_b> zero_pt_pack =
716- // zero_pt.reg[zero_pt_idx];
717-
718- // int8_t zero_pt_i8 =
719- // (zero_pt_pack >>
720- // (4 * ((wg_start_n + offset_x_in_tile) % pack_ratio))) &
721- // 0xf;
722- // // sycl::ext::oneapi::experimental::printf(
723- // // "zero_pt.reg[%d} %x zero_pt_i8 %x
724- // offset_x_in_tile:%d
725- // // \n", zero_pt_idx, zero_pt_pack, (int32_t)zero_pt_i8 ,
726- // // offset_x_in_tile);
727-
728- // cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) =
729- // cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b +
730- // ii) - zero_pt_i8;
731- // } else if constexpr (
732- // compute_policy::quant_mode ==
733- // quant_mode::S4_FULLRANGE_NO_ZP) {
734- // cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) =
735- // cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b +
736- // ii) - int8_t(8);
737- // }
738- // dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) =
739- // cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii)
740- // * scale.reg[scale_idx];
741-
742- // // sycl::ext::oneapi::experimental::printf(
743- // // "scale[%d] %f \n",
744- // // scale_idx,
745- // // float(sycl::half(scale.reg.xetla_select<1,
746- // 1>(scale_idx))));
747- // }
748- // }
749- // }
750- // }
751- // }
752-
753- /*
754- inline void dequantize(
755- matB_acc_t & matB_acc,
756- matB_t & matB,
757- scale_t & scale,
758- zero_pt_t & zero_pt) {
759- // no tail, because this is matB
760- constexpr uint32_t num_block_x = tile_size_x_b / block_size_x_b;
761- constexpr uint32_t num_block_y = tile_size_y_b / block_size_y_b;
762-
763- constexpr uint32_t block_b_y_per_scale = dequant_s / block_size_y_b;
764- constexpr uint32_t block_b_x_per_scale = dequant_s / block_size_x_b;
765- #pragma unroll
766- for (uint32_t i = 0; i < num_block_y; ++i) {
767- #pragma unroll
768- for (uint32_t j = 0; j < num_block_x; ++j) {
769- int block_id = (i * num_block_x + j);
770- auto matB_blk = matB.reg
771- .xetla_select<matB_t::block_elems, 1>(
772- block_id * matB_t::block_elems)
773- .xetla_format<int8_t>();
774- int scale_block_id = (i / block_b_y_per_scale * num_block_x + j);
775- auto scale_vec = scale.reg.xetla_select<scale_t::block_size_x, 1>(
776- scale_block_id * scale_t::block_size_x);
777- auto dst_blk = matB_acc.reg.xetla_select<matB_acc_t::block_elems, 1>(
778- block_id * matB_acc_t::block_elems);
779-
780- // 2: int8 includes 2 4bits data.
781- xetla_vector<uint8_t, block_size_x_b * block_size_y_b> cvt_blk;
782-
783- xetla_vector<int32_t, block_size_x_b * block_size_y_b> cvt_blk_i32;
784- if constexpr (compute_policy::quant_mode == quant_mode::S4_ASYM) {
785- auto zero_pt_vec = zero_pt.reg
786- .xetla_select<zero_pt_t::block_size_x, 1>(
787- scale_block_id * zero_pt_t::block_size_x)
788- .xetla_format<uint8_t>();
789- cvt_blk.xetla_select<matB_t::block_elems, 2>(0) = matB_blk & 0x0f;
790- cvt_blk.xetla_select<matB_t::block_elems, 2>(1) = matB_blk >> 4;
791- xetla_vector<uint8_t, block_size_x_b> zero_pt_sub;
792- zero_pt_sub.xetla_select<block_size_x_b / 2, 2>(0) =
793- zero_pt_vec & 0x0f;
794- zero_pt_sub.xetla_select<block_size_x_b / 2, 2>(1) =
795- zero_pt_vec >> 4;
796- xetla_vector<uint8_t, block_size_x_b * block_size_y_b> zero_pt_blk;
797- #pragma unroll
798- for (uint32_t row = 0; row < block_size_y_b; row++) {
799- zero_pt_blk.xetla_select<block_size_x_b, 1>(row * block_size_x_b)
800- .xetla_format<int8_t>() =
801- zero_pt_sub.xetla_format<int8_t>() + int8_t(1);
802- }
803- cvt_blk_i32 =
804- (cvt_blk.xetla_format<int8_t>() -
805- zero_pt_blk.xetla_format<int8_t>());
806- }
807- if constexpr (
808- compute_policy::quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
809- xetla_vector<int8_t, block_size_x_b * block_size_y_b> cvt_blk_i8;
810- cvt_blk_i8.xetla_select<matB_t::block_elems, 2>(0) =
811- matB_blk & 0x0f;
812- cvt_blk_i8.xetla_select<matB_t::block_elems, 2>(0) =
813- cvt_blk_i8.xetla_select<matB_t::block_elems, 2>(0) << 4;
814- cvt_blk_i8.xetla_select<matB_t::block_elems, 2>(0) =
815- cvt_blk_i8.xetla_select<matB_t::block_elems, 2>(0) >> 4;
816- cvt_blk_i8.xetla_select<matB_t::block_elems, 2>(1) =
817- matB_blk.xetla_format<int8_t>() >> 4;
818- cvt_blk_i32 = (cvt_blk_i8.xetla_format<int8_t>());
819- }
820- if constexpr (compute_policy::mma_engine == mma_engine::xmx) {
821- constexpr uint32_t vnni_rows =
822- sizeof(uint32_t) / sizeof(dtype_mma_b);
823- xetla_vector<dtype_mma_b, matB_acc_t::block_elems * vnni_rows>
824- temp_blk;
825- temp_blk.xetla_select<matB_acc_t::block_elems, vnni_rows>(0) =
826- cvt_blk_i32;
827-
828- #pragma unroll
829- for (uint32_t k = 0; k < block_size_y_b; k += vnni_rows) {
830- #pragma unroll
831- for (uint32_t row = 0; row < vnni_rows; row++) {
832- temp_blk.xetla_select<block_size_x_b, vnni_rows>(
833- row + block_size_x_b * k * vnni_rows) =
834- temp_blk.xetla_select<block_size_x_b, vnni_rows>(
835- (k + row) * block_size_x_b * vnni_rows);
836- }
837- }
838-
839- xetla_vector<dtype_scale, block_size_x_b * vnni_rows> scale_blk;
840- #pragma unroll
841- for (uint32_t row = 0; row < vnni_rows; row++) {
842- scale_blk.xetla_select<block_size_x_b, vnni_rows>(row) =
843- scale_vec;
844- }
845-
846- #pragma unroll
847- for (uint32_t k = 0; k < block_size_y_b; k += vnni_rows) {
848- dst_blk.xetla_select<block_size_x_b * vnni_rows, 1>(
849- k * block_size_x_b) =
850- temp_blk.xetla_select<block_size_x_b * vnni_rows, 1>(
851- k * block_size_x_b * vnni_rows) *
852- scale_blk;
853- }
854- } else {
855- #pragma unroll
856- for (uint32_t k = 0; k < block_size_y_b; k++) {
857- dst_blk.xetla_select<block_size_x_b, 1>(k * block_size_x_b) =
858- cvt_blk_i32.xetla_select<block_size_x_b, 1>(
859- k * block_size_x_b) *
860- scale_vec;
861- }
862- }
863- }
864- }
865- } */
866-
867661 // / @brief Updates tile base descriptor based on the tid.
868662 __XETLA_API static void update_sg_tile_tdesc (
869663 arguments_t & args,
0 commit comments