Skip to content

Commit c72bc07

Browse files
committed
1.修复yxg同学发现的DataSet copy Field导致的bug; 2. 修改DistTrainer中update_every的bug; 3.Trainer和DistTrainer修改pin_memory默认值为True
1 parent 7ce03a6 commit c72bc07

File tree

5 files changed

+61
-58
lines changed

5 files changed

+61
-58
lines changed

fastNLP/core/dataset.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -474,8 +474,8 @@ def __getitem__(self, idx):
474474
if idx.start is not None and (idx.start >= len(self) or idx.start <= -len(self)):
475475
raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self) - 1}")
476476
data_set = DataSet()
477-
for field in self.field_arrays.values():
478-
data_set.add_field(field_name=field.name, fields=field.content[idx], padder=field.padder,
477+
for field_name, field in self.field_arrays.items():
478+
data_set.add_field(field_name=field_name, fields=field.content[idx], padder=field.padder,
479479
is_input=field.is_input, is_target=field.is_target, ignore_type=field.ignore_type)
480480
data_set.collater = self.collater.copy_from(self.collater)
481481
return data_set
@@ -616,6 +616,7 @@ def add_fieldarray(self, field_name, fieldarray):
616616
if len(self) != len(fieldarray):
617617
raise RuntimeError(f"The field to add must have the same size as dataset. "
618618
f"Dataset size {len(self)} != field size {len(fieldarray)}")
619+
fieldarray.name = field_name
619620
self.field_arrays[field_name] = fieldarray
620621

621622
def add_field(self, field_name, fields, padder=AutoPadder(), is_input=False, is_target=False, ignore_type=False):
@@ -673,6 +674,7 @@ def copy_field(self, field_name, new_field_name):
673674
if not self.has_field(field_name):
674675
raise KeyError(f"Field:{field_name} not found in DataSet.")
675676
fieldarray = deepcopy(self.get_field(field_name))
677+
fieldarray.name = new_field_name
676678
self.add_fieldarray(field_name=new_field_name, fieldarray=fieldarray)
677679
return self
678680

fastNLP/core/dist_trainer.py

Lines changed: 44 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def get_local_rank():
5555
raise RuntimeError('Please use "python -m torch.distributed.launch --nproc_per_node=N train_script.py')
5656

5757

58-
class DistTrainer():
58+
class DistTrainer:
5959
r"""
6060
分布式的 Trainer,支持分布式训练和混合精度的训练。具体实现原理请阅读 pytorch 官方文档。
6161
@@ -110,7 +110,7 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
110110
int dev_batch_size: 在evaluate时,使用的evaluate的batch大小
111111
bool test_use_fp16: test时使用fp16
112112
bool set_grad_to_none: zero_grad时将grad设为None而不是0
113-
GradScaler gradscaler: 自定义的梯度 scaler
113+
GradScaler grad_scaler: 自定义的梯度 scaler
114114
bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快数据速度。一般在tensor较多或tensor维度较大时,有速度增益。
115115
bool find_unused_parameters: 在将model转化为DistributedDataParallel类型的时候,需要填入该参数,除非model内确实有
116116
forward没用上的参数,否则应该不需要用到该参数。
@@ -132,6 +132,7 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
132132
self.rank = dist.get_rank() # unique id for each process
133133

134134
self.train_data = train_data
135+
self.kwargs = kwargs
135136
if kwargs.get('batch_size', None):
136137
batch_size_per_gpu = int(kwargs.get('batch_size'))
137138
self.batch_size_per_gpu = int(batch_size_per_gpu)
@@ -158,15 +159,15 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
158159
# init fp16, must before DataParallel init
159160
autocast, GradScaler = _build_fp16_env(dummy=not self.fp16)
160161
self.auto_cast = autocast
161-
user_grad_scaler = getattr(kwargs, 'gradscaler', None)
162+
user_grad_scaler = kwargs.get('grad_scaler', None)
162163
if user_grad_scaler is not None:
163-
assert self.fp16, "must set fp16=True to enable gradscaler"
164+
assert self.fp16, "must set fp16=True to enable grad_scaler"
164165
grad_scaler = user_grad_scaler
165166
else:
166167
grad_scaler = GradScaler()
167168
self.grad_scaler = grad_scaler
168169

169-
self.set_grad_to_none = getattr(kwargs, 'set_grad_to_none', True)
170+
self.set_grad_to_none = kwargs.get('set_grad_to_none', False)
170171

171172
# init DataParallel
172173
if parse_version(torch.__version__)>=parse_version('1.1'):
@@ -191,15 +192,15 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
191192
elif hasattr(sampler, 'set_batch_size'):
192193
sampler.set_batch_size(batch_size_per_gpu)
193194
self.sampler = sampler
194-
self.pin_memory = kwargs.get('pin_memory', True)
195+
# concerning issue from https://github.com/pytorch/pytorch/issues/57273
196+
self.pin_memory = kwargs.get('pin_memory', False if parse_version(torch.__version__)==parse_version('1.9') else True)
195197
self.data_iterator = self._get_data_iter(self.train_data)
196198
self.batch_size = self.world_size * self.batch_size_per_gpu
197199
self.n_steps = self._get_n_steps()
198200

199201
self.dev_data = dev_data
200202
self.metrics = metrics
201203
self.test_use_tqdm = True
202-
self.kwargs = kwargs
203204
self.test_use_tqdm = kwargs.get('test_use_tqdm', self.use_tqdm)
204205
dev_batch_size = kwargs.get('dev_batch_size', batch_size_per_gpu)
205206

@@ -229,22 +230,6 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
229230
self.logger.info("Num of processes: {}".format(self.world_size))
230231
self.logger.info("Use device: {}".format(device))
231232

232-
def _maybe_no_sync(self):
233-
"""
234-
Whenever *samples* contains more than one mini-batch, we
235-
want to accumulate gradients locally and only call
236-
all-reduce in the last backwards pass.
237-
"""
238-
i = self.step % self.update_every
239-
if (
240-
self.world_size > 1
241-
and hasattr(self.ddp_model, "no_sync")
242-
and i != 0
243-
):
244-
return self.ddp_model.no_sync()
245-
else:
246-
return contextlib.ExitStack() # dummy contextmanager
247-
248233
def _get_n_steps(self):
249234
return len(self.data_iterator) * self.n_epochs
250235

@@ -365,37 +350,42 @@ def _train(self):
365350
self.callback_manager.on_epoch_begin()
366351
for batch_x, batch_y in data_iterator:
367352
self.step += 1
368-
self.ddp_model.train()
369-
_move_dict_value_to_device(batch_x, batch_y, device=self.device, non_blocking=self.pin_memory)
370-
indices = data_iterator.get_batch_indices()
371-
# negative sampling; replace unknown; re-weight batch_y
372-
self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
373-
with self.auto_cast():
374-
prediction = self._data_forward(self.ddp_model, batch_x)
375-
# edit prediction
376-
self.callback_manager.on_loss_begin(batch_y, prediction)
377-
loss = self._compute_loss(prediction, batch_y)
378-
379-
avg_loss += loss.detach()
380-
381-
# Is loss NaN or inf? requires_grad = False
382-
self.callback_manager.on_backward_begin(loss)
383-
self._grad_backward(loss)
384-
self.callback_manager.on_backward_end()
385-
self._update()
386-
self.callback_manager.on_step_end()
387-
388-
if self.step % self.print_every == 0:
389-
avg_loss = float(avg_loss) / self.print_every
390-
print_output = "loss:{:<6.5f}".format(avg_loss)
391-
pbar.update(self.print_every)
392-
pbar.set_postfix_str(print_output)
393-
avg_loss = 0
394-
395-
self.callback_manager.on_batch_end()
396-
397-
if (self.validate_every > 0 and self.step % self.validate_every == 0) and len(self.test_manager.callbacks):
398-
self._do_validation()
353+
if self.step%self.update_every!=0:
354+
no_sync = self.ddp_model.no_sync
355+
else:
356+
no_sync = contextlib.ExitStack
357+
with no_sync():
358+
self.ddp_model.train()
359+
_move_dict_value_to_device(batch_x, batch_y, device=self.device, non_blocking=self.pin_memory)
360+
indices = data_iterator.get_batch_indices()
361+
# negative sampling; replace unknown; re-weight batch_y
362+
self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
363+
with self.auto_cast():
364+
prediction = self._data_forward(self.ddp_model, batch_x)
365+
# edit prediction
366+
self.callback_manager.on_loss_begin(batch_y, prediction)
367+
loss = self._compute_loss(prediction, batch_y)
368+
369+
avg_loss += loss.detach()
370+
371+
# Is loss NaN or inf? requires_grad = False
372+
self.callback_manager.on_backward_begin(loss)
373+
self._grad_backward(loss)
374+
self.callback_manager.on_backward_end()
375+
self._update()
376+
self.callback_manager.on_step_end()
377+
378+
if self.step % self.print_every == 0:
379+
avg_loss = float(avg_loss) / self.print_every
380+
print_output = "loss:{:<6.5f}".format(avg_loss)
381+
pbar.update(self.print_every)
382+
pbar.set_postfix_str(print_output)
383+
avg_loss = 0
384+
385+
self.callback_manager.on_batch_end()
386+
387+
if (self.validate_every > 0 and self.step % self.validate_every == 0) and len(self.test_manager.callbacks):
388+
self._do_validation()
399389

400390
# ================= mini-batch end ==================== #
401391
if self.validate_every < 0 and len(self.test_manager.callbacks):

fastNLP/core/trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ def on_epoch_end(self):
334334
except:
335335
from .utils import _pseudo_tqdm as tqdm
336336
import warnings
337+
from pkg_resources import parse_version
337338

338339
from .batch import DataSetIter, BatchIter
339340
from .callback import CallbackManager, CallbackException, Callback
@@ -473,7 +474,8 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
473474
warnings.warn("num_workers is ignored when train_data is BatchIter.")
474475
if drop_last:
475476
warnings.warn("drop_last is ignored when train_data is BatchIter.")
476-
self.pin_memory = kwargs.get('pin_memory', True)
477+
# concerning issue from https://github.com/pytorch/pytorch/issues/57273
478+
self.pin_memory = kwargs.get('pin_memory', False if parse_version(torch.__version__)==parse_version('1.9') else True)
477479
if isinstance(model, nn.parallel.DistributedDataParallel): # 如果是分布式的
478480
# device为None
479481
if device is not None:

fastNLP/io/data_bundle.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def __init__(self, vocabs: dict = None, datasets: dict = None):
3131
r"""
3232
3333
:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict
34-
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict
34+
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict。建议不要将相同的DataSet对象重复传入,可能会在
35+
使用Pipe处理数据的时候遇到问题。
3536
"""
3637
self.vocabs = vocabs or {}
3738
self.datasets = datasets or {}

tests/core/test_dataset.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,14 @@ def test_copy_padder(self):
345345
ds.apply_field(lambda x: x, 'idx', 'idx')
346346
self.assertTrue(isinstance(ds.get_field('idx').padder, AutoPadder)) # should be None, but AutoPadder
347347

348+
def test_instance_field_disappear_bug(self):
349+
data = DataSet({'raw_chars': [[0,1],[2]], 'target': [0, 1]})
350+
data.copy_field(field_name='raw_chars', new_field_name='chars')
351+
_data = data[:1]
352+
for field_name in ['raw_chars', 'target', 'chars']:
353+
self.assertTrue(_data.has_field(field_name))
354+
355+
348356
class TestDataSetIter(unittest.TestCase):
349357
def test__repr__(self):
350358
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})

0 commit comments

Comments
 (0)