@@ -265,10 +265,10 @@ def prediction_step(
265265 labels_list = pad_sequence (labels_list , batch_first = True , padding_value = 0 )
266266 return None , response_list , labels_list
267267
268- def compute_loss (self , model , inputs , return_outputs = False , num_items_in_batch = None ):
268+ def _prepare_inputs (self , inputs ):
269+ inputs = super ()._prepare_inputs (inputs )
269270 from swift .plugin .loss import get_loss_func
270271 loss_kwargs = {}
271- labels = None
272272 compute_loss_func = self .compute_loss_func
273273 loss_scale = inputs .pop ('loss_scale' , None )
274274 if loss_scale is not None :
@@ -287,14 +287,25 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
287287 if inputs .get ('position_ids' ) is not None :
288288 loss_kwargs ['position_ids' ] = inputs ['position_ids' ]
289289
290- if (self .label_smoother is not None or compute_loss_func is not None ) and 'labels' in inputs :
291- labels = inputs .pop ('labels' )
292-
293- use_logits_to_keep = self .get_use_logits_to_keep ('labels' in inputs )
290+ use_logits_to_keep = self .get_use_logits_to_keep ('labels' in inputs and self .label_smoother is None
291+ and compute_loss_func is None )
294292 if use_logits_to_keep :
295293 inputs ['labels' ], logits_to_keep = self .get_logits_to_keep (inputs ['labels' ])
296294 if logits_to_keep is not None :
297295 inputs ['logits_to_keep' ] = logits_to_keep
296+
297+ inputs ['compute_loss_func' ] = compute_loss_func
298+ inputs ['loss_kwargs' ] = loss_kwargs
299+ return inputs
300+
301+ def compute_loss (self , model , inputs , return_outputs = False , num_items_in_batch = None ):
302+ labels = None
303+ compute_loss_func = inputs .pop ('compute_loss_func' , None )
304+ loss_kwargs = inputs .pop ('loss_kwargs' , {})
305+
306+ if (self .label_smoother is not None or compute_loss_func is not None ) and 'labels' in inputs :
307+ labels = inputs .pop ('labels' )
308+
298309 outputs = model (** inputs )
299310 # Save past state if it exists
300311 # TODO: this needs to be fixed and made cleaner later.
0 commit comments