@@ -48,6 +48,10 @@ def __init__(self, trainer_config: TrainerConfig, model, optimizer, train_datase
4848 # set torchrun variables
4949 self .local_rank = int (os .environ ["LOCAL_RANK" ])
5050 self .global_rank = int (os .environ ["RANK" ])
51+ # set device
52+ self .acc = torch .accelerator .current_accelerator ()
53+ self .device : torch .device = torch .device (f"{ self .acc } :{ self .local_rank } " )
54+ self .device_type = self .device .type
5155 # data stuff
5256 self .train_dataset = train_dataset
5357 self .train_loader = self ._prepare_dataloader (train_dataset )
@@ -58,7 +62,7 @@ def __init__(self, trainer_config: TrainerConfig, model, optimizer, train_datase
5862 self .optimizer = optimizer
5963 self .save_every = self .config .save_every
6064 if self .config .use_amp :
61- self .scaler = torch .cuda . amp .GradScaler ()
65+ self .scaler = torch .amp .GradScaler (self . device_type )
6266 # load snapshot if available. only necessary on the first node.
6367 if self .config .snapshot_path is None :
6468 self .config .snapshot_path = "snapshot.pt"
@@ -93,7 +97,7 @@ def _load_snapshot(self):
9397
9498
9599 def _run_batch (self , source , targets , train : bool = True ) -> float :
96- with torch .set_grad_enabled (train ), torch .amp .autocast (device_type = "cuda" , dtype = torch .float16 , enabled = (self .config .use_amp )):
100+ with torch .set_grad_enabled (train ), torch .amp .autocast (device_type = self . device_type , dtype = torch .float16 , enabled = (self .config .use_amp )):
97101 _ , loss = self .model (source , targets )
98102
99103 if train :
@@ -119,7 +123,7 @@ def _run_epoch(self, epoch: int, dataloader: DataLoader, train: bool = True):
119123 targets = targets .to (self .local_rank )
120124 batch_loss = self ._run_batch (source , targets , train )
121125 if iter % 100 == 0 :
122- print (f"[GPU { self .global_rank } ] Epoch { epoch } | Iter { iter } | { step_type } Loss { batch_loss :.5f} " )
126+ print (f"[RANK { self .global_rank } ] Epoch { epoch } | Iter { iter } | { step_type } Loss { batch_loss :.5f} " )
123127
124128 def _save_snapshot (self , epoch ):
125129 # capture snapshot
0 commit comments