Skip to content

Commit e1ed6f1

Browse files
committed
move model to device in DistTrainer
1 parent 972185d commit e1ed6f1

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

fastNLP/core/dist_trainer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,11 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
165165
self.grad_scaler = grad_scaler
166166

167167
self.set_grad_to_none = kwargs.get('set_grad_to_none', False)
168-
169168
# init DataParallel
170169
if isinstance(model, DDP):
171170
self.ddp_model = model
172171
else:
172+
model.to(self.device)
173173
if parse_version(torch.__version__)>=parse_version('1.1'):
174174
self.ddp_model = DDP(model, device_ids=[self.local_rank],
175175
output_device=self.local_rank,
@@ -182,7 +182,6 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
182182
self._forward_func = self.model.forward
183183
self.model.to(self.device)
184184

185-
186185
optimizer = self._get_optimizer(optimizer)
187186
self.optimizer = optimizer
188187
if isinstance(self.train_data, DataSet):

0 commit comments

Comments
 (0)