@@ -89,17 +89,15 @@ def __init__(self, bert_config):
8989 self .layer_norm = nn .LayerNorm (bert_config .hidden_size , epsilon = 1e-12 )
9090 self .dropout = nn .Dropout (bert_config .hidden_dropout_prob )
9191
92- def forward (self , input_ids , token_type_ids = None , position_ids = None ):
92+ def forward (self , input_ids , token_type_ids = None ):
9393 """
9494 Args:
9595 See class `BertModel`.
9696 """
97- if position_ids is None :
98- ones = paddle .ones_like (input_ids , dtype = "int64" )
99- seq_length = paddle .cumsum (ones , axis = - 1 )
100-
101- position_ids = seq_length - ones
102- position_ids .stop_gradient = True
97+ ones = paddle .ones_like (input_ids , dtype = "int64" )
98+ seq_length = paddle .cumsum (ones , axis = - 1 )
99+ position_ids = seq_length - ones
100+ position_ids .stop_gradient = True
103101 if token_type_ids is None :
104102 token_type_ids = paddle .zeros_like (input_ids , dtype = "int64" )
105103
@@ -174,18 +172,13 @@ def __init__(self, bert_config):
174172 dropout = bert_config .hidden_dropout_prob ,
175173 activation = bert_config .hidden_act ,
176174 attn_dropout = bert_config .attention_probs_dropout_prob ,
177- act_dropout = 0 ,
178- enable_cudnn = False )
175+ act_dropout = 0 )
179176 self .encoder = nn .TransformerEncoder (encoder_layer ,
180177 bert_config .num_hidden_layers )
181178
182179 self .pooler = BertPooler (bert_config .hidden_size )
183180
184- def forward (self ,
185- input_ids ,
186- token_type_ids = None ,
187- position_ids = None ,
188- attention_mask = None ):
181+ def forward (self , input_ids , token_type_ids = None , attention_mask = None ):
189182 """
190183 Args:
191184 input_ids(Tensor):
@@ -198,11 +191,6 @@ def forward(self,
198191 to a `sentence A` and type 1 corresponds to a `sentence B` token.
199192 (see BERT paper for more details). Its data type should be `int64`
200193 Defaults: None, which means we don't add segment embeddings.
201- position_ids(Tensor, optional):
202- An optional Tensor of shape [batch_size, num_tokens] with the position
203- indices of each input sequence tokens in the position embeddings.
204- Selected in the range [0, max_position_embeddings - 1].
205- Its data type should be `int64`. Defaults: None.
206194 attention_mask(Tensor, optional):
207195 An optional Tensor of shape [batch_size, sequence_length] with indices of
208196 mask used in multi-head attention to avoid performing attention on to some
@@ -234,9 +222,7 @@ def forward(self,
234222 attention_mask = attention_mask .unsqueeze (axis = [1 , 2 ])
235223
236224 embedding_output = self .embeddings (
237- input_ids = input_ids ,
238- position_ids = position_ids ,
239- token_type_ids = token_type_ids )
225+ input_ids = input_ids , token_type_ids = token_type_ids )
240226
241227 if self .fuse :
242228 encoder_output = embedding_output
@@ -263,11 +249,7 @@ def __init__(self, bert_config):
263249 self .bert = BertModel (bert_config )
264250 self .classifier = nn .Linear (bert_config .hidden_size , 2 )
265251
266- def forward (self ,
267- input_ids ,
268- token_type_ids = None ,
269- position_ids = None ,
270- attention_mask = None ):
252+ def forward (self , input_ids , token_type_ids = None , attention_mask = None ):
271253 """
272254 Args:
273255 See class `BertModel`.
@@ -282,7 +264,6 @@ def forward(self,
282264 encoder_output , _ = self .bert (
283265 input_ids ,
284266 token_type_ids = token_type_ids ,
285- position_ids = position_ids ,
286267 attention_mask = attention_mask )
287268
288269 logits = self .classifier (encoder_output )
@@ -322,13 +303,7 @@ def __init__(self,
322303 self .decoder_bias = self .create_parameter (
323304 shape = [vocab_size ], dtype = self .decoder_weight .dtype , is_bias = True )
324305
325- def forward (self , hidden_states , masked_positions = None ):
326- if masked_positions is not None :
327- hidden_states = paddle .reshape (hidden_states ,
328- [- 1 , hidden_states .shape [- 1 ]])
329- hidden_states = paddle .tensor .gather (hidden_states ,
330- masked_positions )
331- # gather masked tokens might be more quick
306+ def forward (self , hidden_states ):
332307 hidden_states = self .transform (hidden_states )
333308 hidden_states = self .activation (hidden_states )
334309 hidden_states = self .layer_norm (hidden_states )
@@ -362,7 +337,7 @@ def __init__(self,
362337 activation , embedding_weights )
363338 self .seq_relationship = nn .Linear (hidden_size , 2 )
364339
365- def forward (self , encoder_output , pooled_output , masked_positions = None ):
340+ def forward (self , encoder_output , pooled_output , masked_lm_labels ):
366341 """
367342 Args:
368343 sequence_output(Tensor):
@@ -384,7 +359,12 @@ def forward(self, encoder_output, pooled_output, masked_positions=None):
384359 A Tensor of shape [batch_size, 2] with the scores of next sentence prediction.
385360 Its data type should be float32.
386361 """
387- prediction_scores = self .predictions (encoder_output , masked_positions )
362+
363+ sequence_flattened = paddle .index_select (
364+ encoder_output .reshape ([- 1 , encoder_output .shape [- 1 ]]),
365+ paddle .nonzero (masked_lm_labels .reshape ([- 1 ]) != - 1 ).squeeze (),
366+ axis = 0 )
367+ prediction_scores = self .predictions (sequence_flattened )
388368 seq_relationship_score = self .seq_relationship (pooled_output )
389369 return prediction_scores , seq_relationship_score
390370
@@ -406,18 +386,13 @@ def __init__(self, bert_config):
406386 bert_config .hidden_act ,
407387 embedding_weights = self .bert .embeddings .word_embeddings .weight )
408388
409- def forward (self ,
410- input_ids ,
411- token_type_ids = None ,
412- position_ids = None ,
413- attention_mask = None ,
414- masked_positions = None ):
389+ def forward (self , input_ids , token_type_ids , attention_mask ,
390+ masked_lm_labels ):
415391 """
416392
417393 Args:
418394 input_ids(Tensor): See class `BertModel`.
419395 token_type_ids(Tensor, optional): See class `BertModel`.
420- position_ids(Tensor, optional): See class `BertModel`.
421396 attention_mask(Tensor, optional): See class `BertModel`.
422397 masked_positions(Tensor, optional): See class `BertPretrainingHeads`.
423398
@@ -434,9 +409,8 @@ def forward(self,
434409 outputs = self .bert (
435410 input_ids ,
436411 token_type_ids = token_type_ids ,
437- position_ids = position_ids ,
438412 attention_mask = attention_mask )
439413 sequence_output , pooled_output = outputs [:2 ]
440414 prediction_scores , seq_relationship_score = self .cls (
441- sequence_output , pooled_output , masked_positions )
415+ sequence_output , pooled_output , masked_lm_labels )
442416 return prediction_scores , seq_relationship_score
0 commit comments