Skip to content

Commit 5c4cd17

Browse files
authored
Properly thread through use_triton_kernel (#3155)
1 parent cdf48f0 commit 5c4cd17

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66

77
import copy
88
import tempfile
9+
from contextlib import contextmanager
910

1011
import pytest
1112
import torch
1213
import torch.nn as nn
14+
from torch.profiler import ProfilerActivity, profile
1315

1416
from torchao.prototype.mx_formats.config import (
1517
MXGemmKernelChoice,
@@ -44,6 +46,23 @@ def run_around_tests():
4446
torch._dynamo.reset()
4547

4648

49+
@contextmanager
50+
def cuda_kernel_profiler(kernel_pattern):
51+
"""Context manager for profiling CUDA kernels."""
52+
result = {"found": False, "kernel_names": []}
53+
54+
with profile(activities=[ProfilerActivity.CUDA]) as prof:
55+
yield result
56+
57+
kernel_names = [
58+
evt.name
59+
for evt in prof.events()
60+
if evt.device_type == torch.autograd.DeviceType.CUDA and evt.name
61+
]
62+
result["kernel_names"] = kernel_names
63+
result["found"] = any(kernel_pattern in name for name in kernel_names)
64+
65+
4766
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
4867
@pytest.mark.skipif(
4968
not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+"
@@ -178,7 +197,14 @@ def test_inference_workflow_nvfp4(
178197

179198
x = torch.randn(batch_size, in_features, device="cuda", dtype=inpt_dtype)
180199
y_ref = m(x)
181-
y_mx = m_mx(x)
200+
201+
if use_triton_kernel and mm_config != NVFP4MMConfig.WEIGHT_ONLY:
202+
with cuda_kernel_profiler("quantize_nvfp4_triton_kernel") as result:
203+
y_mx = m_mx(x)
204+
assert result["found"], "Expected quantize_nvfp4 kernel to be found"
205+
else:
206+
y_mx = m_mx(x)
207+
182208
sqnr = compute_error(y_ref, y_mx)
183209

184210
if mm_config == NVFP4MMConfig.WEIGHT_ONLY:

torchao/prototype/mx_formats/inference_workflow.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ def _nvfp4_inference_linear_transform(
188188
if config.mm_config == NVFP4MMConfig.DYNAMIC:
189189
act_quant_kwargs = QuantizeTensorToNVFP4Kwargs(
190190
use_dynamic_per_tensor_scale=config.use_dynamic_per_tensor_scale,
191+
use_triton_kernel=config.use_triton_kernel,
192+
is_swizzled_scales=True,
191193
)
192194

193195
quantized_weight = NVFP4Tensor.to_nvfp4(

0 commit comments

Comments
 (0)