diff --git a/README.md b/README.md index 0f998ea4..2641d6ac 100644 --- a/README.md +++ b/README.md @@ -137,11 +137,8 @@ def main(): world_size = torchcomm.get_size() # Calculate device ID - num_devices = torch.cuda.device_count() - device_id = rank % num_devices - target_device = torch.device(f"cuda:{device_id}") - - print(f"Rank {rank}/{world_size}: Running on device {device_id}") + target_device = torchcomm.get_device() + print(f"Rank {rank}/{world_size}: Running on device {target_device.index}") # Create a tensor with rank-specific data tensor = torch.full( @@ -212,8 +209,7 @@ device = torch.device("cuda") torchcomm = new_comm("nccl", device, name="main_comm") rank = torchcomm.get_rank() -device_id = rank % torch.cuda.device_count() -target_device = torch.device(f"cuda:{device_id}") +target_device = torchcomm.get_device() # Create tensor tensor = torch.full((1024,), float(rank + 1), dtype=torch.float32, device=target_device)