@@ -73,7 +73,7 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
7373 r"""
7474
7575 :param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。
76- :param nn.modules model: 待训练的模型
76+ :param nn.modules, DDP model: 待训练的模型
7777 :param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器
7878 :param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward`
7979 :param list callbacks_all: 用于在train过程中起调节作用的回调函数,作用于所有训练进程中。
@@ -146,16 +146,13 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
146146 self .losser = _prepare_losser (loss )
147147 self .fp16 = fp16
148148 self .local_rank = get_local_rank ()
149- self ._forward_func = model .forward
150149 self .callback_manager = DistCallbackManager (
151150 env = {"trainer" : self }, callbacks_all = callbacks_all ,
152151 callbacks_master = callbacks_master )
153152 self .test_manager = DistCallbackManager (env = {'trainer' : self })
154153 self .metric_key = metric_key
155154 self .use_tqdm = use_tqdm
156155
157- model .to (self .device )
158-
159156 # init fp16, must before DataParallel init
160157 autocast , GradScaler = _build_fp16_env (dummy = not self .fp16 )
161158 self .auto_cast = autocast
@@ -170,15 +167,22 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
170167 self .set_grad_to_none = kwargs .get ('set_grad_to_none' , False )
171168
172169 # init DataParallel
173- if parse_version (torch .__version__ )>= parse_version ('1.1' ):
174- self .ddp_model = DDP (model , device_ids = [self .local_rank ],
175- output_device = self .local_rank ,
176- find_unused_parameters = kwargs .get ('find_unused_parameters' , False ))
170+ if isinstance (model , DDP ):
171+ self .ddp_model = model
177172 else :
178- self .ddp_model = DDP (model , device_ids = [self .local_rank ],
179- output_device = self .local_rank )
173+ if parse_version (torch .__version__ )>= parse_version ('1.1' ):
174+ self .ddp_model = DDP (model , device_ids = [self .local_rank ],
175+ output_device = self .local_rank ,
176+ find_unused_parameters = kwargs .get ('find_unused_parameters' , False ))
177+ else :
178+ self .ddp_model = DDP (model , device_ids = [self .local_rank ],
179+ output_device = self .local_rank )
180180 self .model = self .ddp_model .module
181181
182+ self ._forward_func = self .model .forward
183+ self .model .to (self .device )
184+
185+
182186 optimizer = self ._get_optimizer (optimizer )
183187 self .optimizer = optimizer
184188 if isinstance (self .train_data , DataSet ):
@@ -207,7 +211,7 @@ def __init__(self, train_data, model, optimizer=None, loss=None,
207211 # for evaluation, only run eval on master proc
208212 if dev_data and metrics :
209213 cb = _TesterCallback (
210- dev_data , model , metrics ,
214+ dev_data , self . model , metrics ,
211215 batch_size = dev_batch_size , num_workers = num_workers , sampler = kwargs .get ('test_sampler' , None ),
212216 use_tqdm = self .test_use_tqdm )
213217 self .test_manager .add_callback ([cb ], master = True )
@@ -343,6 +347,7 @@ def _train(self):
343347 avg_loss = 0
344348 data_iterator = self .data_iterator
345349 self .ddp_model .zero_grad ()
350+ self .batch_per_epoch = self .data_iterator .num_batches
346351 for epoch in range (1 , self .n_epochs + 1 ):
347352 self .epoch = epoch
348353 pbar .set_description_str (desc = "Epoch {}/{}" .format (epoch , self .n_epochs ))
0 commit comments