Skip to content

Commit 972185d

Browse files
committed
增加DistTrainer中的batch_per_epoch属性
1 parent c0aa5bd commit 972185d

File tree

3 files changed

+20
-15
lines changed

3 files changed

+20
-15
lines changed

fastNLP/core/dist_trainer.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
7373
r"""
7474
7575
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。
76-
:param nn.modules model: 待训练的模型
76+
:param nn.modules, DDP model: 待训练的模型
7777
:param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器
7878
:param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward`
7979
:param list callbacks_all: 用于在train过程中起调节作用的回调函数,作用于所有训练进程中。
@@ -146,16 +146,13 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
146146
self.losser = _prepare_losser(loss)
147147
self.fp16 = fp16
148148
self.local_rank = get_local_rank()
149-
self._forward_func = model.forward
150149
self.callback_manager = DistCallbackManager(
151150
env={"trainer": self}, callbacks_all=callbacks_all,
152151
callbacks_master=callbacks_master)
153152
self.test_manager = DistCallbackManager(env={'trainer': self})
154153
self.metric_key = metric_key
155154
self.use_tqdm = use_tqdm
156155

157-
model.to(self.device)
158-
159156
# init fp16, must before DataParallel init
160157
autocast, GradScaler = _build_fp16_env(dummy=not self.fp16)
161158
self.auto_cast = autocast
@@ -170,15 +167,22 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
170167
self.set_grad_to_none = kwargs.get('set_grad_to_none', False)
171168

172169
# init DataParallel
173-
if parse_version(torch.__version__)>=parse_version('1.1'):
174-
self.ddp_model = DDP(model, device_ids=[self.local_rank],
175-
output_device=self.local_rank,
176-
find_unused_parameters=kwargs.get('find_unused_parameters', False))
170+
if isinstance(model, DDP):
171+
self.ddp_model = model
177172
else:
178-
self.ddp_model = DDP(model, device_ids=[self.local_rank],
179-
output_device=self.local_rank)
173+
if parse_version(torch.__version__)>=parse_version('1.1'):
174+
self.ddp_model = DDP(model, device_ids=[self.local_rank],
175+
output_device=self.local_rank,
176+
find_unused_parameters=kwargs.get('find_unused_parameters', False))
177+
else:
178+
self.ddp_model = DDP(model, device_ids=[self.local_rank],
179+
output_device=self.local_rank)
180180
self.model = self.ddp_model.module
181181

182+
self._forward_func = self.model.forward
183+
self.model.to(self.device)
184+
185+
182186
optimizer = self._get_optimizer(optimizer)
183187
self.optimizer = optimizer
184188
if isinstance(self.train_data, DataSet):
@@ -207,7 +211,7 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
207211
# for evaluation, only run eval on master proc
208212
if dev_data and metrics:
209213
cb = _TesterCallback(
210-
dev_data, model, metrics,
214+
dev_data, self.model, metrics,
211215
batch_size=dev_batch_size, num_workers=num_workers, sampler=kwargs.get('test_sampler', None),
212216
use_tqdm=self.test_use_tqdm)
213217
self.test_manager.add_callback([cb], master=True)
@@ -343,6 +347,7 @@ def _train(self):
343347
avg_loss = 0
344348
data_iterator = self.data_iterator
345349
self.ddp_model.zero_grad()
350+
self.batch_per_epoch = self.data_iterator.num_batches
346351
for epoch in range(1, self.n_epochs + 1):
347352
self.epoch = epoch
348353
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))

fastNLP/io/data_bundle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(self, vocabs: dict = None, datasets: dict = None):
3232
3333
:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict
3434
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict。建议不要将相同的DataSet对象重复传入,可能会在
35-
使用Pipe处理数据的时候遇到问题。
35+
使用Pipe处理数据的时候遇到问题,若多个数据集确需一致,请手动deepcopy后传入
3636
"""
3737
self.vocabs = vocabs or {}
3838
self.datasets = datasets or {}

fastNLP/io/pipe/pipe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ def process(self, data_bundle: DataBundle) -> DataBundle:
2727
对输入的DataBundle进行处理,然后返回该DataBundle。
2828
2929
:param ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象
30-
:return:
30+
:return: DataBundle
3131
"""
3232
raise NotImplementedError
3333

34-
def process_from_file(self, paths) -> DataBundle:
34+
def process_from_file(self, paths: str) -> DataBundle:
3535
r"""
3636
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()`
3737
38-
:param paths:
38+
:param str paths:
3939
:return: DataBundle
4040
"""
4141
raise NotImplementedError

0 commit comments

Comments
 (0)