Skip to content

Commit 5354356

Browse files
authored
Normalize device name and decorate cuda-only test cases (#819)
1 parent ebb9f34 commit 5354356

File tree

6 files changed

+21
-9
lines changed

6 files changed

+21
-9
lines changed

helion/_testing.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -703,9 +703,16 @@ def normalize_tensor_descriptors(code: str) -> str:
703703
@staticmethod
704704
def normalize_device_name(code: str) -> str:
705705
"""
706-
convert device='cuda:0' etc to device=DEVICE
706+
convert device='cuda:0' or device(type='cuda', index=0) etc to device=DEVICE
707707
"""
708-
return re.sub(r"device\s*=\s*['\"][^'\"]+['\"]", "device=DEVICE", code)
708+
# device='cuda:0'
709+
reg_pattern_for_device_str = r"device\s*=\s*['\"][^'\"]+['\"]"
710+
normalized_code = re.sub(reg_pattern_for_device_str, "device=DEVICE", code)
711+
# device(type='cuda', index=0)
712+
reg_pattern_for_torch_device = (
713+
r"device\s*\(type\s*=\s*['\"][^'\"]+['\"][^'\"\)]*\)"
714+
)
715+
return re.sub(reg_pattern_for_torch_device, "device=DEVICE", normalized_code)
709716

710717
def lookup(self, test_id: str, value: str) -> tuple[str, str]:
711718
test_id = self.normalize_id(test_id)

test/test_constexpr.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from helion._testing import TestCase
1111
from helion._testing import code_and_output
1212
from helion._testing import skipIfRefEager
13+
from helion._testing import skipIfXPU
1314
import helion.language as hl
1415

1516

@@ -94,6 +95,7 @@ def fn(x: torch.Tensor, mode: str) -> torch.Tensor:
9495
self.assertExpectedJournal(code)
9596

9697
@skipIfRefEager("Triton codegen does not work in ref eager mode")
98+
@skipIfXPU("Failed on XPU due to a different configuration for min dot size")
9799
def test_block_size_constexpr_assignment_in_host_code(self) -> None:
98100
@helion.kernel(
99101
config=helion.Config(

test/test_examples.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from helion._testing import import_path
1515
from helion._testing import skipIfRefEager
1616
from helion._testing import skipIfRocm
17+
from helion._testing import skipIfXPU
1718

1819
torch.backends.cuda.matmul.fp32_precision = "tf32"
1920
torch.backends.cudnn.conv.fp32_precision = "tf32"
@@ -163,6 +164,7 @@ def test_template_via_closure0(self):
163164
)
164165
)
165166

167+
@skipIfXPU("Failed on XPU - https://github.com/pytorch/helion/issues/795")
166168
def test_template_via_closure1(self):
167169
bias = torch.randn([1, 1024], device=DEVICE, dtype=torch.float16)
168170
args = (

test/test_reductions.expected

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def reduce_kernel(x: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], o
392392
# List: SequenceType([SymIntType(s77)]) SourceOrigin(location=<SourceLocation test_reductions.py:52>)
393393
# Name: SymIntType(s77) GetItemOrigin(value=SourceOrigin(location=<SourceLocation test_reductions.py:50>), key=0)
394394
# Name: LiteralType(torch.float32) ArgumentOrigin(name='out_dtype')
395-
# Attribute: LiteralType(device(type='cuda', index=0)) AttributeOrigin(value=ArgumentOrigin(name='x'), key='device')
395+
# Attribute: LiteralType(device=DEVICE) AttributeOrigin(value=ArgumentOrigin(name='x'), key='device')
396396
# Name: TensorType([x_size0, x_size1], torch.float32) ArgumentOrigin(name='x')
397397
# For: loop_type=GRID
398398
out = torch.empty([n], dtype=out_dtype, device=x.device)

test/test_signal_wait.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from helion._testing import RefEagerTestDisabled
1010
from helion._testing import TestCase
1111
from helion._testing import code_and_output
12+
from helion._testing import skipIfNotCUDA
1213
from helion._testing import skipIfRocm
1314
import helion.language as hl
1415

@@ -82,7 +83,7 @@ def gmem_wait_multi_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
8283
self.maxDiff = None
8384
self.assertExpectedJournal(code)
8485

85-
@skipIfRocm("only works on cuda")
86+
@skipIfNotCUDA()
8687
def test_wait_multi_bar_cas(self):
8788
@helion.kernel
8889
def gmem_wait_multi_bar_kernel_cas(signal_pad: torch.Tensor) -> torch.Tensor:
@@ -156,7 +157,7 @@ def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
156157
)
157158
self.assertExpectedJournal(code)
158159

159-
@skipIfRocm("only works on cuda")
160+
@skipIfNotCUDA()
160161
def test_signal_multiple_cas(self):
161162
@helion.kernel
162163
def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
@@ -218,7 +219,7 @@ def gmem_multi_bar_sync_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
218219
)
219220
self.assertExpectedJournal(code)
220221

221-
@skipIfRocm("only works on cuda")
222+
@skipIfNotCUDA()
222223
def test_global_sync_cas(self):
223224
@helion.kernel
224225
def gmem_multi_bar_sync_kernel(signal_pad: torch.Tensor) -> torch.Tensor:

test/test_type_propagation.expected

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -499,14 +499,14 @@ def root_graph_0():
499499

500500
--- assertExpectedJournal(TestTypePropagation.test_cuda_device_properties)
501501
def use_device_properties(x: torch.Tensor):
502-
# Attribute: LiteralType(device(type='cuda', index=0)) AttributeOrigin(value=ArgumentOrigin(name='x'), key='device')
502+
# Attribute: LiteralType(device=DEVICE) AttributeOrigin(value=ArgumentOrigin(name='x'), key='device')
503503
# Name: TensorType([x_size0], torch.float32) ArgumentOrigin(name='x')
504504
device = x.device
505505
# Call: ClassType({'multi_processor_count': SymIntType(u0)}) SourceOrigin(location=<SourceLocation test_type_propagation.py:104>)
506506
# Attribute: CallableType(get_device_properties) AttributeOrigin(value=AttributeOrigin(value=GlobalOrigin(name='torch'), key='cuda'), key='get_device_properties')
507507
# Attribute: PythonModuleType(torch.cuda) AttributeOrigin(value=GlobalOrigin(name='torch'), key='cuda')
508508
# Name: PythonModuleType(torch) GlobalOrigin(name='torch')
509-
# Name: LiteralType(device(type='cuda', index=0)) AttributeOrigin(value=ArgumentOrigin(name='x'), key='device')
509+
# Name: LiteralType(device=DEVICE) AttributeOrigin(value=ArgumentOrigin(name='x'), key='device')
510510
props = torch.cuda.get_device_properties(device)
511511
# Attribute: SymIntType(u0) AttributeOrigin(value=SourceOrigin(location=<SourceLocation test_type_propagation.py:104>), key='multi_processor_count')
512512
# Name: ClassType({'multi_processor_count': SymIntType(u0)}) SourceOrigin(location=<SourceLocation test_type_propagation.py:104>)
@@ -737,7 +737,7 @@ def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]]
737737
# Name: TensorType([512, 512], torch.float32) ArgumentOrigin(name='x')
738738
# Attribute: LiteralType(torch.float32) AttributeOrigin(value=ArgumentOrigin(name='y'), key='dtype')
739739
# Name: TensorType([512, 512], torch.float32) ArgumentOrigin(name='y')
740-
# Attribute: LiteralType(device(type='cpu')) AttributeOrigin(value=ArgumentOrigin(name='x'), key='device')
740+
# Attribute: LiteralType(device=DEVICE) AttributeOrigin(value=ArgumentOrigin(name='x'), key='device')
741741
# Name: TensorType([512, 512], torch.float32) ArgumentOrigin(name='x')
742742
# For: loop_type=GRID
743743
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)

0 commit comments

Comments
 (0)