We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent db7b273 commit 88a4d77Copy full SHA for 88a4d77
generate.py
@@ -260,7 +260,6 @@ def main(
260
rank = maybe_init_dist()
261
use_tp = rank is not None
262
if use_tp:
263
- torch.cuda.set_device(rank)
264
if rank != 0:
265
# only print on rank 0
266
print = lambda *args, **kwargs: None
tp.py
@@ -42,6 +42,7 @@ def maybe_init_dist() -> Optional[int]:
42
# not run via torchrun, no-op
43
return None
44
45
+ torch.cuda.set_device(rank)
46
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
47
return rank
48
0 commit comments