|
| 1 | +/******************************************************************************* |
| 2 | + * Copyright (c) 2023-2024 Intel Corporation |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + *******************************************************************************/ |
| 16 | + |
| 17 | +/// @file |
| 18 | +/// C++ API |
| 19 | + |
| 20 | +#pragma once |
| 21 | + |
| 22 | +#include <experimental/kernel/col_major_shuf/api.hpp> |
| 23 | +#include <experimental/kernel/col_major_shuf/common.hpp> |
| 24 | +#include <experimental/kernel/col_major_shuf/config.hpp> |
| 25 | + |
| 26 | +namespace gpu::xetla::kernel { |
| 27 | +template < |
| 28 | + typename dtype_in_, |
| 29 | + typename dtype_out_, |
| 30 | + typename dtype_gidx_, |
| 31 | + mem_layout mem_layout_in_, |
| 32 | + typename col_major_shuf_attr_, |
| 33 | + gpu_arch arch_> |
| 34 | +struct col_major_shuf_t< |
| 35 | + dtype_in_, |
| 36 | + dtype_out_, |
| 37 | + dtype_gidx_, |
| 38 | + mem_layout_in_, |
| 39 | + col_major_shuf_attr_, |
| 40 | + arch_> { |
| 41 | + using dtype_in = dtype_in_; |
| 42 | + using dtype_out = dtype_out_; |
| 43 | + using dtype_gidx = dtype_gidx_; |
| 44 | + using col_major_shuf_attr = col_major_shuf_attr_; |
| 45 | + |
| 46 | + static constexpr mem_layout mem_layout_in = mem_layout_in_; |
| 47 | + |
| 48 | + static_assert( |
| 49 | + std::is_same<dtype_in, dtype_out>::value, |
| 50 | + "only support in/out data type must be same now."); |
| 51 | + static_assert( |
| 52 | + mem_layout_in == mem_layout::row_major, |
| 53 | + "only support row_major input now."); |
| 54 | + static_assert( |
| 55 | + std::is_same<dtype_gidx, uint32_t>::value, |
| 56 | + "dtype_gidx must be uint32_t"); |
| 57 | + |
| 58 | + static constexpr uint32_t wg_tile_x = col_major_shuf_attr::wg_tile_x; |
| 59 | + static constexpr uint32_t wg_tile_y = col_major_shuf_attr::wg_tile_y; |
| 60 | + static constexpr uint32_t sg_tile_x = col_major_shuf_attr::sg_tile_x; |
| 61 | + static constexpr uint32_t sg_tile_y = col_major_shuf_attr::sg_tile_y; |
| 62 | + |
| 63 | + static constexpr uint32_t tile_size_x = sg_tile_x; |
| 64 | + static constexpr uint32_t tile_size_y = sg_tile_y; |
| 65 | + |
| 66 | + static constexpr uint32_t block_size_x = |
| 67 | + col_major_shuf_attr::load_block_size; // TODO(zhe:) add load block size |
| 68 | + // check under different arch |
| 69 | + |
| 70 | + static constexpr uint32_t dev_mem_align = 64; |
| 71 | + using mem_desc_store_tile_t = mem_desc_t< |
| 72 | + dtype_in, |
| 73 | + mem_layout_in, |
| 74 | + mem_space::global, |
| 75 | + dev_mem_align / sizeof(dtype_in)>; |
| 76 | + using store_tile_desc_t = subgroup::tile_desc_t< |
| 77 | + tile_size_x, |
| 78 | + tile_size_y, |
| 79 | + block_size_x, |
| 80 | + tile_size_y, |
| 81 | + reg_layout::tiled>; |
| 82 | + using store_tile_t = subgroup::tile_t<dtype_out, store_tile_desc_t>; |
| 83 | + using store_tile_payload_t = subgroup::mem_payload_t< |
| 84 | + mem_desc_store_tile_t, |
| 85 | + store_tile_desc_t, |
| 86 | + subgroup::msg_type_v<store_tile_desc_t, mem_space::global>, |
| 87 | + arch_>; |
| 88 | + |
| 89 | + using mem_desc_gidx_t = mem_desc_t< |
| 90 | + dtype_gidx, |
| 91 | + mem_layout::row_major, |
| 92 | + mem_space::global, |
| 93 | + dev_mem_align / sizeof(dtype_gidx)>; |
| 94 | + using gidx_tile_desc_t = |
| 95 | + subgroup::tile_desc_t<tile_size_x, 1, block_size_x, 1, reg_layout::tiled>; |
| 96 | + using gidx_t = subgroup::tile_t<dtype_gidx, gidx_tile_desc_t>; |
| 97 | + using gidx_payload_t = subgroup::mem_payload_t< |
| 98 | + mem_desc_gidx_t, |
| 99 | + gidx_tile_desc_t, |
| 100 | + subgroup::msg_type_v<gidx_tile_desc_t, mem_space::global>, |
| 101 | + arch_>; |
| 102 | + |
| 103 | + struct arguments_t { |
| 104 | + dtype_in* mat_in_ptr; |
| 105 | + dtype_out* mat_out_ptr; |
| 106 | + dtype_gidx* gidx_ptr; |
| 107 | + uint32_t matrix_x; |
| 108 | + uint32_t matrix_y; |
| 109 | + }; |
| 110 | + |
| 111 | + __XETLA_API static void call(sycl::nd_item<3>& item, arguments_t& args) { |
| 112 | + int gid_x = item.get_group(2); |
| 113 | + int gid_y = item.get_group(1); |
| 114 | + int x_dim_offset = gid_x * wg_tile_x; |
| 115 | + int y_dim_offset = gid_y * wg_tile_y; |
| 116 | + int tid_x = item.get_local_id(2); |
| 117 | + int tid_y = item.get_local_id(1); |
| 118 | + x_dim_offset += tid_x * sg_tile_x; |
| 119 | + y_dim_offset += tid_y * sg_tile_y; |
| 120 | + mem_desc_gidx_t gidx_desc( |
| 121 | + args.gidx_ptr, {args.matrix_x, 1, args.matrix_x}, {x_dim_offset, 0}); |
| 122 | + mem_desc_store_tile_t store_tile_desc( |
| 123 | + args.mat_out_ptr, |
| 124 | + {args.matrix_x, args.matrix_y, args.matrix_x}, |
| 125 | + {x_dim_offset, y_dim_offset}); |
| 126 | + |
| 127 | + static constexpr int block_x_num = tile_size_x / block_size_x; |
| 128 | + static constexpr int elt_per_block = block_size_x * tile_size_y; |
| 129 | + store_tile_t store_tile; |
| 130 | + store_tile_payload_t store_tile_payload(store_tile_desc); |
| 131 | + gidx_payload_t gidx_payload(gidx_desc); |
| 132 | + |
| 133 | +#pragma unroll |
| 134 | + for (int block_x = 0; block_x < block_x_num; block_x++) { |
| 135 | + auto gidx = xetla_load_global< |
| 136 | + uint32_t, |
| 137 | + block_size_x, |
| 138 | + data_size::default_size, |
| 139 | + cache_hint::cached, |
| 140 | + cache_hint::cached>( |
| 141 | + args.gidx_ptr, gidx_payload.base_offset + block_x * block_size_x); |
| 142 | +#pragma unroll |
| 143 | + for (uint32_t row = 0; row < tile_size_y; row++) { |
| 144 | + store_tile.reg.xetla_select<block_size_x, 1>( |
| 145 | + block_x * elt_per_block + row * block_size_x) = |
| 146 | + xetla_load_global< |
| 147 | + dtype_in, |
| 148 | + 1, |
| 149 | + data_size::default_size, |
| 150 | + cache_hint::cached, |
| 151 | + cache_hint::cached, |
| 152 | + block_size_x>( |
| 153 | + args.mat_in_ptr + (y_dim_offset + row) * args.matrix_x, |
| 154 | + gidx, |
| 155 | + 1); |
| 156 | + } |
| 157 | + } |
| 158 | + tile_store(store_tile, store_tile_payload); |
| 159 | + }; |
| 160 | +}; |
| 161 | +} // namespace gpu::xetla::kernel |
0 commit comments