@@ -20,18 +20,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
2020 // vLLM custom ops
2121 //
2222
23- // The default behavior in PyTorch 2.6 was changed to "requires_contiguous",
24- // so we need
25- // to override this for many GEMMs with the following tag. Otherwise,
26- // torch.compile will force all input tensors to be contiguous(), which
27- // will break many custom ops that require column-major weight matrices.
28- // This was a bug and PyTorch 2.7 has since fixed this.
29- #if TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 6
30- #define stride_tag at::Tag::needs_fixed_stride_order
31- #else
32- #define stride_tag
33- #endif
34-
3523 ops.def (
3624 " persistent_masked_m_silu_mul_quant(Tensor input, Tensor counts, Tensor! "
3725 " y_q, Tensor! y_s,"
@@ -241,15 +229,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
241229 // Quantized GEMM for AWQ.
242230 ops.def (
243231 " awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
244- " Tensor _zeros, SymInt split_k_iters) -> Tensor" ,
245- {stride_tag});
232+ " Tensor _zeros, SymInt split_k_iters) -> Tensor" );
246233 ops.impl (" awq_gemm" , torch::kCUDA , &awq_gemm);
247234
248235 // Dequantization for AWQ.
249236 ops.def (
250237 " awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
251- " Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor" ,
252- {stride_tag});
238+ " Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor" );
253239 ops.impl (" awq_dequantize" , torch::kCUDA , &awq_dequantize);
254240
255241 // Note about marlin kernel 'workspace' arguments:
@@ -271,8 +257,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
271257 " gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
272258 " Tensor b_scales, Tensor workspace, "
273259 " int b_q_type, "
274- " SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor" ,
275- {stride_tag});
260+ " SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor" );
276261 // conditionally compiled so impl in source file
277262
278263 // Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
@@ -298,8 +283,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
298283 " Tensor? channel_scales,"
299284 " Tensor? token_scales,"
300285 " str? schedule"
301- " ) -> Tensor" ,
302- {stride_tag});
286+ " ) -> Tensor" );
303287 ops.def (
304288 " machete_prepack_B("
305289 " Tensor B,"
@@ -319,8 +303,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
319303 " Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? "
320304 " g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, "
321305 " SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
322- " bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor" ,
323- {stride_tag});
306+ " bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor" );
324307 // conditionally compiled so impl registration is in source file
325308
326309 // gptq_marlin repack from GPTQ.
@@ -346,8 +329,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
346329 " Tensor token_scales,"
347330 " ScalarType? out_type,"
348331 " str? maybe_schedule"
349- " ) -> Tensor" ,
350- {stride_tag});
332+ " ) -> Tensor" );
351333 // pack scales
352334 ops.def (" cutlass_pack_scale_fp8(Tensor scales) -> Tensor" );
353335 // encode and reorder weight matrix
@@ -394,33 +376,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
394376 ops.def (
395377 " cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
396378 " Tensor block_scale_a, Tensor block_scale_b,"
397- " Tensor alpha) -> ()" ,
398- {stride_tag});
379+ " Tensor alpha) -> ()" );
399380 ops.impl (" cutlass_scaled_fp4_mm" , torch::kCUDA , &cutlass_scaled_fp4_mm);
400381
401382 // cutlass blockwise scaledgroup GEMM
402383 ops.def (
403384 " cutlass_blockwise_scaled_grouped_mm(Tensor! output, Tensor a, Tensor b, "
404385 " Tensor scales_a, Tensor scales_b, "
405- " Tensor problem_sizes, Tensor expert_offsets) -> ()" ,
406- {stride_tag});
386+ " Tensor problem_sizes, Tensor expert_offsets) -> ()" );
407387 // conditionally compiled so impl registration is in source file
408388
409389 // cutlass nvfp4 block scaled group GEMM
410390 ops.def (
411391 " cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
412392 " Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
413- " Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()" ,
414- {stride_tag});
393+ " Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()" );
415394 // conditionally compiled so impl registration is in source file
416395
417396 // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
418397 // quantization, as well as bias
419398 ops.def (
420399 " cutlass_scaled_mm(Tensor! out, Tensor a,"
421400 " Tensor b, Tensor a_scales,"
422- " Tensor b_scales, Tensor? bias) -> ()" ,
423- {stride_tag});
401+ " Tensor b_scales, Tensor? bias) -> ()" );
424402 ops.impl (" cutlass_scaled_mm" , torch::kCUDA , &cutlass_scaled_mm);
425403
426404 // CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
@@ -429,8 +407,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
429407 " cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
430408 " Tensor b, Tensor a_scales,"
431409 " Tensor b_scales, Tensor azp_adj,"
432- " Tensor? azp, Tensor? bias) -> ()" ,
433- {stride_tag});
410+ " Tensor? azp, Tensor? bias) -> ()" );
434411 ops.impl (" cutlass_scaled_mm_azp" , torch::kCUDA , &cutlass_scaled_mm_azp);
435412
436413 // Check if cutlass scaled_mm is supported for CUDA devices of the given
@@ -449,8 +426,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
449426 " Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
450427 " Tensor problem_sizes, Tensor a_strides, "
451428 " Tensor b_strides, Tensor c_strides, bool per_act_token, "
452- " bool per_out_ch) -> ()" ,
453- {stride_tag});
429+ " bool per_out_ch) -> ()" );
454430 ops.impl (" cutlass_moe_mm" , torch::kCUDA , &cutlass_moe_mm);
455431
456432 // A function that computes data required to run fused MoE with w8a8 grouped
@@ -464,8 +440,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
464440 " Tensor! problem_sizes1, Tensor! problem_sizes2, "
465441 " Tensor! input_permutation, "
466442 " Tensor! output_permutation, int num_experts, "
467- " int n, int k, Tensor? blockscale_offsets) -> () " ,
468- {stride_tag} );
443+ " int n, int k, Tensor? blockscale_offsets) -> "
444+ " () " );
469445 ops.impl (" get_cutlass_moe_mm_data" , torch::kCUDA , &get_cutlass_moe_mm_data);
470446
471447 // A function that computes problem sizes for each expert's multiplication
@@ -476,8 +452,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
476452 " Tensor! problem_sizes1, "
477453 " Tensor! problem_sizes2, "
478454 " int num_experts, int n, int k, "
479- " Tensor? blockscale_offsets) -> ()" ,
480- {stride_tag});
455+ " Tensor? blockscale_offsets) -> ()" );
481456 ops.impl (" get_cutlass_moe_mm_problem_sizes" , torch::kCUDA ,
482457 &get_cutlass_moe_mm_problem_sizes);
483458
@@ -492,8 +467,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
492467 " Tensor! problem_sizes2, "
493468 " Tensor expert_num_tokens, "
494469 " int num_local_experts, int padded_m, "
495- " int n, int k) -> ()" ,
496- {stride_tag});
470+ " int n, int k) -> ()" );
497471 ops.impl (" get_cutlass_pplx_moe_mm_data" , torch::kCUDA ,
498472 &get_cutlass_pplx_moe_mm_data);
499473
@@ -517,8 +491,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
517491 " cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
518492 " Tensor bt_nzs,"
519493 " Tensor bt_meta, Tensor a_scales,"
520- " Tensor b_scales, Tensor? bias) -> ()" ,
521- {stride_tag});
494+ " Tensor b_scales, Tensor? bias) -> ()" );
522495 ops.impl (" cutlass_scaled_sparse_mm" , torch::kCUDA , &cutlass_scaled_sparse_mm);
523496
524497 // CUTLASS sparse matrix compressor
@@ -567,8 +540,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
567540 " gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
568541 " Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, bool "
569542 " use_v2_format, int bit) "
570- " -> Tensor" ,
571- {stride_tag});
543+ " -> Tensor" );
572544 ops.impl (" gptq_gemm" , torch::kCUDA , &gptq_gemm);
573545
574546 // Post processing for GPTQ.
0 commit comments