Skip to content

Commit 7ce03a6

Browse files
committed
1.修改Trainer中的pin_memory参数; 2.修改DistTrainer使得DistTrainer和Trainer的api使用可以尽量接近
1 parent e9c6bf7 commit 7ce03a6

File tree

3 files changed

+55
-19
lines changed

3 files changed

+55
-19
lines changed

fastNLP/core/dist_trainer.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from .utils import _build_fp16_env
3333
from .utils import _get_func_signature
3434
from .utils import _move_dict_value_to_device
35+
from .sampler import Sampler
3536

3637
__all__ = [
3738
'get_local_rank',
@@ -68,7 +69,7 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
6869
dev_data=None, metrics=None, metric_key=None,
6970
update_every=1, print_every=10, validate_every=-1,
7071
save_path=None, device='auto',
71-
fp16=False, use_tqdm=True, **kwargs):
72+
fp16=False, use_tqdm=True, sampler=None, **kwargs):
7273
r"""
7374
7475
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。
@@ -101,13 +102,18 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
101102
:param str device: 指定 device,可以是 gpu,cpu 或 auto
102103
:param bool fp16: 指定是否使用半精度训练。
103104
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。
105+
:param Sampler sampler: 使用的sampler,如果不指定,默认使用的DistributedSampler。使用这个参数的情况一般为,明确修改了每个
106+
rank的Dataset,使得每个rank上的dataset虽然sample数量一样多,但是sample其实不一样。
104107
:param kwargs: 支持配置可选参数
105108
bool test_use_tqdm: 在dev上验证的时候是否开启tqdm
106109
Sampler test_sampler: 在evaluate的时候使用的sampler
107110
int dev_batch_size: 在evaluate时,使用的evaluate的batch大小
108111
bool test_use_fp16: test时使用fp16
109112
bool set_grad_to_none: zero_grad时将grad设为None而不是0
110113
GradScaler gradscaler: 自定义的梯度 scaler
114+
bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快数据速度。一般在tensor较多或tensor维度较大时,有速度增益。
115+
bool find_unused_parameters: 在将model转化为DistributedDataParallel类型的时候,需要填入该参数,除非model内确实有
116+
forward没用上的参数,否则应该不需要用到该参数。
111117
"""
112118
assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']"
113119
if device == 'auto':
@@ -126,6 +132,8 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
126132
self.rank = dist.get_rank() # unique id for each process
127133

128134
self.train_data = train_data
135+
if kwargs.get('batch_size', None):
136+
batch_size_per_gpu = int(kwargs.get('batch_size'))
129137
self.batch_size_per_gpu = int(batch_size_per_gpu)
130138
self.n_epochs = int(n_epochs)
131139
self.num_data_workers = int(num_workers)
@@ -163,7 +171,8 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
163171
# init DataParallel
164172
if parse_version(torch.__version__)>=parse_version('1.1'):
165173
self.ddp_model = DDP(model, device_ids=[self.local_rank],
166-
output_device=self.local_rank, find_unused_parameters=True)
174+
output_device=self.local_rank,
175+
find_unused_parameters=kwargs.get('find_unused_parameters', False))
167176
else:
168177
self.ddp_model = DDP(model, device_ids=[self.local_rank],
169178
output_device=self.local_rank)
@@ -172,7 +181,17 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
172181
optimizer = self._get_optimizer(optimizer)
173182
self.optimizer = optimizer
174183
if isinstance(self.train_data, DataSet):
175-
self.sampler = DistributedSampler(self.train_data)
184+
if sampler is None:
185+
self.sampler = DistributedSampler(self.train_data)
186+
else:
187+
# sampler check
188+
if sampler is not None and not isinstance(sampler, (Sampler, torch.utils.data.Sampler)):
189+
raise ValueError(
190+
f"The type of sampler should be fastNLP.BaseSampler or pytorch's Sampler, got {type(sampler)}")
191+
elif hasattr(sampler, 'set_batch_size'):
192+
sampler.set_batch_size(batch_size_per_gpu)
193+
self.sampler = sampler
194+
self.pin_memory = kwargs.get('pin_memory', True)
176195
self.data_iterator = self._get_data_iter(self.train_data)
177196
self.batch_size = self.world_size * self.batch_size_per_gpu
178197
self.n_steps = self._get_n_steps()
@@ -191,7 +210,6 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
191210
batch_size=dev_batch_size, num_workers=num_workers, sampler=kwargs.get('test_sampler', None),
192211
use_tqdm=self.test_use_tqdm)
193212
self.test_manager.add_callback([cb], master=True)
194-
195213
# Setup logging
196214
# 同步start_time
197215
sync_time = torch.tensor(time.time(), dtype=torch.double).to(self.device)
@@ -233,7 +251,8 @@ def _get_n_steps(self):
233251
def _get_data_iter(self, dataset):
234252
if isinstance(dataset, DataSet):
235253
return DataSetIter(dataset=dataset, batch_size=self.batch_size_per_gpu, sampler=self.sampler,
236-
num_workers=self.num_data_workers, drop_last=self.drop_last)
254+
num_workers=self.num_data_workers, drop_last=self.drop_last,
255+
pin_memory=self.pin_memory)
237256
elif isinstance(dataset, BatchIter):
238257
return dataset
239258
else:
@@ -347,7 +366,7 @@ def _train(self):
347366
for batch_x, batch_y in data_iterator:
348367
self.step += 1
349368
self.ddp_model.train()
350-
_move_dict_value_to_device(batch_x, batch_y, device=self.device)
369+
_move_dict_value_to_device(batch_x, batch_y, device=self.device, non_blocking=self.pin_memory)
351370
indices = data_iterator.get_batch_indices()
352371
# negative sampling; replace unknown; re-weight batch_y
353372
self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
@@ -361,10 +380,9 @@ def _train(self):
361380

362381
# Is loss NaN or inf? requires_grad = False
363382
self.callback_manager.on_backward_begin(loss)
364-
self.grad_scaler.scale(loss).backward()
383+
self._grad_backward(loss)
365384
self.callback_manager.on_backward_end()
366-
if self.step % self.update_every == 0:
367-
self._update()
385+
self._update()
368386
self.callback_manager.on_step_end()
369387

370388
if self.step % self.print_every == 0:
@@ -390,7 +408,7 @@ def _train(self):
390408
self.pbar = None
391409
# ============ tqdm end ============== #
392410

393-
def _clear_grad_opt(self, optimizer):
411+
def _clear_grad(self, optimizer):
394412
if self.set_grad_to_none:
395413
for group in optimizer.param_groups:
396414
for p in group['params']:
@@ -399,13 +417,24 @@ def _clear_grad_opt(self, optimizer):
399417
else:
400418
optimizer.zero_grad()
401419

420+
def _grad_backward(self, loss):
421+
r"""Compute gradient with link rules.
422+
423+
:param loss: a scalar where back-prop starts
424+
425+
For PyTorch, just do "loss.backward()"
426+
"""
427+
if (self.step-1) % self.update_every == 0:
428+
self._clear_grad(self.optimizer)
429+
self.grad_scaler.scale(loss).backward()
430+
402431
def _update(self):
403432
r"""Perform weight update on a model.
404433
405434
"""
406-
self.grad_scaler.step(self.optimizer)
407-
self.grad_scaler.update()
408-
self._clear_grad_opt(self.optimizer)
435+
if self.step % self.update_every == 0:
436+
self.grad_scaler.step(self.optimizer)
437+
self.grad_scaler.update()
409438

410439
def _data_forward(self, network, x):
411440
x = _build_args(self._forward_func, **x)

fastNLP/core/tester.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=No
9898
:param bool fp16: 是否使用float16进行验证
9999
:param kwargs:
100100
Sampler sampler: 支持传入sampler控制测试顺序
101+
bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快数据速度。
101102
"""
102103
super(Tester, self).__init__()
103104

@@ -112,6 +113,7 @@ def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=No
112113
self.verbose = verbose
113114
self.use_tqdm = use_tqdm
114115
self.logger = logger
116+
self.pin_memory = kwargs.get('pin_memory', True)
115117

116118
if isinstance(data, DataSet):
117119
sampler = kwargs.get('sampler', None)
@@ -122,7 +124,8 @@ def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=No
122124
if hasattr(sampler, 'set_batch_size'):
123125
sampler.set_batch_size(batch_size)
124126
self.data_iterator = DataSetIter(dataset=data, batch_size=batch_size, sampler=sampler,
125-
num_workers=num_workers)
127+
num_workers=num_workers,
128+
pin_memory=self.pin_memory)
126129
elif isinstance(data, BatchIter):
127130
self.data_iterator = data
128131
else:
@@ -179,7 +182,8 @@ def test(self):
179182
start_time = time.time()
180183

181184
for batch_x, batch_y in data_iterator:
182-
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
185+
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device,
186+
non_blocking=self.pin_memory)
183187
with self.auto_cast():
184188
pred_dict = self._data_forward(self._predict_func, batch_x)
185189
if not isinstance(pred_dict, dict):

fastNLP/core/trainer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,7 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
432432
bool set_grad_to_none: 在zero_grad的时候是否将gradient设置为None,而不是设置为zero
433433
GradScaler grad_scaler: 仅在fp16为True时有效,如果不使用torch.cuda.amp.GradScaler的初始化参数,可传入一个已经初始化后的
434434
grad_scaler。
435+
bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快数据速度。
435436
"""
436437
super(Trainer, self).__init__()
437438
if not isinstance(model, nn.Module):
@@ -472,7 +473,7 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
472473
warnings.warn("num_workers is ignored when train_data is BatchIter.")
473474
if drop_last:
474475
warnings.warn("drop_last is ignored when train_data is BatchIter.")
475-
476+
self.pin_memory = kwargs.get('pin_memory', True)
476477
if isinstance(model, nn.parallel.DistributedDataParallel): # 如果是分布式的
477478
# device为None
478479
if device is not None:
@@ -502,12 +503,13 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
502503
sampler(train_data)
503504
train_data = DataSetIter(train_data,
504505
batch_size=1, sampler=None, as_numpy=False, num_workers=num_workers,
505-
pin_memory=False, drop_last=drop_last, timeout=0, worker_init_fn=None,
506+
pin_memory=self.pin_memory, drop_last=drop_last, timeout=0, worker_init_fn=None,
506507
batch_sampler=sampler)
507508

508509
if isinstance(train_data, DataSet):
509510
self.data_iterator = DataSetIter(dataset=train_data, batch_size=batch_size, sampler=sampler,
510-
num_workers=num_workers, drop_last=drop_last)
511+
num_workers=num_workers, drop_last=drop_last,
512+
pin_memory=self.pin_memory)
511513
elif isinstance(train_data, BatchIter):
512514
self.data_iterator = train_data
513515
train_data = train_data.dataset
@@ -600,7 +602,8 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
600602
use_tqdm=self.test_use_tqdm,
601603
sampler=kwargs.get('test_sampler', None),
602604
fp16=self.test_use_fp16,
603-
num_workers=num_workers)
605+
num_workers=num_workers,
606+
pin_memory=self.pin_memory)
604607

605608
self.start_time = None # start timestamp
606609

0 commit comments

Comments
 (0)