@@ -237,6 +237,7 @@ def prediction_step(
237237 ** gen_kwargs ,
238238 ) -> Tuple [Optional [float ], Optional [torch .Tensor ], Optional [torch .Tensor ]]:
239239 if not self .args .predict_with_generate or prediction_loss_only :
240+ inputs ['_position_ids' ] = inputs .get ('position_ids' )
240241 with self .template .forward_context (self .model , inputs ):
241242 return super ().prediction_step (
242243 model , inputs , prediction_loss_only = prediction_loss_only , ignore_keys = ignore_keys )
@@ -277,15 +278,19 @@ def _prepare_inputs(self, inputs):
277278 compute_loss_func = get_loss_func ('loss_scale' )
278279
279280 sample_channels = inputs .pop ('channel' , None )
280- if sample_channels is not None and self .args .channels is not None :
281+ position_ids = inputs .pop ('_position_ids' , None )
282+ if self .args .channels is not None :
283+ assert sample_channels is not None , f'sample_channels: { sample_channels } '
281284 state = self .state
282285 setattr (state , 'local_step' , getattr (state , 'local_step' , 0 ))
283286 setattr (state , 'ch_loss_steps' , getattr (state , 'ch_loss_steps' , {}))
284287
285288 loss_kwargs ['sample_channels' ] = sample_channels
286289 loss_kwargs ['trainer' ] = self
287- if inputs .get ('position_ids' ) is not None :
288- loss_kwargs ['position_ids' ] = inputs ['position_ids' ]
290+ if position_ids is None :
291+ position_ids = inputs .get ('position_ids' )
292+ if position_ids is not None :
293+ loss_kwargs ['position_ids' ] = position_ids
289294
290295 use_logits_to_keep = self .get_use_logits_to_keep ('labels' in inputs and self .label_smoother is None
291296 and compute_loss_func is None )
@@ -352,5 +357,6 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
352357 return (loss , outputs ) if return_outputs else loss
353358
354359 def training_step (self , model , inputs , * args , ** kwargs ):
360+ inputs ['_position_ids' ] = inputs .get ('position_ids' )
355361 with self .template .forward_context (self .model , inputs ):
356362 return super ().training_step (model , inputs , * args , ** kwargs )
0 commit comments