@@ -499,14 +499,14 @@ def root_graph_0():
499499
500500--- assertExpectedJournal(TestTypePropagation.test_cuda_device_properties)
501501def use_device_properties(x: torch.Tensor):
502- # Attribute: LiteralType(device(type='cuda', index=0) ) AttributeOrigin(value=ArgumentOrigin(name='x'), key='device')
502+ # Attribute: LiteralType(device=DEVICE ) AttributeOrigin(value=ArgumentOrigin(name='x'), key='device')
503503 # Name: TensorType([x_size0], torch.float32) ArgumentOrigin(name='x')
504504 device = x.device
505505 # Call: ClassType({'multi_processor_count': SymIntType(u0)}) SourceOrigin(location=<SourceLocation test_type_propagation.py:104>)
506506 # Attribute: CallableType(get_device_properties) AttributeOrigin(value=AttributeOrigin(value=GlobalOrigin(name='torch'), key='cuda'), key='get_device_properties')
507507 # Attribute: PythonModuleType(torch.cuda) AttributeOrigin(value=GlobalOrigin(name='torch'), key='cuda')
508508 # Name: PythonModuleType(torch) GlobalOrigin(name='torch')
509- # Name: LiteralType(device(type='cuda', index=0) ) AttributeOrigin(value=ArgumentOrigin(name='x'), key='device')
509+ # Name: LiteralType(device=DEVICE ) AttributeOrigin(value=ArgumentOrigin(name='x'), key='device')
510510 props = torch.cuda.get_device_properties(device)
511511 # Attribute: SymIntType(u0) AttributeOrigin(value=SourceOrigin(location=<SourceLocation test_type_propagation.py:104>), key='multi_processor_count')
512512 # Name: ClassType({'multi_processor_count': SymIntType(u0)}) SourceOrigin(location=<SourceLocation test_type_propagation.py:104>)
@@ -737,7 +737,7 @@ def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]]
737737 # Name: TensorType([512, 512], torch.float32) ArgumentOrigin(name='x')
738738 # Attribute: LiteralType(torch.float32) AttributeOrigin(value=ArgumentOrigin(name='y'), key='dtype')
739739 # Name: TensorType([512, 512], torch.float32) ArgumentOrigin(name='y')
740- # Attribute: LiteralType(device(type='cpu') ) AttributeOrigin(value=ArgumentOrigin(name='x'), key='device')
740+ # Attribute: LiteralType(device=DEVICE ) AttributeOrigin(value=ArgumentOrigin(name='x'), key='device')
741741 # Name: TensorType([512, 512], torch.float32) ArgumentOrigin(name='x')
742742 # For: loop_type=GRID
743743 out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
0 commit comments