From 821bd2b7985f26743ef7644a60e7380cb16e8c26 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 16 Oct 2025 07:41:27 -0700 Subject: [PATCH 1/5] Update [ghstack-poisoned] --- benchmarks/float8/float8_roofline.py | 22 ++++++++++-- torchao/testing/training/roofline_utils.py | 41 +++++++++++++++++++++- 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index 4bf54538df..547b0a40e4 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -180,7 +180,7 @@ def get_gemm_times( scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu) scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu) else: - assert False, "TODO add cutlass mx gemm here" + assert False, f"unsupported {float8_recipe_name=} {mx_recipe_name=}" def do_matmul(A, B): return torch._scaled_mm( @@ -233,6 +233,20 @@ def run( print(f"mx_recipe_name: {mx_recipe_name}") print(f"enable_fusion_modeling: {enable_fusion_modeling}") + assert mx_recipe_name in ( + # real mxfp8_cublas recipe + "mxfp8_cublas", + # real mxfp8_cublas_rceil recipe + "mxfp8_cublas_rceil", + # modeling of what mxfp8 with 32x32 block size and without gemm + # operand layout restrictions would look like + "mxfp8_32x32_flexible_gemm_layout", + # modeling of what mxfp8 with 32x32 block size for weight + "mxfp8_32x32_weight", + # real mxfp4_cutlass recipe + "mxfp4_cutlass", + ), f"unsupported {mx_recipe_name=}" + M, K, N = sympy.symbols("M K N") fp8_ovhd_time_sympy = get_float8_mem_sympy( @@ -309,7 +323,11 @@ def run( rb_fp8_gemm_ratio = -1 if do_benchmarks: - assert mx_recipe_name != "mxfp4_cutlass", "unsupported" + assert mx_recipe_name not in ( + "mxfp4_cutlass", + "mxfp8_32x32_flexible_gemm_layout", + "mxfp8_32x32_weight", + ), f"do_benchmarks unsupported with {mx_recipe_name=}" # TODO(future): make the bf16 gemm times exactly match the e2e # benchmarks, there is a slight deviation, probably related to gemm diff --git a/torchao/testing/training/roofline_utils.py b/torchao/testing/training/roofline_utils.py index f57705333a..6610654bf1 100644 --- a/torchao/testing/training/roofline_utils.py +++ b/torchao/testing/training/roofline_utils.py @@ -187,13 +187,52 @@ def get_tensor_memory_traffic_ovhd_s( else: assert False, "unsupported" + elif mx_recipe_name == "mxfp8_32x32_flexible_gemm_layout": + # modeling the following: + # 1. mxfp8 scaling with 32x32 everywhere, so the format makes sense + # across dim0 and dim1 + # 2. mxfp8 gemm with TN, NT, TT, NN formats supported (not in + # PyTorch right now) + # x_bf16 = ... + # kernel 1: x_bf16 -> x_mxfp8_dim0 + if fuse_with_prev: + kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel + else: + kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + res_bytes = [kernel_1_rw] + + elif mx_recipe_name == "mxfp8_32x32_weight": + # modeling the following: + # 1. mxfp8 scaling with 32x32 weights, so the format makes sense + # across dim0 and dim1. input and grad_output still 1x32. + + if tensor_role in ("input", "grad_output"): + # kernel 1: x_bf16 -> x_mxfp8_dim0 + # kernel 2: x_bf16 -> x_mxfp8_dim1 + if fuse_with_prev: + kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel + else: + kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + + elif tensor_role == "weight": + # kernel 1: x_bf16 -> x_mxfp8_dim0 + # kernel 2: x_mxfp8_dim0 -> x_mxfp8_dim1 + kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + kernel_2_rw = BYTES_PER_EL_FLOAT8 * numel * 2 + + else: + assert False, "unsupported" + + res_bytes = [kernel_1_rw, kernel_2_rw] + else: assert mx_recipe_name in ( "mxfp8_emulated", "mxfp8_cublas", "mxfp8_cublas_rceil", "mxfp4_cutlass", - ), "unsupported" + ), f"unsupported {mx_recipe_name=}" # For now, assume that we can't profitably fuse kernel 1 and kernel 2 # x_bf16 = ... # kernel 1: x_bf16 -> x_mxfp8_dim0 From 5bd4e3b4ff6617d6bb7eec8b13f6be99b1aeb40d Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 16 Oct 2025 13:32:59 -0700 Subject: [PATCH 2/5] Update [ghstack-poisoned] --- torchao/testing/training/roofline_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchao/testing/training/roofline_utils.py b/torchao/testing/training/roofline_utils.py index 6610654bf1..e391a4d44b 100644 --- a/torchao/testing/training/roofline_utils.py +++ b/torchao/testing/training/roofline_utils.py @@ -207,6 +207,7 @@ def get_tensor_memory_traffic_ovhd_s( # across dim0 and dim1. input and grad_output still 1x32. if tensor_role in ("input", "grad_output"): + # TODO(future): update all of the mx rooflines to just read once # kernel 1: x_bf16 -> x_mxfp8_dim0 # kernel 2: x_bf16 -> x_mxfp8_dim1 if fuse_with_prev: From ea2d54f578ef0fb39d0556699429598419ce8927 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 16 Oct 2025 14:09:19 -0700 Subject: [PATCH 3/5] Update [ghstack-poisoned] --- .../float8/float8_inference_roofline.py | 106 +++++++++++++----- 1 file changed, 81 insertions(+), 25 deletions(-) diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index fbfead161a..6c8113e8cb 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -38,6 +38,14 @@ ) import torchao +from torchao.prototype.mx_formats.config import ( + MXGemmKernelChoice, +) +from torchao.prototype.mx_formats.inference_workflow import ( + MXFPInferenceConfig, + NVFP4InferenceConfig, + NVFP4MMConfig, +) from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, PerRow, @@ -80,40 +88,67 @@ def get_gemm_times( fast_accum: bool, recipe_name: Optional[str], ): - assert recipe_name in {"rowwise"}, ( - "Only support real benchmarks for 'rowwise' recipe for now" - ) device = torch.device("cuda") # bf16 time x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device) - # w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t() w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device) bf16_time_s = get_gpu_kernel_gemm_time_s(torch.mm, x_bf16, w_bf16) - e4m3_dtype = torch.float8_e4m3fn - if torch.version.hip and torch.cuda.is_available() and is_MI300(): - e4m3_dtype = torch.float8_e4m3fnuz - d1, d2, d3 = e4m3_dtype, e4m3_dtype, torch.bfloat16 - A = torch.randint(0, 255, (M, K), device=device, dtype=torch.uint8).view(d1) - B = ( - torch.randint(0, 255, (K, N), device=device, dtype=torch.uint8) - .view(d2) - .t() - .contiguous() - .t() - ) + if recipe_name in ("mxfp4_cutlass", "nvfp4"): + d1, d2, d3 = torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2, torch.bfloat16 + A = torch.randint(0, 255, (M, K // 2), device=device, dtype=torch.uint8).view( + d1 + ) + B = ( + torch.randint(0, 255, (K // 2, N), device=device, dtype=torch.uint8) + .t() + .contiguous() + .t() + .view(d2) + ) + else: + e4m3_dtype = torch.float8_e4m3fn + if torch.version.hip and torch.cuda.is_available() and is_MI300(): + e4m3_dtype = torch.float8_e4m3fnuz + d1, d2, d3 = e4m3_dtype, e4m3_dtype, torch.bfloat16 + A = torch.randint(0, 255, (M, K), device=device, dtype=torch.uint8).view(d1) + B = ( + torch.randint(0, 255, (K, N), device=device, dtype=torch.uint8) + .view(d2) + .t() + .contiguous() + .t() + ) + if recipe_name == "rowwise": scale_a = torch.ones(M, 1, device=device) scale_b = torch.ones(1, N, device=device) + elif recipe_name == "mxfp8_cublas": + scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu) + scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu) + elif recipe_name == "mxfp4_cutlass": + scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu) + scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu) + elif recipe_name == "nvfp4": + scale_a = torch.ones(M, K // 16, device=device, dtype=torch.float8_e4m3fn) + scale_b = torch.ones(N, K // 16, device=device, dtype=torch.float8_e4m3fn) + else: assert False, "unsupported" def do_matmul(A, B): - return torch._scaled_mm( - A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum - ) + if recipe_name == "mxfp4_cutlass": + return torchao.ops.mx_fp4_bf16(A, B, scale_a, scale_b) + if recipe_name == "nvfp4": + return torch._scaled_mm( + A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False + ) + else: + return torch._scaled_mm( + A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum + ) f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B) @@ -259,12 +294,33 @@ def run( # get the float8 dynamic scaling gpu kernel time torch._dynamo.reset() - config = Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow(), - # for now, use TORCH. In the future might be interesting - # to benchmark AUTO and FBGEMM. - kernel_preference=KernelPreference.TORCH, - ) + if recipe_name == "rowwise": + config = Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + # for now, use TORCH. In the future might be interesting + # to benchmark AUTO and FBGEMM. + kernel_preference=KernelPreference.TORCH, + ) + elif recipe_name == "mxfp8_cublas": + config = MXFPInferenceConfig( + activation_dtype=torch.float8_e4m3fn, + weight_dtype=torch.float8_e4m3fn, + gemm_kernel_choice=MXGemmKernelChoice.CUBLAS, + ) + elif recipe_name == "mxfp4_cutlass": + config = MXFPInferenceConfig( + activation_dtype=torch.float4_e2m1fn_x2, + weight_dtype=torch.float4_e2m1fn_x2, + gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, + ) + elif recipe_name == "nvfp4": + config = NVFP4InferenceConfig( + mm_config=NVFP4MMConfig.DYNAMIC, + use_dynamic_per_tensor_scale=False, + ) + else: + assert False, "unsupported" + m_fp8_dyn = copy.deepcopy(m_orig) quantize_(m_fp8_dyn, config) From b88850f0d83a7cac38b83868da00ddfaf2f9ab26 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 16 Oct 2025 17:44:34 -0700 Subject: [PATCH 4/5] Update [ghstack-poisoned] --- .../float8/float8_inference_roofline.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index 6c8113e8cb..3365fba923 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -60,7 +60,7 @@ @torch.no_grad() -def get_gpu_kernel_time(m, x): +def get_gpu_kernel_time(m, x, trace_filename=None): # warm up for _ in range(2): __ = m(x) @@ -72,6 +72,12 @@ def get_gpu_kernel_time(m, x): for _ in range(n_iter): __ = m(x) torch.cuda.synchronize() + + # save a trace, if requested + if trace_filename is not None: + print(f"exporting trace to {trace_filename}") + prof.export_chrome_trace(trace_filename) + # get the gpu kernel time and aggregate it num_leaf_tensors = 1 + len(list(m.parameters())) ref_times = profiler_output_to_filtered_time_by_kernel_name( @@ -161,6 +167,7 @@ def run( do_benchmarks: bool = True, shape_gen_name: str = "pow2", n_limit: Optional[int] = None, + save_profile_traces: bool = False, ): """ Args: @@ -168,6 +175,7 @@ def run( * `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked * `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep` * `n_limit (optional)`: if specified, only runs `n_limit` iterations + # `save_profile_traces (optional)`: if True, saves profiling traces """ config_table = [ ["GPU", torch.cuda.get_device_name(0)], @@ -289,7 +297,11 @@ def run( # get the bf16 gpu kernel time torch._dynamo.reset() m_bf16 = torch.compile(copy.deepcopy(m_orig)) - b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x) + + bf16_trace_filename = None + if save_profile_traces: + bf16_trace_filename = f"{outfile}_{M_val}_{K_val}_{N_val}_bf16.json" + b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x, bf16_trace_filename) # get the float8 dynamic scaling gpu kernel time torch._dynamo.reset() @@ -325,7 +337,11 @@ def run( quantize_(m_fp8_dyn, config) m_fp8_dyn = torch.compile(m_fp8_dyn) - b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x) + + fp8_trace_filename = None + if save_profile_traces: + fp8_trace_filename = f"{outfile}_{M_val}_{K_val}_{N_val}_fp8.json" + b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x, fp8_trace_filename) r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s) From 07749746ad0700c590d6b2f491b343e79218bcb5 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 16 Oct 2025 18:51:11 -0700 Subject: [PATCH 5/5] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_tensor.py | 54 +++++++++++++++++++++ torchao/prototype/mx_formats/kernels.py | 4 +- torchao/prototype/mx_formats/mx_tensor.py | 5 +- 3 files changed, 59 insertions(+), 4 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 577112b16a..1a3631cc53 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -662,3 +662,57 @@ def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool): rtol=0.0, msg=f"Roundtrip failed for shape {shape} with use_triton_kernel={use_triton_kernel}", ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+") +@pytest.mark.parametrize("transpose", [False, True]) +@pytest.mark.parametrize( + "shape", + ( + (128, 64), + (1, 128, 64), + ), +) +def test_scale_shape_matches_qdata(transpose, shape): + if len(shape) == 3 and transpose: + pytest.skip("transpose not yet implemented for 3D MXTensor") + + block_size = 32 + + x_hp = torch.randn(*shape, device="cuda") + x = MXTensor.to_mx( + x_hp, + torch.float8_e4m3fn, + block_size, + ScaleCalculationMode.FLOOR, + ) + + if len(shape) == 2: + m_dim, k_dim = 0, 1 + if transpose: + x_hp = x_hp.t() + x = x.t() + m_dim, k_dim = 1, 0 + else: + assert len(shape) == 3, "unsupported" + m_dim, k_dim = 1, 2 + if transpose: + x_hp = x_hp.transpose(-2, -1) + x = x.transpose(-2, -1) + m_dim, k_dim = 2, 1 + + orig_m = x_hp.shape[m_dim] + expected_padded_m = orig_m + actual_padded_m = x.scale.shape[m_dim] + assert expected_padded_m == actual_padded_m, ( + f"incompatible padded shape for dim {m_dim}: {expected_padded_m=}, {actual_padded_m=}, {x.shape}, {x.scale.shape}" + ) + + orig_k = x_hp.shape[k_dim] + expected_padded_k = orig_k // block_size + actual_padded_k = x.scale.shape[k_dim] + + assert expected_padded_k == actual_padded_k, ( + f"incompatible padded shape for dim {k_dim}: {expected_padded_k}, {actual_padded_k=}, {x.shape}, {x.scale.shape}" + ) diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 69bb076b40..c69da4d076 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1264,7 +1264,7 @@ def triton_to_mxfp8_dim1( return ( output_col_major.t(), - col_scale.view(torch.float8_e8m0fnu), + col_scale.view(torch.float8_e8m0fnu).squeeze(-1), ) @register_sharding(torch.ops.torchao.triton_to_mxfp8_dim1.default) @@ -1293,7 +1293,7 @@ def triton_to_mxfp8_dim1_reference( scale_e8m0_dim1 = scale_e8m0_dim1.view(torch.float8_e8m0fnu) return ( x_hp_d1_normalized.t(), - scale_e8m0_dim1.unsqueeze(-1), + scale_e8m0_dim1, ) @triton.jit diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 05c8fdc8e4..a5e50b2468 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -362,6 +362,7 @@ def to_dtype( # unpacking and unscaling if is_transposed: data_lp = data_lp.t() + scale_e8m0 = scale_e8m0.t() assert data_lp.is_contiguous() orig_shape = (orig_shape[1], orig_shape[0]) @@ -688,7 +689,7 @@ def _addmm_mx_dispatch( assert b._block_size == 32, f"Invalid block size {b._block_size}" a_scale = a.scale.view(M, K // a._block_size) - b_scale = b.scale.view(N, K // b._block_size) + b_scale = b.scale.t().view(N, K // b._block_size) a_scale_block = to_blocked(a_scale) b_scale_block = to_blocked(b_scale) @@ -759,7 +760,7 @@ def mx_t(func, types, args, kwargs): old = args[0] new = MXTensor( old.qdata.t(), - old.scale, + old.scale.t(), old._elem_dtype, old._block_size, old._orig_dtype,