|
6 | 6 |
|
7 | 7 | import copy |
8 | 8 | import tempfile |
| 9 | +from contextlib import contextmanager |
9 | 10 |
|
10 | 11 | import pytest |
11 | 12 | import torch |
12 | 13 | import torch.nn as nn |
| 14 | +from torch.profiler import ProfilerActivity, profile |
13 | 15 |
|
14 | 16 | from torchao.prototype.mx_formats.config import ( |
15 | 17 | MXGemmKernelChoice, |
@@ -44,6 +46,23 @@ def run_around_tests(): |
44 | 46 | torch._dynamo.reset() |
45 | 47 |
|
46 | 48 |
|
| 49 | +@contextmanager |
| 50 | +def cuda_kernel_profiler(kernel_pattern): |
| 51 | + """Context manager for profiling CUDA kernels.""" |
| 52 | + result = {"found": False, "kernel_names": []} |
| 53 | + |
| 54 | + with profile(activities=[ProfilerActivity.CUDA]) as prof: |
| 55 | + yield result |
| 56 | + |
| 57 | + kernel_names = [ |
| 58 | + evt.name |
| 59 | + for evt in prof.events() |
| 60 | + if evt.device_type == torch.autograd.DeviceType.CUDA and evt.name |
| 61 | + ] |
| 62 | + result["kernel_names"] = kernel_names |
| 63 | + result["found"] = any(kernel_pattern in name for name in kernel_names) |
| 64 | + |
| 65 | + |
47 | 66 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
48 | 67 | @pytest.mark.skipif( |
49 | 68 | not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+" |
@@ -178,7 +197,14 @@ def test_inference_workflow_nvfp4( |
178 | 197 |
|
179 | 198 | x = torch.randn(batch_size, in_features, device="cuda", dtype=inpt_dtype) |
180 | 199 | y_ref = m(x) |
181 | | - y_mx = m_mx(x) |
| 200 | + |
| 201 | + if use_triton_kernel and mm_config != NVFP4MMConfig.WEIGHT_ONLY: |
| 202 | + with cuda_kernel_profiler("quantize_nvfp4_triton_kernel") as result: |
| 203 | + y_mx = m_mx(x) |
| 204 | + assert result["found"], "Expected quantize_nvfp4 kernel to be found" |
| 205 | + else: |
| 206 | + y_mx = m_mx(x) |
| 207 | + |
182 | 208 | sqnr = compute_error(y_ref, y_mx) |
183 | 209 |
|
184 | 210 | if mm_config == NVFP4MMConfig.WEIGHT_ONLY: |
|
0 commit comments