@@ -55,7 +55,7 @@ def get_local_rank():
5555 raise RuntimeError ('Please use "python -m torch.distributed.launch --nproc_per_node=N train_script.py' )
5656
5757
58- class DistTrainer () :
58+ class DistTrainer :
5959 r"""
6060 分布式的 Trainer,支持分布式训练和混合精度的训练。具体实现原理请阅读 pytorch 官方文档。
6161
@@ -110,7 +110,7 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
110110 int dev_batch_size: 在evaluate时,使用的evaluate的batch大小
111111 bool test_use_fp16: test时使用fp16
112112 bool set_grad_to_none: zero_grad时将grad设为None而不是0
113- GradScaler gradscaler : 自定义的梯度 scaler
113+ GradScaler grad_scaler : 自定义的梯度 scaler
114114 bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快数据速度。一般在tensor较多或tensor维度较大时,有速度增益。
115115 bool find_unused_parameters: 在将model转化为DistributedDataParallel类型的时候,需要填入该参数,除非model内确实有
116116 forward没用上的参数,否则应该不需要用到该参数。
@@ -132,6 +132,7 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
132132 self .rank = dist .get_rank () # unique id for each process
133133
134134 self .train_data = train_data
135+ self .kwargs = kwargs
135136 if kwargs .get ('batch_size' , None ):
136137 batch_size_per_gpu = int (kwargs .get ('batch_size' ))
137138 self .batch_size_per_gpu = int (batch_size_per_gpu )
@@ -158,15 +159,15 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
158159 # init fp16, must before DataParallel init
159160 autocast , GradScaler = _build_fp16_env (dummy = not self .fp16 )
160161 self .auto_cast = autocast
161- user_grad_scaler = getattr ( kwargs , 'gradscaler ' , None )
162+ user_grad_scaler = kwargs . get ( 'grad_scaler ' , None )
162163 if user_grad_scaler is not None :
163- assert self .fp16 , "must set fp16=True to enable gradscaler "
164+ assert self .fp16 , "must set fp16=True to enable grad_scaler "
164165 grad_scaler = user_grad_scaler
165166 else :
166167 grad_scaler = GradScaler ()
167168 self .grad_scaler = grad_scaler
168169
169- self .set_grad_to_none = getattr ( kwargs , 'set_grad_to_none' , True )
170+ self .set_grad_to_none = kwargs . get ( 'set_grad_to_none' , False )
170171
171172 # init DataParallel
172173 if parse_version (torch .__version__ )>= parse_version ('1.1' ):
@@ -191,15 +192,15 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
191192 elif hasattr (sampler , 'set_batch_size' ):
192193 sampler .set_batch_size (batch_size_per_gpu )
193194 self .sampler = sampler
194- self .pin_memory = kwargs .get ('pin_memory' , True )
195+ # concerning issue from https://github.com/pytorch/pytorch/issues/57273
196+ self .pin_memory = kwargs .get ('pin_memory' , False if parse_version (torch .__version__ )== parse_version ('1.9' ) else True )
195197 self .data_iterator = self ._get_data_iter (self .train_data )
196198 self .batch_size = self .world_size * self .batch_size_per_gpu
197199 self .n_steps = self ._get_n_steps ()
198200
199201 self .dev_data = dev_data
200202 self .metrics = metrics
201203 self .test_use_tqdm = True
202- self .kwargs = kwargs
203204 self .test_use_tqdm = kwargs .get ('test_use_tqdm' , self .use_tqdm )
204205 dev_batch_size = kwargs .get ('dev_batch_size' , batch_size_per_gpu )
205206
@@ -229,22 +230,6 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
229230 self .logger .info ("Num of processes: {}" .format (self .world_size ))
230231 self .logger .info ("Use device: {}" .format (device ))
231232
232- def _maybe_no_sync (self ):
233- """
234- Whenever *samples* contains more than one mini-batch, we
235- want to accumulate gradients locally and only call
236- all-reduce in the last backwards pass.
237- """
238- i = self .step % self .update_every
239- if (
240- self .world_size > 1
241- and hasattr (self .ddp_model , "no_sync" )
242- and i != 0
243- ):
244- return self .ddp_model .no_sync ()
245- else :
246- return contextlib .ExitStack () # dummy contextmanager
247-
248233 def _get_n_steps (self ):
249234 return len (self .data_iterator ) * self .n_epochs
250235
@@ -365,37 +350,42 @@ def _train(self):
365350 self .callback_manager .on_epoch_begin ()
366351 for batch_x , batch_y in data_iterator :
367352 self .step += 1
368- self .ddp_model .train ()
369- _move_dict_value_to_device (batch_x , batch_y , device = self .device , non_blocking = self .pin_memory )
370- indices = data_iterator .get_batch_indices ()
371- # negative sampling; replace unknown; re-weight batch_y
372- self .callback_manager .on_batch_begin (batch_x , batch_y , indices )
373- with self .auto_cast ():
374- prediction = self ._data_forward (self .ddp_model , batch_x )
375- # edit prediction
376- self .callback_manager .on_loss_begin (batch_y , prediction )
377- loss = self ._compute_loss (prediction , batch_y )
378-
379- avg_loss += loss .detach ()
380-
381- # Is loss NaN or inf? requires_grad = False
382- self .callback_manager .on_backward_begin (loss )
383- self ._grad_backward (loss )
384- self .callback_manager .on_backward_end ()
385- self ._update ()
386- self .callback_manager .on_step_end ()
387-
388- if self .step % self .print_every == 0 :
389- avg_loss = float (avg_loss ) / self .print_every
390- print_output = "loss:{:<6.5f}" .format (avg_loss )
391- pbar .update (self .print_every )
392- pbar .set_postfix_str (print_output )
393- avg_loss = 0
394-
395- self .callback_manager .on_batch_end ()
396-
397- if (self .validate_every > 0 and self .step % self .validate_every == 0 ) and len (self .test_manager .callbacks ):
398- self ._do_validation ()
353+ if self .step % self .update_every != 0 :
354+ no_sync = self .ddp_model .no_sync
355+ else :
356+ no_sync = contextlib .ExitStack
357+ with no_sync ():
358+ self .ddp_model .train ()
359+ _move_dict_value_to_device (batch_x , batch_y , device = self .device , non_blocking = self .pin_memory )
360+ indices = data_iterator .get_batch_indices ()
361+ # negative sampling; replace unknown; re-weight batch_y
362+ self .callback_manager .on_batch_begin (batch_x , batch_y , indices )
363+ with self .auto_cast ():
364+ prediction = self ._data_forward (self .ddp_model , batch_x )
365+ # edit prediction
366+ self .callback_manager .on_loss_begin (batch_y , prediction )
367+ loss = self ._compute_loss (prediction , batch_y )
368+
369+ avg_loss += loss .detach ()
370+
371+ # Is loss NaN or inf? requires_grad = False
372+ self .callback_manager .on_backward_begin (loss )
373+ self ._grad_backward (loss )
374+ self .callback_manager .on_backward_end ()
375+ self ._update ()
376+ self .callback_manager .on_step_end ()
377+
378+ if self .step % self .print_every == 0 :
379+ avg_loss = float (avg_loss ) / self .print_every
380+ print_output = "loss:{:<6.5f}" .format (avg_loss )
381+ pbar .update (self .print_every )
382+ pbar .set_postfix_str (print_output )
383+ avg_loss = 0
384+
385+ self .callback_manager .on_batch_end ()
386+
387+ if (self .validate_every > 0 and self .step % self .validate_every == 0 ) and len (self .test_manager .callbacks ):
388+ self ._do_validation ()
399389
400390 # ================= mini-batch end ==================== #
401391 if self .validate_every < 0 and len (self .test_manager .callbacks ):
0 commit comments