@@ -17,24 +17,20 @@ def ddp_setup(rank, world_size):
1717 world_size: Total number of processes
1818 """
1919 os .environ ["MASTER_ADDR" ] = "localhost"
20- os .environ ["MASTER_PORT" ] = "12453 "
20+ os .environ ["MASTER_PORT" ] = "12455 "
2121
2222
23+ rank = int (os .environ ["LOCAL_RANK" ])
2324 if torch .accelerator .is_available ():
2425 device_type = torch .accelerator .current_accelerator ()
25- torch .accelerator .set_device_idx (rank )
26- device : torch .device = torch .device (f"{ device_type } :{ rank } " )
26+ device = torch .device (f"{ device_type } :{ rank } " )
2727 torch .accelerator .device_index (rank )
2828 print (f"Running on rank { rank } on device { device } " )
29- backend = torch .distributed .get_default_backend_for_device (device )
30- torch .distributed .init_process_group (backend = backend , rank = rank , world_size = world_size , device_id = device )
3129 else :
3230 device = torch .device ("cpu" )
3331 print (f"Running on device { device } " )
34- torch .distributed .init_process_group (backend = "gloo" , device_id = device )
3532
36- # torch.cuda.set_device(rank)
37- # init_process_group(backend="xccl", rank=rank, world_size=world_size)
33+ backend = torch .distributed .get_default_backend_for_device (device )
3834
3935class Trainer :
4036 def __init__ (
@@ -116,5 +112,4 @@ def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_s
116112 args = parser .parse_args ()
117113
118114 world_size = torch .accelerator .device_count ()
119- print (world_size )
120115 mp .spawn (main , args = (world_size , args .save_every , args .total_epochs , args .batch_size ), nprocs = world_size )
0 commit comments