@@ -47,6 +47,19 @@ def is_cuda() -> bool:
4747 )
4848
4949
50+ def get_nvidia_gpu_model () -> str :
51+ """
52+ Retrieves the model of the NVIDIA GPU being used.
53+ Will return the name of the current device.
54+ Returns:
55+ str: The model of the NVIDIA GPU or empty str if not found.
56+ """
57+ if torch .cuda .is_available ():
58+ props = torch .cuda .get_device_properties (torch .cuda .current_device ())
59+ return getattr (props , "name" , "" )
60+ return ""
61+
62+
5063def skipIfRefEager (reason : str ) -> Callable [[Callable ], Callable ]:
5164 """Skip test if running in ref eager mode (HELION_INTERPRET=1)."""
5265 return unittest .skipIf (os .environ .get ("HELION_INTERPRET" ) == "1" , reason )
@@ -67,6 +80,13 @@ def skipIfXPU(reason: str) -> Callable[[Callable], Callable]:
6780 return unittest .skipIf (torch .xpu .is_available (), reason ) # pyright: ignore[reportAttributeAccessIssue]
6881
6982
83+ def skipIfA10G (reason : str ) -> Callable [[Callable ], Callable ]:
84+ """Skip test if running on A10G GPU"""
85+ gpu_model = get_nvidia_gpu_model ()
86+ is_a10g = "A10G" in gpu_model
87+ return unittest .skipIf (is_a10g , reason )
88+
89+
7090def skipIfNotCUDA () -> Callable [[Callable ], Callable ]:
7191 """Skip test if not running on CUDA (NVIDIA GPU)."""
7292 return unittest .skipIf (
0 commit comments