3232from .utils import _build_fp16_env
3333from .utils import _get_func_signature
3434from .utils import _move_dict_value_to_device
35+ from .sampler import Sampler
3536
3637__all__ = [
3738 'get_local_rank' ,
@@ -68,7 +69,7 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
6869 dev_data = None , metrics = None , metric_key = None ,
6970 update_every = 1 , print_every = 10 , validate_every = - 1 ,
7071 save_path = None , device = 'auto' ,
71- fp16 = False , use_tqdm = True , ** kwargs ):
72+ fp16 = False , use_tqdm = True , sampler = None , ** kwargs ):
7273 r"""
7374
7475 :param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。
@@ -101,13 +102,18 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
101102 :param str device: 指定 device,可以是 gpu,cpu 或 auto
102103 :param bool fp16: 指定是否使用半精度训练。
103104 :param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。
105+ :param Sampler sampler: 使用的sampler,如果不指定,默认使用的DistributedSampler。使用这个参数的情况一般为,明确修改了每个
106+ rank的Dataset,使得每个rank上的dataset虽然sample数量一样多,但是sample其实不一样。
104107 :param kwargs: 支持配置可选参数
105108 bool test_use_tqdm: 在dev上验证的时候是否开启tqdm
106109 Sampler test_sampler: 在evaluate的时候使用的sampler
107110 int dev_batch_size: 在evaluate时,使用的evaluate的batch大小
108111 bool test_use_fp16: test时使用fp16
109112 bool set_grad_to_none: zero_grad时将grad设为None而不是0
110113 GradScaler gradscaler: 自定义的梯度 scaler
114+ bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快数据速度。一般在tensor较多或tensor维度较大时,有速度增益。
115+ bool find_unused_parameters: 在将model转化为DistributedDataParallel类型的时候,需要填入该参数,除非model内确实有
116+ forward没用上的参数,否则应该不需要用到该参数。
111117 """
112118 assert device in ['auto' , 'cuda' , 'cpu' ], "Please set correct device in [auto', 'cuda', 'cpu']"
113119 if device == 'auto' :
@@ -126,6 +132,8 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
126132 self .rank = dist .get_rank () # unique id for each process
127133
128134 self .train_data = train_data
135+ if kwargs .get ('batch_size' , None ):
136+ batch_size_per_gpu = int (kwargs .get ('batch_size' ))
129137 self .batch_size_per_gpu = int (batch_size_per_gpu )
130138 self .n_epochs = int (n_epochs )
131139 self .num_data_workers = int (num_workers )
@@ -163,7 +171,8 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
163171 # init DataParallel
164172 if parse_version (torch .__version__ )>= parse_version ('1.1' ):
165173 self .ddp_model = DDP (model , device_ids = [self .local_rank ],
166- output_device = self .local_rank , find_unused_parameters = True )
174+ output_device = self .local_rank ,
175+ find_unused_parameters = kwargs .get ('find_unused_parameters' , False ))
167176 else :
168177 self .ddp_model = DDP (model , device_ids = [self .local_rank ],
169178 output_device = self .local_rank )
@@ -172,7 +181,17 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
172181 optimizer = self ._get_optimizer (optimizer )
173182 self .optimizer = optimizer
174183 if isinstance (self .train_data , DataSet ):
175- self .sampler = DistributedSampler (self .train_data )
184+ if sampler is None :
185+ self .sampler = DistributedSampler (self .train_data )
186+ else :
187+ # sampler check
188+ if sampler is not None and not isinstance (sampler , (Sampler , torch .utils .data .Sampler )):
189+ raise ValueError (
190+ f"The type of sampler should be fastNLP.BaseSampler or pytorch's Sampler, got { type (sampler )} " )
191+ elif hasattr (sampler , 'set_batch_size' ):
192+ sampler .set_batch_size (batch_size_per_gpu )
193+ self .sampler = sampler
194+ self .pin_memory = kwargs .get ('pin_memory' , True )
176195 self .data_iterator = self ._get_data_iter (self .train_data )
177196 self .batch_size = self .world_size * self .batch_size_per_gpu
178197 self .n_steps = self ._get_n_steps ()
@@ -191,7 +210,6 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
191210 batch_size = dev_batch_size , num_workers = num_workers , sampler = kwargs .get ('test_sampler' , None ),
192211 use_tqdm = self .test_use_tqdm )
193212 self .test_manager .add_callback ([cb ], master = True )
194-
195213 # Setup logging
196214 # 同步start_time
197215 sync_time = torch .tensor (time .time (), dtype = torch .double ).to (self .device )
@@ -233,7 +251,8 @@ def _get_n_steps(self):
233251 def _get_data_iter (self , dataset ):
234252 if isinstance (dataset , DataSet ):
235253 return DataSetIter (dataset = dataset , batch_size = self .batch_size_per_gpu , sampler = self .sampler ,
236- num_workers = self .num_data_workers , drop_last = self .drop_last )
254+ num_workers = self .num_data_workers , drop_last = self .drop_last ,
255+ pin_memory = self .pin_memory )
237256 elif isinstance (dataset , BatchIter ):
238257 return dataset
239258 else :
@@ -347,7 +366,7 @@ def _train(self):
347366 for batch_x , batch_y in data_iterator :
348367 self .step += 1
349368 self .ddp_model .train ()
350- _move_dict_value_to_device (batch_x , batch_y , device = self .device )
369+ _move_dict_value_to_device (batch_x , batch_y , device = self .device , non_blocking = self . pin_memory )
351370 indices = data_iterator .get_batch_indices ()
352371 # negative sampling; replace unknown; re-weight batch_y
353372 self .callback_manager .on_batch_begin (batch_x , batch_y , indices )
@@ -361,10 +380,9 @@ def _train(self):
361380
362381 # Is loss NaN or inf? requires_grad = False
363382 self .callback_manager .on_backward_begin (loss )
364- self .grad_scaler . scale (loss ). backward ( )
383+ self ._grad_backward (loss )
365384 self .callback_manager .on_backward_end ()
366- if self .step % self .update_every == 0 :
367- self ._update ()
385+ self ._update ()
368386 self .callback_manager .on_step_end ()
369387
370388 if self .step % self .print_every == 0 :
@@ -390,7 +408,7 @@ def _train(self):
390408 self .pbar = None
391409 # ============ tqdm end ============== #
392410
393- def _clear_grad_opt (self , optimizer ):
411+ def _clear_grad (self , optimizer ):
394412 if self .set_grad_to_none :
395413 for group in optimizer .param_groups :
396414 for p in group ['params' ]:
@@ -399,13 +417,24 @@ def _clear_grad_opt(self, optimizer):
399417 else :
400418 optimizer .zero_grad ()
401419
420+ def _grad_backward (self , loss ):
421+ r"""Compute gradient with link rules.
422+
423+ :param loss: a scalar where back-prop starts
424+
425+ For PyTorch, just do "loss.backward()"
426+ """
427+ if (self .step - 1 ) % self .update_every == 0 :
428+ self ._clear_grad (self .optimizer )
429+ self .grad_scaler .scale (loss ).backward ()
430+
402431 def _update (self ):
403432 r"""Perform weight update on a model.
404433
405434 """
406- self .grad_scaler . step ( self .optimizer )
407- self .grad_scaler .update ( )
408- self ._clear_grad_opt ( self . optimizer )
435+ if self .step % self .update_every == 0 :
436+ self .grad_scaler .step ( self . optimizer )
437+ self .grad_scaler . update ( )
409438
410439 def _data_forward (self , network , x ):
411440 x = _build_args (self ._forward_func , ** x )
0 commit comments