@@ -17,12 +17,11 @@ def ddp_setup():
1717 torch .accelerator .set_device_index (rank )
1818 print (f"Running on rank { rank } on device { device } " )
1919 else :
20- device = torch .device ("cpu" )
21- print (f"Running on device { device } " )
22-
23- backend = torch .distributed .get_default_backend_for_device (device )
24- torch .distributed .init_process_group (backend = backend , device_id = device )
25- return device
20+ print (f"Multi-GPU environment not detected" )
21+
22+ backend = torch .distributed .get_default_backend_for_device (rank )
23+ torch .distributed .init_process_group (backend = backend , rank = rank , device_id = rank )
24+
2625
2726class Trainer :
2827 def __init__ (
@@ -32,7 +31,6 @@ def __init__(
3231 optimizer : torch .optim .Optimizer ,
3332 save_every : int ,
3433 snapshot_path : str ,
35- device : torch .device ,
3634 ) -> None :
3735 self .local_rank = int (os .environ ["LOCAL_RANK" ])
3836 self .global_rank = int (os .environ ["RANK" ])
@@ -42,15 +40,14 @@ def __init__(
4240 self .save_every = save_every
4341 self .epochs_run = 0
4442 self .snapshot_path = snapshot_path
45- self .device = device
4643 if os .path .exists (snapshot_path ):
4744 print ("Loading snapshot" )
4845 self ._load_snapshot (snapshot_path )
4946
5047 self .model = DDP (self .model , device_ids = [self .local_rank ])
5148
5249 def _load_snapshot (self , snapshot_path ):
53- loc = str (self . device )
50+ loc = str (torch . accelerator . current_accelerator () )
5451 snapshot = torch .load (snapshot_path , map_location = loc )
5552 self .model .load_state_dict (snapshot ["MODEL_STATE" ])
5653 self .epochs_run = snapshot ["EPOCHS_RUN" ]
@@ -105,10 +102,10 @@ def prepare_dataloader(dataset: Dataset, batch_size: int):
105102
106103
107104def main (save_every : int , total_epochs : int , batch_size : int , snapshot_path : str = "snapshot.pt" ):
108- device = ddp_setup ()
105+ ddp_setup ()
109106 dataset , model , optimizer = load_train_objs ()
110107 train_data = prepare_dataloader (dataset , batch_size )
111- trainer = Trainer (model , train_data , optimizer , save_every , snapshot_path , device )
108+ trainer = Trainer (model , train_data , optimizer , save_every , snapshot_path )
112109 trainer .train (total_epochs )
113110 destroy_process_group ()
114111
0 commit comments