Skip to content

Commit 75052a4

Browse files
committed
Merge: [BART/PyT] Add pretraining feature support
2 parents 0409902 + 305744e commit 75052a4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+2055
-1832
lines changed

PyTorch/LanguageModeling/BART/Dockerfile

Lines changed: 12 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -14,55 +14,25 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616

17-
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.02-py3
18-
19-
######
20-
# Tokenizers is only available pre-built on x86
21-
#
22-
FROM ${FROM_IMAGE_NAME} AS tokenizers_amd64
23-
WORKDIR /wheelhouse
24-
RUN pip download tokenizers==0.8.0
25-
26-
FROM quay.io/pypa/manylinux2014_aarch64 as tokenizers_arm64
27-
ARG PYVER=38
28-
RUN yum install -y openssl-devel
29-
RUN curl https://sh.rustup.rs -sSf | sh -s -- --default-toolchain nightly-2020-05-14 -y
30-
ENV PATH="/root/.cargo/bin:$PATH"
31-
ENV PYBIN=/opt/python/cp${PYVER}-cp${PYVER}/bin
32-
ENV PYTHON_SYS_EXECUTABLE="$PYBIN/python"
33-
RUN git clone -b python-v0.8.0 https://github.com/huggingface/tokenizers.git /opt/tokenizers
34-
WORKDIR /opt/tokenizers/bindings/python
35-
RUN "${PYBIN}/pip" install setuptools-rust \
36-
&& "${PYBIN}/python" setup.py bdist_wheel \
37-
&& rm -rf build/* \
38-
&& for whl in dist/*.whl; do \
39-
auditwheel repair "$whl" -w dist/; \
40-
done \
41-
&& rm dist/*-linux_* \
42-
&& mkdir -p /wheelhouse \
43-
&& mv dist/*.whl /wheelhouse
44-
45-
ARG TARGETARCH
46-
FROM tokenizers_${TARGETARCH} AS tokenizers
47-
#
48-
#####
49-
50-
17+
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:22.08-py3
5118
FROM ${FROM_IMAGE_NAME}
52-
RUN apt-get update && apt-get install -y pbzip2
5319

54-
RUN --mount=from=tokenizers,source=/wheelhouse,target=/tmp/wheelhouse \
55-
pip install --no-cache-dir /tmp/wheelhouse/tokenizers*.whl
5620

57-
RUN pip install --no-cache-dir dataclasses gitpython rouge-score pynvml==8.0.4 \
58-
git+https://github.com/NVIDIA/dllogger pytorch-lightning==1.1.5 gdown sacrebleu
59-
60-
RUN pip install tqdm --upgrade
21+
RUN apt-get update
22+
COPY requirements.txt .
23+
RUN pip install --upgrade --no-cache-dir pip \
24+
&& pip install --no-cache-dir -r requirements.txt
6125

6226
WORKDIR /workspace
63-
RUN git clone https://github.com/artmatsak/cnn-dailymail.git
27+
RUN git clone https://github.com/abisee/cnn-dailymail.git
6428
RUN git clone https://github.com/gcunhase/AMICorpusXML.git
6529

30+
# Re-build apex
31+
RUN git clone https://github.com/nv-joseli/apex.git
32+
RUN cd apex && \
33+
git checkout bf16lamb && \
34+
NVCC_APPEND_FLAGS='--threads 1' pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
35+
6636
WORKDIR /workspace/bart
6737

6838
COPY . .

PyTorch/LanguageModeling/BART/README.md

Lines changed: 223 additions & 140 deletions
Large diffs are not rendered by default.

PyTorch/LanguageModeling/BART/bart/configuration/configuration_bart.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# coding=utf-8
2+
# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
23
# Copyright 2020 The Fairseq Authors and The HuggingFace Inc. team.
34
#
45
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -88,8 +89,6 @@
8889
Google "layerdrop arxiv", as its not explainable in one line.
8990
decoder_layerdrop: (:obj:`float`, optional, defaults to 0.0):
9091
Google "layerdrop arxiv", as its not explainable in one line.
91-
extra_pos_embeddings: (:obj:`int`, optional, defaults to 2):
92-
How many extra learned positional embeddings to use. Should be pad_token_id+1 for bart.
9392
num_labels: (:obj:`int`, optional, defaults to 2):
9493
for SequenceClassification
9594
is_encoder_decoder (:obj:`int`, optional, defaults to True):
@@ -109,7 +108,6 @@ class BartConfig(PretrainedConfig):
109108
def __init__(
110109
self,
111110
activation_dropout=0.0,
112-
extra_pos_embeddings=2, # FIXME(@sshleifer): delete?
113111
activation_function="gelu",
114112
vocab_size=50265,
115113
d_model=1024,
@@ -194,9 +192,6 @@ def __init__(
194192
# Classifier stuff
195193
self.classif_dropout = classifier_dropout
196194

197-
# pos embedding offset
198-
self.extra_pos_embeddings = self.pad_token_id + 1
199-
200195
self.force_bos_token_to_be_generated = force_bos_token_to_be_generated
201196
self.attention_bias = attention_bias
202197

PyTorch/LanguageModeling/BART/bart/configuration/configuration_t5.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# coding=utf-8
2+
# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
23
# Copyright 2010, The T5 Authors and HuggingFace Inc.
34
#
45
# Licensed under the Apache License, Version 2.0 (the "License");

PyTorch/LanguageModeling/BART/bart/configuration/configuration_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# coding=utf-8
2+
# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
23
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3-
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
class BertSelfAttention(nn.Module):
17+
18+
def __init__(
19+
self,
20+
embed_dim,
21+
num_heads,
22+
dropout=0.0,
23+
bias=True,
24+
encoder_decoder_attention=False, # otherwise self_attention
25+
):
26+
def __init__(self, config):
27+
super(BertSelfAttention, self).__init__()
28+
if config.hidden_size % num_heads != 0:
29+
raise ValueError(
30+
"The hidden size (%d) is not a multiple of the number of attention "
31+
"heads (%d)" % (config.hidden_size, num_heads))
32+
self.num_heads = num_heads
33+
self.attention_head_size = int(config.hidden_size / num_heads)
34+
self.all_head_size = self.num_heads * self.attention_head_size
35+
36+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
37+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
38+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
39+
40+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
41+
42+
def transpose_for_scores(self, x):
43+
new_x_shape = x.size()[:-1] + (self.num_heads, self.attention_head_size)
44+
x = torch.reshape(x, new_x_shape)
45+
return x.permute(0, 2, 1, 3)
46+
47+
def transpose_key_for_scores(self, x):
48+
new_x_shape = x.size()[:-1] + (self.num_heads, self.attention_head_size)
49+
x = torch.reshape(x, new_x_shape)
50+
return x.permute(0, 2, 3, 1)
51+
52+
def forward(self, hidden_states, attention_mask):
53+
mixed_query_layer = self.query(hidden_states)
54+
mixed_key_layer = self.key(hidden_states)
55+
mixed_value_layer = self.value(hidden_states)
56+
57+
query_layer = self.transpose_for_scores(mixed_query_layer)
58+
key_layer = self.transpose_key_for_scores(mixed_key_layer)
59+
value_layer = self.transpose_for_scores(mixed_value_layer)
60+
61+
# Take the dot product between "query" and "key" to get the raw attention scores.
62+
attention_scores = torch.matmul(query_layer, key_layer)
63+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
64+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
65+
attention_scores = attention_scores + attention_mask
66+
67+
# Normalize the attention scores to probabilities.
68+
attention_probs = F.softmax(attention_scores, dim=-1)
69+
70+
# This is actually dropping out entire tokens to attend to, which might
71+
# seem a bit unusual, but is taken from the original Transformer paper.
72+
attention_probs = self.dropout(attention_probs)
73+
74+
context_layer = torch.matmul(attention_probs, value_layer)
75+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
76+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
77+
context_layer = torch.reshape(context_layer, new_context_layer_shape)
78+
return context_layer

0 commit comments

Comments
 (0)