Skip to content

Commit bffde78

Browse files
committed
fix a bug in early exit of bert
1 parent bf9d834 commit bffde78

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

fastNLP/modules/encoder/bert.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -374,20 +374,18 @@ def __init__(self, config, num_output_layer=-1):
374374
self.num_output_layer = max(min(num_output_layer, len(self.layer)), 0)
375375
if self.num_output_layer + 1 < len(self.layer):
376376
logger.info(f'The transformer encoder will early exit after layer-{self.num_output_layer} '
377-
f'(start from 0)!')
377+
f'(layer 0 means embedding layer)!')
378378

379379
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
380380
all_encoder_layers = []
381381
for idx, layer_module in enumerate(self.layer):
382-
if idx > self.num_output_layer:
382+
if idx >= self.num_output_layer:
383383
break
384384
hidden_states = layer_module(hidden_states, attention_mask)
385385
if output_all_encoded_layers:
386386
all_encoder_layers.append(hidden_states)
387387
if not output_all_encoded_layers:
388388
all_encoder_layers.append(hidden_states)
389-
if len(all_encoder_layers) == 0:
390-
all_encoder_layers.append(hidden_states)
391389
return all_encoder_layers
392390

393391

@@ -445,8 +443,8 @@ def __init__(self, config, *inputs, **kwargs):
445443
self.hidden_size = self.config.hidden_size
446444
self.model_type = 'bert'
447445
neg_num_output_layer = kwargs.get('neg_num_output_layer', -1)
448-
pos_num_output_layer = kwargs.get('pos_num_output_layer', self.config.num_hidden_layers - 1)
449-
self.num_output_layer = max(neg_num_output_layer + self.config.num_hidden_layers, pos_num_output_layer)
446+
pos_num_output_layer = kwargs.get('pos_num_output_layer', self.config.num_hidden_layers)
447+
self.num_output_layer = max(neg_num_output_layer + 1 + self.config.num_hidden_layers, pos_num_output_layer)
450448
if hasattr(config, 'sinusoidal_pos_embds'):
451449
self.model_type = 'distilbert'
452450
elif 'model_type' in kwargs:
@@ -535,15 +533,14 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_al
535533
encoded_layers = self.encoder(embedding_output,
536534
extended_attention_mask,
537535
output_all_encoded_layers=output_all_encoded_layers)
536+
encoded_layers.insert(0, embedding_output)
538537
sequence_output = encoded_layers[-1]
539538
if self.model_type != 'distilbert':
540539
pooled_output = self.pooler(sequence_output)
541540
else:
542541
pooled_output = sequence_output[:, 0]
543542
if not output_all_encoded_layers:
544543
encoded_layers = encoded_layers[-1]
545-
else:
546-
encoded_layers.insert(0, embedding_output)
547544
return encoded_layers, pooled_output
548545

549546
@classmethod

0 commit comments

Comments
 (0)