Skip to content

Commit 88a4d77

Browse files
authored
Set cuda device before init_process_group (#56)
1 parent db7b273 commit 88a4d77

File tree

2 files changed

+1
-1
lines changed

2 files changed

+1
-1
lines changed

generate.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,6 @@ def main(
260260
rank = maybe_init_dist()
261261
use_tp = rank is not None
262262
if use_tp:
263-
torch.cuda.set_device(rank)
264263
if rank != 0:
265264
# only print on rank 0
266265
print = lambda *args, **kwargs: None

tp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def maybe_init_dist() -> Optional[int]:
4242
# not run via torchrun, no-op
4343
return None
4444

45+
torch.cuda.set_device(rank)
4546
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
4647
return rank
4748

0 commit comments

Comments
 (0)