1616#include " ../include/bestla_packq_impl.hpp"
1717
1818namespace woq {
19- template <class GemmCore , BTLA_ISA ISA>
19+
20+ template <class proB >
2021void execute_qpack (repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) {
21- using proB = bestla::prologue_b::gemm::WeightKBlockNInteger<GemmCore, ISA>;
2222 static proB ker;
23- auto qpackw = ker.createStorage (ctx->n , ctx->k , p->blocksize , wei2bestladt_map.at (p->weight_type ),
24- scale2bestladt_map.at (p->scale_type ), BTLA_DTYPE::BF16, p->asym );
23+ using WType = typename proB::StorageWeight;
24+ WType qpackw (0 );
25+ if constexpr (std::is_same_v<WType, bestla::storage::gemm::StorageWeightKBlockNInteger>) {
26+ qpackw = ker.createStorage (ctx->n , ctx->k , p->blocksize , wei2bestladt_map.at (p->weight_type ),
27+ scale2bestladt_map.at (p->scale_type ), BTLA_DTYPE::BF16, p->asym );
28+ } else {
29+ qpackw = ker.createStorage (ctx->n , ctx->k , p->blocksize , wei2bestladt_map.at (p->weight_type ),
30+ scale2bestladt_map.at (p->scale_type ));
31+ }
2532 if (p->enable_act_shuffle ) ker.enableShuffle (&qpackw);
2633 ctx->packw_size = qpackw.mSize ;
2734 if (task == WOQ_GET_PACKW_SIZE) return ;
@@ -33,6 +40,20 @@ void execute_qpack(repack_quantized_weight_param* p, repack_quantized_weight_ctx
3340 p->asym ? ctx->zp ->data_ptr <int8_t >() : nullptr , &qpackw, dispatcher_utils::qbits_threading::get ());
3441}
3542
43+ template <class GemmCore , BTLA_ISA ISA>
44+ void parse_prob (repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) {
45+ if (p->weight_type == " int8" || p->weight_type == " int4_clip" || p->weight_type == " int3_clip" ||
46+ p->weight_type == " int2_clip" ) {
47+ return execute_qpack<bestla::prologue_b::gemm::WeightKBlockNInteger<GemmCore, ISA>>(p, ctx, task);
48+ }
49+ if (p->weight_type == " nf4" || p->weight_type == " fp4_e2m1_bnb" || p->weight_type == " fp4_e2m1" ) {
50+ TORCH_CHECK (!p->asym , " Qbits: float-weight unsupports asym quantization." );
51+ return execute_qpack<bestla::prologue_b::gemm::WeightKBlockNFloat<GemmCore, ISA>>(p, ctx, task);
52+ }
53+ TORCH_CHECK (false , " Qbits: unsupported bestla packq config, compute_type: " + p->compute_type +
54+ " weight_type: " + p->weight_type );
55+ }
56+
3657std::string get_dtype_str (BTLA_DTYPE dtype) {
3758 switch (dtype) {
3859 case BTLA_DTYPE::F32:
@@ -183,40 +204,38 @@ torch::Tensor get_packw_info(torch::Tensor& packw, PACKW_ACQUIRE_TYPE ACQ_T) {
183204}
184205
185206void bestla_packq (repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) {
186- // TODO(zhe): elegant impl.
187- TORCH_CHECK (p->weight_type == " int8" || p->weight_type == " int4_clip" || p->weight_type == " int3_clip" ||
188- p->weight_type == " int2_clip" ,
189- " Qbits: only support Integer WOQ in PACKQ" );
190-
191207 if (p->compute_type == " int8" ) {
208+ TORCH_CHECK (p->weight_type == " int8" || p->weight_type == " int4_clip" || p->weight_type == " int3_clip" ||
209+ p->weight_type == " int2_clip" ,
210+ " Qbits: only support Integer weight-type with int8 compute-type" );
192211 if (dispatcher_utils::check_amx () && p->blocksize % bestla::gemm::ICoreRowNAmxint8KBlock<64 , 16 >::KTILE == 0 ) {
193- return execute_qpack <bestla::gemm::ICoreRowNAmxint8KBlock<64 , 16 >, BTLA_ISA::AMX_INT8>(p, ctx, task);
212+ return parse_prob <bestla::gemm::ICoreRowNAmxint8KBlock<64 , 16 >, BTLA_ISA::AMX_INT8>(p, ctx, task);
194213 }
195214 if (dispatcher_utils::check_avx512_vnni () &&
196215 p->blocksize % bestla::gemm::ICoreRowNAvx512vnniKBlock<48 , 4 >::KTILE == 0 ) {
197- return execute_qpack <bestla::gemm::ICoreRowNAvx512vnniKBlock<48 , 4 >, BTLA_ISA::AVX512_VNNI>(p, ctx, task);
216+ return parse_prob <bestla::gemm::ICoreRowNAvx512vnniKBlock<48 , 4 >, BTLA_ISA::AVX512_VNNI>(p, ctx, task);
198217 }
199218 if (dispatcher_utils::check_avx_vnni () && p->blocksize % bestla::gemm::ICoreRowNAvxvnniKBlock<24 , 2 >::KTILE == 0 ) {
200- return execute_qpack <bestla::gemm::ICoreRowNAvxvnniKBlock<24 , 2 >, BTLA_ISA::AVX_VNNI>(p, ctx, task);
219+ return parse_prob <bestla::gemm::ICoreRowNAvxvnniKBlock<24 , 2 >, BTLA_ISA::AVX_VNNI>(p, ctx, task);
201220 }
202221 if (dispatcher_utils::check_avx2 () && p->blocksize % bestla::gemm::ICoreRowNAvx2vnniKBlock<24 , 2 >::KTILE == 0 ) {
203- return execute_qpack <bestla::gemm::ICoreRowNAvx2vnniKBlock<24 , 2 >, BTLA_ISA::AVX2>(p, ctx, task);
222+ return parse_prob <bestla::gemm::ICoreRowNAvx2vnniKBlock<24 , 2 >, BTLA_ISA::AVX2>(p, ctx, task);
204223 }
205224 TORCH_CHECK (false , " Qbits: Illegal config in int8 compute_type, blocksize:" , p->blocksize ,
206225 " , ISA support avx2:" , dispatcher_utils::check_avx2 ());
207226 }
208227 if (p->compute_type == " fp32" ) {
209228 if (dispatcher_utils::check_avx512f ()) {
210- return execute_qpack <bestla::gemm::SCoreRowNAvx512f<48 , 8 >, BTLA_ISA::AVX512F>(p, ctx, task);
229+ return parse_prob <bestla::gemm::SCoreRowNAvx512f<48 , 8 >, BTLA_ISA::AVX512F>(p, ctx, task);
211230 }
212231 if (dispatcher_utils::check_avx2 ()) {
213- return execute_qpack <bestla::gemm::SCoreRowNAvx2<24 , 4 >, BTLA_ISA::AVX2>(p, ctx, task);
232+ return parse_prob <bestla::gemm::SCoreRowNAvx2<24 , 4 >, BTLA_ISA::AVX2>(p, ctx, task);
214233 }
215234 TORCH_CHECK (false , " Qbits: device ISA must support BTLA_ISA::AVX2 when compute_type==fp32" );
216235 }
217236 if (p->compute_type == " bf16" ) {
218237 if (dispatcher_utils::check_amx ()) {
219- return execute_qpack <bestla::gemm::HCoreRowNAmxbf16<64 , 16 >, BTLA_ISA::AMX_BF16>(p, ctx, task);
238+ return parse_prob <bestla::gemm::HCoreRowNAmxbf16<64 , 16 >, BTLA_ISA::AMX_BF16>(p, ctx, task);
220239 }
221240 TORCH_CHECK (false , " Qbits: device ISA must support AMX-BF16 when compute_type==bf16" );
222241 }
0 commit comments