1111
1212
1313def ddp_setup ():
14- torch .cuda .set_device (int (os .environ ["LOCAL_RANK" ]))
15- init_process_group (backend = "nccl" )
14+ rank = int (os .environ ["LOCAL_RANK" ])
15+ if torch .accelerator .is_available ():
16+ device = torch .device (f"{ torch .accelerator .current_accelerator ()} :{ rank } " )
17+ torch .accelerator .set_device_index (rank )
18+ print (f"Running on rank { rank } on device { device } " )
19+ else :
20+ device = torch .device ("cpu" )
21+ print (f"Running on device { device } " )
22+
23+ backend = torch .distributed .get_default_backend_for_device (device )
24+ torch .distributed .init_process_group (backend = backend , device_id = device )
25+ return device
1626
1727class Trainer :
1828 def __init__ (
@@ -22,6 +32,7 @@ def __init__(
2232 optimizer : torch .optim .Optimizer ,
2333 save_every : int ,
2434 snapshot_path : str ,
35+ device : torch .device ,
2536 ) -> None :
2637 self .local_rank = int (os .environ ["LOCAL_RANK" ])
2738 self .global_rank = int (os .environ ["RANK" ])
@@ -31,14 +42,15 @@ def __init__(
3142 self .save_every = save_every
3243 self .epochs_run = 0
3344 self .snapshot_path = snapshot_path
45+ self .device = device
3446 if os .path .exists (snapshot_path ):
3547 print ("Loading snapshot" )
3648 self ._load_snapshot (snapshot_path )
3749
3850 self .model = DDP (self .model , device_ids = [self .local_rank ])
3951
4052 def _load_snapshot (self , snapshot_path ):
41- loc = f"cuda: { self .local_rank } "
53+ loc = str ( self .device )
4254 snapshot = torch .load (snapshot_path , map_location = loc )
4355 self .model .load_state_dict (snapshot ["MODEL_STATE" ])
4456 self .epochs_run = snapshot ["EPOCHS_RUN" ]
@@ -93,10 +105,10 @@ def prepare_dataloader(dataset: Dataset, batch_size: int):
93105
94106
95107def main (save_every : int , total_epochs : int , batch_size : int , snapshot_path : str = "snapshot.pt" ):
96- ddp_setup ()
108+ device = ddp_setup ()
97109 dataset , model , optimizer = load_train_objs ()
98110 train_data = prepare_dataloader (dataset , batch_size )
99- trainer = Trainer (model , train_data , optimizer , save_every , snapshot_path )
111+ trainer = Trainer (model , train_data , optimizer , save_every , snapshot_path , device )
100112 trainer .train (total_epochs )
101113 destroy_process_group ()
102114
0 commit comments