1111
1212
1313def ddp_setup ():
14- torch .cuda .set_device (int (os .environ ["LOCAL_RANK" ]))
15- init_process_group (backend = "nccl" )
14+ rank = int (os .environ ["LOCAL_RANK" ])
15+ if torch .accelerator .is_available ():
16+ device_type = torch .accelerator .current_accelerator ()
17+ device : torch .device = torch .device (f"{ device_type } :{ rank } " )
18+ torch .accelerator .device_index (rank )
19+ print (f"Running on rank { rank } on device { device } " )
20+ backend = torch .distributed .get_default_backend_for_device (device )
21+ torch .distributed .init_process_group (backend = backend )
22+ return device_type
23+ else :
24+ device = torch .device ("cpu" )
25+ print (f"Running on device { device } " )
26+ torch .distributed .init_process_group (backend = "gloo" )
27+ return device
1628
1729class Trainer :
1830 def __init__ (
@@ -22,6 +34,7 @@ def __init__(
2234 optimizer : torch .optim .Optimizer ,
2335 save_every : int ,
2436 snapshot_path : str ,
37+ device
2538 ) -> None :
2639 self .local_rank = int (os .environ ["LOCAL_RANK" ])
2740 self .global_rank = int (os .environ ["RANK" ])
@@ -31,14 +44,15 @@ def __init__(
3144 self .save_every = save_every
3245 self .epochs_run = 0
3346 self .snapshot_path = snapshot_path
47+ self .device = device
3448 if os .path .exists (snapshot_path ):
3549 print ("Loading snapshot" )
3650 self ._load_snapshot (snapshot_path )
3751
3852 self .model = DDP (self .model , device_ids = [self .local_rank ])
3953
4054 def _load_snapshot (self , snapshot_path ):
41- loc = f"cuda: { self .local_rank } "
55+ loc = str ( self .device )
4256 snapshot = torch .load (snapshot_path , map_location = loc )
4357 self .model .load_state_dict (snapshot ["MODEL_STATE" ])
4458 self .epochs_run = snapshot ["EPOCHS_RUN" ]
@@ -93,10 +107,10 @@ def prepare_dataloader(dataset: Dataset, batch_size: int):
93107
94108
95109def main (save_every : int , total_epochs : int , batch_size : int , snapshot_path : str = "snapshot.pt" ):
96- ddp_setup ()
110+ device = ddp_setup ()
97111 dataset , model , optimizer = load_train_objs ()
98112 train_data = prepare_dataloader (dataset , batch_size )
99- trainer = Trainer (model , train_data , optimizer , save_every , snapshot_path )
113+ trainer = Trainer (model , train_data , optimizer , save_every , snapshot_path , device )
100114 trainer .train (total_epochs )
101115 destroy_process_group ()
102116
0 commit comments