Skip to content

Commit a08de3c

Browse files
BoyuanFengDarkLight1337
authored andcommitted
cleanup at::Tag::needs_fixed_stride_order (vllm-project#28974)
Signed-off-by: Boyuan Feng <boyuan@meta.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Signed-off-by: LuminolT <lumischen01@gmail.com>
1 parent 6b6742e commit a08de3c

File tree

2 files changed

+20
-51
lines changed

2 files changed

+20
-51
lines changed

csrc/cpu/torch_bindings.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
172172
// Quantization
173173
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) || \
174174
defined(__powerpc64__)
175-
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
176175
// Helper function to release oneDNN handlers
177176
ops.def("release_dnnl_matmul_handler(int handler) -> ()",
178177
&release_dnnl_matmul_handler);
@@ -208,15 +207,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
208207
// Compute int8 quantized tensor for given scaling factor.
209208
ops.def(
210209
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
211-
"Tensor? azp) -> ()",
212-
{stride_tag});
210+
"Tensor? azp) -> ()");
213211
ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
214212

215213
// Compute int8 quantized tensor and scaling factor
216214
ops.def(
217215
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
218-
"Tensor!? azp) -> ()",
219-
{stride_tag});
216+
"Tensor!? azp) -> ()");
220217
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
221218
&dynamic_scaled_int8_quant);
222219
#endif

csrc/torch_bindings.cpp

Lines changed: 18 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
2020
// vLLM custom ops
2121
//
2222

23-
// The default behavior in PyTorch 2.6 was changed to "requires_contiguous",
24-
// so we need
25-
// to override this for many GEMMs with the following tag. Otherwise,
26-
// torch.compile will force all input tensors to be contiguous(), which
27-
// will break many custom ops that require column-major weight matrices.
28-
// This was a bug and PyTorch 2.7 has since fixed this.
29-
#if TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 6
30-
#define stride_tag at::Tag::needs_fixed_stride_order
31-
#else
32-
#define stride_tag
33-
#endif
34-
3523
ops.def(
3624
"persistent_masked_m_silu_mul_quant(Tensor input, Tensor counts, Tensor! "
3725
"y_q, Tensor! y_s,"
@@ -241,15 +229,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
241229
// Quantized GEMM for AWQ.
242230
ops.def(
243231
"awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
244-
"Tensor _zeros, SymInt split_k_iters) -> Tensor",
245-
{stride_tag});
232+
"Tensor _zeros, SymInt split_k_iters) -> Tensor");
246233
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
247234

248235
// Dequantization for AWQ.
249236
ops.def(
250237
"awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
251-
"Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor",
252-
{stride_tag});
238+
"Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor");
253239
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
254240

255241
// Note about marlin kernel 'workspace' arguments:
@@ -271,8 +257,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
271257
"gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
272258
"Tensor b_scales, Tensor workspace, "
273259
"int b_q_type, "
274-
"SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor",
275-
{stride_tag});
260+
"SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor");
276261
// conditionally compiled so impl in source file
277262

278263
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
@@ -298,8 +283,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
298283
" Tensor? channel_scales,"
299284
" Tensor? token_scales,"
300285
" str? schedule"
301-
") -> Tensor",
302-
{stride_tag});
286+
") -> Tensor");
303287
ops.def(
304288
"machete_prepack_B("
305289
" Tensor B,"
@@ -319,8 +303,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
319303
"Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? "
320304
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, "
321305
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
322-
"bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor",
323-
{stride_tag});
306+
"bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
324307
// conditionally compiled so impl registration is in source file
325308

326309
// gptq_marlin repack from GPTQ.
@@ -346,8 +329,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
346329
" Tensor token_scales,"
347330
" ScalarType? out_type,"
348331
" str? maybe_schedule"
349-
") -> Tensor",
350-
{stride_tag});
332+
") -> Tensor");
351333
// pack scales
352334
ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor");
353335
// encode and reorder weight matrix
@@ -394,33 +376,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
394376
ops.def(
395377
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
396378
" Tensor block_scale_a, Tensor block_scale_b,"
397-
" Tensor alpha) -> ()",
398-
{stride_tag});
379+
" Tensor alpha) -> ()");
399380
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
400381

401382
// cutlass blockwise scaledgroup GEMM
402383
ops.def(
403384
"cutlass_blockwise_scaled_grouped_mm(Tensor! output, Tensor a, Tensor b, "
404385
"Tensor scales_a, Tensor scales_b, "
405-
"Tensor problem_sizes, Tensor expert_offsets) -> ()",
406-
{stride_tag});
386+
"Tensor problem_sizes, Tensor expert_offsets) -> ()");
407387
// conditionally compiled so impl registration is in source file
408388

409389
// cutlass nvfp4 block scaled group GEMM
410390
ops.def(
411391
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
412392
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
413-
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()",
414-
{stride_tag});
393+
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
415394
// conditionally compiled so impl registration is in source file
416395

417396
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
418397
// quantization, as well as bias
419398
ops.def(
420399
"cutlass_scaled_mm(Tensor! out, Tensor a,"
421400
" Tensor b, Tensor a_scales,"
422-
" Tensor b_scales, Tensor? bias) -> ()",
423-
{stride_tag});
401+
" Tensor b_scales, Tensor? bias) -> ()");
424402
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
425403

426404
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
@@ -429,8 +407,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
429407
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
430408
" Tensor b, Tensor a_scales,"
431409
" Tensor b_scales, Tensor azp_adj,"
432-
" Tensor? azp, Tensor? bias) -> ()",
433-
{stride_tag});
410+
" Tensor? azp, Tensor? bias) -> ()");
434411
ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
435412

436413
// Check if cutlass scaled_mm is supported for CUDA devices of the given
@@ -449,8 +426,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
449426
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
450427
" Tensor problem_sizes, Tensor a_strides, "
451428
" Tensor b_strides, Tensor c_strides, bool per_act_token, "
452-
" bool per_out_ch) -> ()",
453-
{stride_tag});
429+
" bool per_out_ch) -> ()");
454430
ops.impl("cutlass_moe_mm", torch::kCUDA, &cutlass_moe_mm);
455431

456432
// A function that computes data required to run fused MoE with w8a8 grouped
@@ -464,8 +440,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
464440
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
465441
" Tensor! input_permutation, "
466442
" Tensor! output_permutation, int num_experts, "
467-
" int n, int k, Tensor? blockscale_offsets) -> ()",
468-
{stride_tag});
443+
" int n, int k, Tensor? blockscale_offsets) -> "
444+
"()");
469445
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
470446

471447
// A function that computes problem sizes for each expert's multiplication
@@ -476,8 +452,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
476452
" Tensor! problem_sizes1, "
477453
" Tensor! problem_sizes2, "
478454
" int num_experts, int n, int k, "
479-
" Tensor? blockscale_offsets) -> ()",
480-
{stride_tag});
455+
" Tensor? blockscale_offsets) -> ()");
481456
ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA,
482457
&get_cutlass_moe_mm_problem_sizes);
483458

@@ -492,8 +467,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
492467
" Tensor! problem_sizes2, "
493468
" Tensor expert_num_tokens, "
494469
" int num_local_experts, int padded_m, "
495-
" int n, int k) -> ()",
496-
{stride_tag});
470+
" int n, int k) -> ()");
497471
ops.impl("get_cutlass_pplx_moe_mm_data", torch::kCUDA,
498472
&get_cutlass_pplx_moe_mm_data);
499473

@@ -517,8 +491,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
517491
"cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
518492
" Tensor bt_nzs,"
519493
" Tensor bt_meta, Tensor a_scales,"
520-
" Tensor b_scales, Tensor? bias) -> ()",
521-
{stride_tag});
494+
" Tensor b_scales, Tensor? bias) -> ()");
522495
ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm);
523496

524497
// CUTLASS sparse matrix compressor
@@ -567,8 +540,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
567540
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
568541
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, bool "
569542
"use_v2_format, int bit) "
570-
"-> Tensor",
571-
{stride_tag});
543+
"-> Tensor");
572544
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
573545

574546
// Post processing for GPTQ.

0 commit comments

Comments
 (0)