Skip to content

Commit fc69870

Browse files
authored
Add skipIfA10G decorator (#982)
1 parent 72fbdca commit fc69870

File tree

2 files changed

+21
-17
lines changed

2 files changed

+21
-17
lines changed

benchmarks/run.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from torch.utils._pytree import tree_leaves
4646
from torch.utils._pytree import tree_map
4747

48+
from helion._testing import get_nvidia_gpu_model
4849
from helion._utils import counters
4950

5051
logger: logging.Logger = logging.getLogger(__name__)
@@ -59,23 +60,6 @@ def is_cuda() -> bool:
5960
return torch.version.cuda is not None
6061

6162

62-
def get_nvidia_gpu_model() -> str:
63-
"""
64-
Retrieves the model of the NVIDIA GPU being used.
65-
Will return the name of the first GPU listed.
66-
Returns:
67-
str: The model of the NVIDIA GPU or empty str if not found.
68-
"""
69-
try:
70-
model = subprocess.check_output(
71-
["nvidia-smi", "--query-gpu=name", "--format=csv,noheader,nounits"]
72-
)
73-
return model.decode().strip().split("\n")[0]
74-
except OSError:
75-
logger.warning("nvidia-smi not found. Returning empty str.")
76-
return ""
77-
78-
7963
IS_B200 = is_cuda() and get_nvidia_gpu_model() == "NVIDIA B200"
8064

8165

helion/_testing.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,19 @@ def is_cuda() -> bool:
4747
)
4848

4949

50+
def get_nvidia_gpu_model() -> str:
51+
"""
52+
Retrieves the model of the NVIDIA GPU being used.
53+
Will return the name of the current device.
54+
Returns:
55+
str: The model of the NVIDIA GPU or empty str if not found.
56+
"""
57+
if torch.cuda.is_available():
58+
props = torch.cuda.get_device_properties(torch.cuda.current_device())
59+
return getattr(props, "name", "")
60+
return ""
61+
62+
5063
def skipIfRefEager(reason: str) -> Callable[[Callable], Callable]:
5164
"""Skip test if running in ref eager mode (HELION_INTERPRET=1)."""
5265
return unittest.skipIf(os.environ.get("HELION_INTERPRET") == "1", reason)
@@ -67,6 +80,13 @@ def skipIfXPU(reason: str) -> Callable[[Callable], Callable]:
6780
return unittest.skipIf(torch.xpu.is_available(), reason) # pyright: ignore[reportAttributeAccessIssue]
6881

6982

83+
def skipIfA10G(reason: str) -> Callable[[Callable], Callable]:
84+
"""Skip test if running on A10G GPU"""
85+
gpu_model = get_nvidia_gpu_model()
86+
is_a10g = "A10G" in gpu_model
87+
return unittest.skipIf(is_a10g, reason)
88+
89+
7090
def skipIfNotCUDA() -> Callable[[Callable], Callable]:
7191
"""Skip test if not running on CUDA (NVIDIA GPU)."""
7292
return unittest.skipIf(

0 commit comments

Comments
 (0)