Skip to content

Commit f17343e

Browse files
authored
[bugfix]修复fitlogcallback在disttrainner中无法添加dev_data 的问题 (#348)
fix the distTrainer dev_data
1 parent 8477669 commit f17343e

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

fastNLP/core/dist_trainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,13 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
177177
self.batch_size = self.world_size * self.batch_size_per_gpu
178178
self.n_steps = self._get_n_steps()
179179

180+
self.dev_data = dev_data
181+
self.metrics = metrics
182+
self.test_use_tqdm = True
183+
self.kwargs = kwargs
180184
self.test_use_tqdm = kwargs.get('test_use_tqdm', self.use_tqdm)
181185
dev_batch_size = kwargs.get('dev_batch_size', batch_size_per_gpu)
186+
182187
# for evaluation, only run eval on master proc
183188
if dev_data and metrics:
184189
cb = _TesterCallback(

0 commit comments

Comments
 (0)