Skip to content

Commit 8cdaba1

Browse files
committed
Merge: [BERT/Paddle] Update LDDL
2 parents dff4935 + b9178e8 commit 8cdaba1

File tree

16 files changed

+720
-474
lines changed

16 files changed

+720
-474
lines changed
Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
1-
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/paddlepaddle:22.08-py3
2-
1+
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/paddlepaddle:22.12-py3
32
FROM ${FROM_IMAGE_NAME}
4-
53
RUN apt-get update && apt-get install -y pbzip2 pv bzip2 cabextract
64

75
ENV BERT_PREP_WORKING_DIR /workspace/bert/data
8-
ADD requirements.txt /workspace/
6+
97
WORKDIR /workspace/
10-
RUN pip install --no-cache-dir -r requirements.txt
11-
RUN git clone https://github.com/attardi/wikiextractor.git && cd wikiextractor && git checkout 6408a430fc504a38b04d37ce5e7fc740191dee16 && cd ..
12-
RUN git clone https://github.com/soskek/bookcorpus.git
138

14-
ADD . /workspace/bert
159
WORKDIR /workspace/bert
10+
RUN pip install --no-cache-dir \
11+
tqdm boto3 requests six ipdb h5py nltk progressbar tokenizers>=0.7\
12+
git+https://github.com/NVIDIA/dllogger wget
13+
14+
RUN apt-get install -y iputils-ping
15+
16+
COPY . .
17+
18+
RUN apt-get install -y libjemalloc-dev
19+
RUN pip install git+https://github.com/NVIDIA/lddl.git
20+
RUN python -m nltk.downloader punkt

PaddlePaddle/LanguageModeling/BERT/README.md

Lines changed: 121 additions & 60 deletions
Large diffs are not rendered by default.

PaddlePaddle/LanguageModeling/BERT/data/create_datasets_from_start.sh

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,36 +13,5 @@
1313
# limitations under the License.
1414

1515
#Download
16-
to_download=${1:-"wiki_only"}
17-
18-
#Download
19-
if [ "$to_download" = "wiki_books" ] ; then
20-
python3 /workspace/bert/data/bertPrep.py --action download --dataset bookscorpus
21-
fi
22-
23-
python3 /workspace/bert/data/bertPrep.py --action download --dataset wikicorpus_en
16+
download_wikipedia --outdir ${BERT_PREP_WORKING_DIR}/wikipedia/
2417
python3 /workspace/bert/data/bertPrep.py --action download --dataset squad
25-
26-
# Properly format the text files
27-
if [ "$to_download" = "wiki_books" ] ; then
28-
python3 /workspace/bert/data/bertPrep.py --action text_formatting --dataset bookscorpus
29-
fi
30-
python3 /workspace/bert/data/bertPrep.py --action text_formatting --dataset wikicorpus_en
31-
32-
if [ "$to_download" = "wiki_books" ] ; then
33-
DATASET="books_wiki_en_corpus"
34-
else
35-
DATASET="wikicorpus_en"
36-
# Shard the text files
37-
fi
38-
39-
# Shard the text files
40-
python3 /workspace/bert/data/bertPrep.py --action sharding --dataset $DATASET
41-
42-
# Create HDF5 files Phase 1
43-
python3 /workspace/bert/data/bertPrep.py --action create_hdf5_files --dataset $DATASET --max_seq_length 128 \
44-
--max_predictions_per_seq 20 --vocab_file /workspace/bert/vocab/bert-large-uncased-vocab.txt --do_lower_case 1
45-
46-
# Create HDF5 files Phase 2
47-
python3 /workspace/bert/data/bertPrep.py --action create_hdf5_files --dataset $DATASET --max_seq_length 512 \
48-
--max_predictions_per_seq 80 --vocab_file /workspace/bert/vocab/bert-large-uncased-vocab.txt --do_lower_case 1

PaddlePaddle/LanguageModeling/BERT/loss.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import paddle
16-
import paddle.nn.functional as F
1716

1817

1918
class CrossEntropyLossForSQuAD(paddle.nn.Layer):
@@ -53,7 +52,7 @@ def __init__(self, vocab_size):
5352
self.vocab_size = vocab_size
5453

5554
def forward(self, prediction_scores, seq_relationship_score,
56-
masked_lm_labels, next_sentence_labels, masked_lm_scale):
55+
masked_lm_labels, next_sentence_labels):
5756
"""
5857
Args:
5958
prediction_scores(Tensor):
@@ -80,12 +79,11 @@ def forward(self, prediction_scores, seq_relationship_score,
8079
Its data type should be float32 and its shape is [1].
8180
"""
8281
with paddle.static.amp.fp16_guard():
83-
masked_lm_loss = F.cross_entropy(
84-
prediction_scores,
85-
masked_lm_labels,
86-
reduction='none',
87-
ignore_index=-1)
88-
masked_lm_loss = masked_lm_loss / masked_lm_scale
89-
next_sentence_loss = F.cross_entropy(
90-
seq_relationship_score, next_sentence_labels, reduction='none')
91-
return paddle.sum(masked_lm_loss) + paddle.mean(next_sentence_loss)
82+
masked_lm_labels_flat = masked_lm_labels.reshape([-1])
83+
mlm_labels = masked_lm_labels_flat[masked_lm_labels_flat != -1]
84+
masked_lm_loss = self.loss_fn(prediction_scores, mlm_labels)
85+
if next_sentence_labels.ndim == 1:
86+
next_sentence_labels = next_sentence_labels.unsqueeze(axis=-1)
87+
next_sentence_loss = self.loss_fn(seq_relationship_score,
88+
next_sentence_labels)
89+
return masked_lm_loss + next_sentence_loss

PaddlePaddle/LanguageModeling/BERT/modeling.py

Lines changed: 20 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -89,17 +89,15 @@ def __init__(self, bert_config):
8989
self.layer_norm = nn.LayerNorm(bert_config.hidden_size, epsilon=1e-12)
9090
self.dropout = nn.Dropout(bert_config.hidden_dropout_prob)
9191

92-
def forward(self, input_ids, token_type_ids=None, position_ids=None):
92+
def forward(self, input_ids, token_type_ids=None):
9393
"""
9494
Args:
9595
See class `BertModel`.
9696
"""
97-
if position_ids is None:
98-
ones = paddle.ones_like(input_ids, dtype="int64")
99-
seq_length = paddle.cumsum(ones, axis=-1)
100-
101-
position_ids = seq_length - ones
102-
position_ids.stop_gradient = True
97+
ones = paddle.ones_like(input_ids, dtype="int64")
98+
seq_length = paddle.cumsum(ones, axis=-1)
99+
position_ids = seq_length - ones
100+
position_ids.stop_gradient = True
103101
if token_type_ids is None:
104102
token_type_ids = paddle.zeros_like(input_ids, dtype="int64")
105103

@@ -174,18 +172,13 @@ def __init__(self, bert_config):
174172
dropout=bert_config.hidden_dropout_prob,
175173
activation=bert_config.hidden_act,
176174
attn_dropout=bert_config.attention_probs_dropout_prob,
177-
act_dropout=0,
178-
enable_cudnn=False)
175+
act_dropout=0)
179176
self.encoder = nn.TransformerEncoder(encoder_layer,
180177
bert_config.num_hidden_layers)
181178

182179
self.pooler = BertPooler(bert_config.hidden_size)
183180

184-
def forward(self,
185-
input_ids,
186-
token_type_ids=None,
187-
position_ids=None,
188-
attention_mask=None):
181+
def forward(self, input_ids, token_type_ids=None, attention_mask=None):
189182
"""
190183
Args:
191184
input_ids(Tensor):
@@ -198,11 +191,6 @@ def forward(self,
198191
to a `sentence A` and type 1 corresponds to a `sentence B` token.
199192
(see BERT paper for more details). Its data type should be `int64`
200193
Defaults: None, which means we don't add segment embeddings.
201-
position_ids(Tensor, optional):
202-
An optional Tensor of shape [batch_size, num_tokens] with the position
203-
indices of each input sequence tokens in the position embeddings.
204-
Selected in the range [0, max_position_embeddings - 1].
205-
Its data type should be `int64`. Defaults: None.
206194
attention_mask(Tensor, optional):
207195
An optional Tensor of shape [batch_size, sequence_length] with indices of
208196
mask used in multi-head attention to avoid performing attention on to some
@@ -234,9 +222,7 @@ def forward(self,
234222
attention_mask = attention_mask.unsqueeze(axis=[1, 2])
235223

236224
embedding_output = self.embeddings(
237-
input_ids=input_ids,
238-
position_ids=position_ids,
239-
token_type_ids=token_type_ids)
225+
input_ids=input_ids, token_type_ids=token_type_ids)
240226

241227
if self.fuse:
242228
encoder_output = embedding_output
@@ -263,11 +249,7 @@ def __init__(self, bert_config):
263249
self.bert = BertModel(bert_config)
264250
self.classifier = nn.Linear(bert_config.hidden_size, 2)
265251

266-
def forward(self,
267-
input_ids,
268-
token_type_ids=None,
269-
position_ids=None,
270-
attention_mask=None):
252+
def forward(self, input_ids, token_type_ids=None, attention_mask=None):
271253
"""
272254
Args:
273255
See class `BertModel`.
@@ -282,7 +264,6 @@ def forward(self,
282264
encoder_output, _ = self.bert(
283265
input_ids,
284266
token_type_ids=token_type_ids,
285-
position_ids=position_ids,
286267
attention_mask=attention_mask)
287268

288269
logits = self.classifier(encoder_output)
@@ -322,13 +303,7 @@ def __init__(self,
322303
self.decoder_bias = self.create_parameter(
323304
shape=[vocab_size], dtype=self.decoder_weight.dtype, is_bias=True)
324305

325-
def forward(self, hidden_states, masked_positions=None):
326-
if masked_positions is not None:
327-
hidden_states = paddle.reshape(hidden_states,
328-
[-1, hidden_states.shape[-1]])
329-
hidden_states = paddle.tensor.gather(hidden_states,
330-
masked_positions)
331-
# gather masked tokens might be more quick
306+
def forward(self, hidden_states):
332307
hidden_states = self.transform(hidden_states)
333308
hidden_states = self.activation(hidden_states)
334309
hidden_states = self.layer_norm(hidden_states)
@@ -362,7 +337,7 @@ def __init__(self,
362337
activation, embedding_weights)
363338
self.seq_relationship = nn.Linear(hidden_size, 2)
364339

365-
def forward(self, encoder_output, pooled_output, masked_positions=None):
340+
def forward(self, encoder_output, pooled_output, masked_lm_labels):
366341
"""
367342
Args:
368343
sequence_output(Tensor):
@@ -384,7 +359,12 @@ def forward(self, encoder_output, pooled_output, masked_positions=None):
384359
A Tensor of shape [batch_size, 2] with the scores of next sentence prediction.
385360
Its data type should be float32.
386361
"""
387-
prediction_scores = self.predictions(encoder_output, masked_positions)
362+
363+
sequence_flattened = paddle.index_select(
364+
encoder_output.reshape([-1, encoder_output.shape[-1]]),
365+
paddle.nonzero(masked_lm_labels.reshape([-1]) != -1).squeeze(),
366+
axis=0)
367+
prediction_scores = self.predictions(sequence_flattened)
388368
seq_relationship_score = self.seq_relationship(pooled_output)
389369
return prediction_scores, seq_relationship_score
390370

@@ -406,18 +386,13 @@ def __init__(self, bert_config):
406386
bert_config.hidden_act,
407387
embedding_weights=self.bert.embeddings.word_embeddings.weight)
408388

409-
def forward(self,
410-
input_ids,
411-
token_type_ids=None,
412-
position_ids=None,
413-
attention_mask=None,
414-
masked_positions=None):
389+
def forward(self, input_ids, token_type_ids, attention_mask,
390+
masked_lm_labels):
415391
"""
416392
417393
Args:
418394
input_ids(Tensor): See class `BertModel`.
419395
token_type_ids(Tensor, optional): See class `BertModel`.
420-
position_ids(Tensor, optional): See class `BertModel`.
421396
attention_mask(Tensor, optional): See class `BertModel`.
422397
masked_positions(Tensor, optional): See class `BertPretrainingHeads`.
423398
@@ -434,9 +409,8 @@ def forward(self,
434409
outputs = self.bert(
435410
input_ids,
436411
token_type_ids=token_type_ids,
437-
position_ids=position_ids,
438412
attention_mask=attention_mask)
439413
sequence_output, pooled_output = outputs[:2]
440414
prediction_scores, seq_relationship_score = self.cls(
441-
sequence_output, pooled_output, masked_positions)
415+
sequence_output, pooled_output, masked_lm_labels)
442416
return prediction_scores, seq_relationship_score

0 commit comments

Comments
 (0)