@@ -366,19 +366,28 @@ def forward(self, hidden_states, attention_mask):
366366
367367
368368class BertEncoder (nn .Module ):
369- def __init__ (self , config ):
369+ def __init__ (self , config , num_output_layer = - 1 ):
370370 super (BertEncoder , self ).__init__ ()
371371 layer = BertLayer (config )
372372 self .layer = nn .ModuleList ([copy .deepcopy (layer ) for _ in range (config .num_hidden_layers )])
373+ num_output_layer = num_output_layer if num_output_layer >= 0 else (len (self .layer ) + num_output_layer )
374+ self .num_output_layer = max (min (num_output_layer , len (self .layer )), 0 )
375+ if self .num_output_layer + 1 < len (self .layer ):
376+ logger .info (f'The transformer encoder will early exit after layer-{ self .num_output_layer } '
377+ f'(start from 0)!' )
373378
374379 def forward (self , hidden_states , attention_mask , output_all_encoded_layers = True ):
375380 all_encoder_layers = []
376- for layer_module in self .layer :
381+ for idx , layer_module in enumerate (self .layer ):
382+ if idx > self .num_output_layer :
383+ break
377384 hidden_states = layer_module (hidden_states , attention_mask )
378385 if output_all_encoded_layers :
379386 all_encoder_layers .append (hidden_states )
380387 if not output_all_encoded_layers :
381388 all_encoder_layers .append (hidden_states )
389+ if len (all_encoder_layers ) == 0 :
390+ all_encoder_layers .append (hidden_states )
382391 return all_encoder_layers
383392
384393
@@ -435,6 +444,9 @@ def __init__(self, config, *inputs, **kwargs):
435444 self .config = config
436445 self .hidden_size = self .config .hidden_size
437446 self .model_type = 'bert'
447+ neg_num_output_layer = kwargs .get ('neg_num_output_layer' , - 1 )
448+ pos_num_output_layer = kwargs .get ('pos_num_output_layer' , self .config .num_hidden_layers - 1 )
449+ self .num_output_layer = max (neg_num_output_layer + self .config .num_hidden_layers , pos_num_output_layer )
438450 if hasattr (config , 'sinusoidal_pos_embds' ):
439451 self .model_type = 'distilbert'
440452 elif 'model_type' in kwargs :
@@ -445,7 +457,7 @@ def __init__(self, config, *inputs, **kwargs):
445457 else :
446458 self .embeddings = BertEmbeddings (config )
447459
448- self .encoder = BertEncoder (config )
460+ self .encoder = BertEncoder (config , num_output_layer = self . num_output_layer )
449461 if self .model_type != 'distilbert' :
450462 self .pooler = BertPooler (config )
451463 else :
0 commit comments