@@ -374,20 +374,18 @@ def __init__(self, config, num_output_layer=-1):
374374 self .num_output_layer = max (min (num_output_layer , len (self .layer )), 0 )
375375 if self .num_output_layer + 1 < len (self .layer ):
376376 logger .info (f'The transformer encoder will early exit after layer-{ self .num_output_layer } '
377- f'(start from 0 )!' )
377+ f'(layer 0 means embedding layer )!' )
378378
379379 def forward (self , hidden_states , attention_mask , output_all_encoded_layers = True ):
380380 all_encoder_layers = []
381381 for idx , layer_module in enumerate (self .layer ):
382- if idx > self .num_output_layer :
382+ if idx >= self .num_output_layer :
383383 break
384384 hidden_states = layer_module (hidden_states , attention_mask )
385385 if output_all_encoded_layers :
386386 all_encoder_layers .append (hidden_states )
387387 if not output_all_encoded_layers :
388388 all_encoder_layers .append (hidden_states )
389- if len (all_encoder_layers ) == 0 :
390- all_encoder_layers .append (hidden_states )
391389 return all_encoder_layers
392390
393391
@@ -445,8 +443,8 @@ def __init__(self, config, *inputs, **kwargs):
445443 self .hidden_size = self .config .hidden_size
446444 self .model_type = 'bert'
447445 neg_num_output_layer = kwargs .get ('neg_num_output_layer' , - 1 )
448- pos_num_output_layer = kwargs .get ('pos_num_output_layer' , self .config .num_hidden_layers - 1 )
449- self .num_output_layer = max (neg_num_output_layer + self .config .num_hidden_layers , pos_num_output_layer )
446+ pos_num_output_layer = kwargs .get ('pos_num_output_layer' , self .config .num_hidden_layers )
447+ self .num_output_layer = max (neg_num_output_layer + 1 + self .config .num_hidden_layers , pos_num_output_layer )
450448 if hasattr (config , 'sinusoidal_pos_embds' ):
451449 self .model_type = 'distilbert'
452450 elif 'model_type' in kwargs :
@@ -535,15 +533,14 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_al
535533 encoded_layers = self .encoder (embedding_output ,
536534 extended_attention_mask ,
537535 output_all_encoded_layers = output_all_encoded_layers )
536+ encoded_layers .insert (0 , embedding_output )
538537 sequence_output = encoded_layers [- 1 ]
539538 if self .model_type != 'distilbert' :
540539 pooled_output = self .pooler (sequence_output )
541540 else :
542541 pooled_output = sequence_output [:, 0 ]
543542 if not output_all_encoded_layers :
544543 encoded_layers = encoded_layers [- 1 ]
545- else :
546- encoded_layers .insert (0 , embedding_output )
547544 return encoded_layers , pooled_output
548545
549546 @classmethod
0 commit comments