Skip to content

Commit 774eb6d

Browse files
authored
Merge pull request #5 from matlab-deep-learning/bert
BERT + FinBERT
2 parents 59abf30 + d40bd93 commit 774eb6d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+3924
-78
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
function dirpath = convertModelNameToDirectories(name)
2+
% convertModelNameToDirectories Converts the user facing model name to
3+
% the directory name used by support files.
4+
5+
% Copyright 2021 The MathWorks, Inc.
6+
arguments
7+
name (1,1) string
8+
end
9+
modelName = userInputToSupportFileName(name);
10+
dirpath = {"data","networks","bert",modelName};
11+
end
12+
13+
function supportfileName = userInputToSupportFileName(name)
14+
persistent map;
15+
if isempty(map)
16+
names = namesArray();
17+
map = containers.Map(names(:,1),names(:,2));
18+
end
19+
supportfileName = map(name);
20+
end
21+
22+
function names = namesArray()
23+
names = [
24+
"base", "uncased_L12_H768_A12";
25+
"multilingual-cased", "multicased_L12_H768_A12";
26+
"medium", "uncased_L8_H512_A8";
27+
"small", "uncased_L4_H512_A8";
28+
"mini", "uncased_L4_H256_A4";
29+
"tiny", "uncased_L2_H128_A2"];
30+
end
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
function weightsStruct = createParameterStruct(oldWeightsStruct)
2+
% createParameterStruct Given the flat struct of BERT model weights, this
3+
% function parses that into a tree-like struct of weights.
4+
5+
% Copyright 2021 The MathWorks, Inc.
6+
7+
f = fieldnames(oldWeightsStruct);
8+
for i = 1:numel(f)
9+
name = f{i};
10+
encoderLayerPrefix = "bert_encoder_layer";
11+
embeddingLayerPrefix = "bert_embeddings";
12+
poolingLayerPrefix = "bert_pooler_dense";
13+
langModPrefix = "cls_predictions";
14+
nspPrefix = "cls_seq_relationship_output";
15+
genericClassifierPrefix = "classifier_";
16+
17+
weight = dlarray(oldWeightsStruct.(name));
18+
19+
if startsWith(name,encoderLayerPrefix)
20+
% BERT transformer layer weights.
21+
layerIndex = extractBetween(name,encoderLayerPrefix+"_","_");
22+
newLayerIndex = str2double(layerIndex)+1;
23+
layerName = encoderLayerPrefix+"_"+layerIndex;
24+
shortLayerName = "layer_"+newLayerIndex;
25+
paramName = extractAfter(name,layerName+"_");
26+
attentionOrFeedforward = iParseAttentionOrFeedforward(paramName);
27+
[subParamName,subsubParamName] = iParseAttentionAndFeedforwardParamName(paramName,attentionOrFeedforward);
28+
weightsStruct.("encoder_layers").(shortLayerName).(attentionOrFeedforward).(subParamName).(subsubParamName) = weight;
29+
30+
elseif startsWith(name,embeddingLayerPrefix)
31+
% Emebdding parameters
32+
paramName = extractAfter(name,embeddingLayerPrefix+"_");
33+
if contains(paramName,"LayerNorm")
34+
[subname,subsubname] = iParseLayerNorm(paramName);
35+
weightsStruct.("embeddings").(subname).(subsubname) = weight;
36+
else
37+
weightsStruct.("embeddings").(paramName) = weight;
38+
end
39+
40+
elseif startsWith(name,poolingLayerPrefix)
41+
paramName = extractAfter(name,poolingLayerPrefix+"_");
42+
weightsStruct.("pooler").(paramName) = weight;
43+
44+
elseif startsWith(name,langModPrefix)
45+
paramName = extractAfter(name,langModPrefix+"_");
46+
[subname,subsubname] = iParseLM(paramName);
47+
weightsStruct.("masked_LM").(subname).(subsubname) = weight;
48+
49+
elseif startsWith(name,nspPrefix)
50+
paramName = extractAfter(name,nspPrefix+"_");
51+
if strcmp(paramName,"weights")
52+
% This parameter wasn't renamed and transposed before
53+
% uploading. We can fix it here.
54+
paramName = "kernel";
55+
weight = weight.';
56+
end
57+
weightsStruct.("sequence_relation").(paramName) = weight;
58+
59+
elseif startsWith(name,genericClassifierPrefix)
60+
paramName = extractAfter(name,genericClassifierPrefix);
61+
weightsStruct.("classifier").(paramName) = weight;
62+
end
63+
end
64+
end
65+
66+
function name = iParseAttentionOrFeedforward(name)
67+
if contains(name, "attention")
68+
name = "attention";
69+
else
70+
name = "feedforward";
71+
end
72+
end
73+
74+
function [name,subname] = iParseAttentionAndFeedforwardParamName(name,attnOrFeedforward)
75+
switch attnOrFeedforward
76+
case "attention"
77+
[name,subname] = iParseAttentionParamName(name);
78+
case "feedforward"
79+
[name,subname] = iParseFeedforwardParamName(name);
80+
end
81+
end
82+
83+
function [subname,subsubname] = iParseAttentionParamName(name)
84+
if contains(name,"LayerNorm")
85+
[subname,subsubname] = iParseLayerNorm(name);
86+
else
87+
name = strrep(name,"self_","");
88+
name = strrep(name,"_dense","");
89+
subname = extractBetween(name,"attention_","_");
90+
subsubname = extractAfter(name,subname+"_");
91+
end
92+
end
93+
94+
function [subname,subsubname] = iParseFeedforwardParamName(name)
95+
if contains(name,"LayerNorm")
96+
[subname,subsubname] = iParseLayerNorm(name);
97+
else
98+
subname = extractBefore(name,"_");
99+
subsubname = extractAfter(name,"dense_");
100+
end
101+
end
102+
103+
function [subname,subsubname] = iParseLayerNorm(name)
104+
subname = "LayerNorm";
105+
subsubname = extractAfter(name,"LayerNorm_");
106+
end
107+
108+
function [subname,subsubname] = iParseLM(name)
109+
if contains(name,"LayerNorm")
110+
[subname,subsubname] = iParseLayerNorm(name);
111+
else
112+
name = strrep(name,"dense_","");
113+
subname = extractBefore(name,"_");
114+
subsubname = extractAfter(name,"_");
115+
end
116+
end
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
function [x,untokenizedPieces,ismask] = encodeWithMaskToken(tok,str)
2+
% encodeWithMaskToken This function handles the case of encoding an input
3+
% string that includes tokens such as [MASK].
4+
5+
% Copyright 2021 The MathWorks, Inc.
6+
arguments
7+
tok bert.tokenizer.BERTTokenizer
8+
str (1,:) string
9+
end
10+
[seqs,untokenizedPieces] = arrayfun(@(s)encodeScalarString(tok,s),str,'UniformOutput',false);
11+
x = padsequences(seqs,2,'PaddingValue',tok.PaddingCode);
12+
maskCode = tok.MaskCode;
13+
ismask = x==maskCode;
14+
end
15+
16+
function [x,pieces] = encodeScalarString(tok,str)
17+
pieces = split(str,tok.MaskToken);
18+
fulltok = tok.FullTokenizer;
19+
maskCode = fulltok.encode(tok.MaskToken);
20+
x = [];
21+
22+
for i = 1:numel(pieces)
23+
tokens = fulltok.tokenize(pieces(i));
24+
if ~isempty(tokens)
25+
% "" tokenizes to empty - awkward
26+
x = cat(2,x,fulltok.encode(tokens));
27+
end
28+
if i<numel(pieces)
29+
x = cat(2,x,maskCode);
30+
end
31+
end
32+
x = [fulltok.encode(tok.StartToken),x,fulltok.encode(tok.SeparatorToken)];
33+
end
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
function filePath = getSupportFilePath(modelName,fileName)
2+
% getSupportFilePath This function is for converting any differences
3+
% between the model names presented to the user and the support files
4+
% URLs.
5+
6+
% Copyright 2021 The MathWorks, Inc.
7+
arguments
8+
modelName (1,1) string
9+
fileName (1,1) string
10+
end
11+
directory = bert.internal.convertModelNameToDirectories(modelName);
12+
sd = matlab.internal.examples.utils.getSupportFileDir();
13+
localFile = fullfile(sd,"nnet",directory{:},fileName);
14+
if exist(localFile,'file')~=2
15+
disp("Downloading "+fileName+" to: "+localFile);
16+
end
17+
fileURL = strjoin([directory,fileName],"/");
18+
filePath = matlab.internal.examples.downloadSupportFile("nnet",fileURL);
19+
end

+bert/+internal/inferTypeID.m

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
function types = inferTypeID(x,separatorCode)
2+
% infer the typeIDs from a CTB unlabeled array x
3+
xsz = size(x);
4+
types = ones(xsz);
5+
sepId = x==separatorCode;
6+
if isa(sepId,'dlarray')
7+
sepId = extractdata(sepId);
8+
end
9+
% Find which observations have >1 separator - when there is 1 separator,
10+
% any padding is considered "type 1".
11+
cs = cumsum(sepId,2);
12+
obsNeedsType2 = cs(:,end,:)>1;
13+
% Type 2 tokens are those between the first (exclusive) and second
14+
% separator (inclusive) if a second separator was present.
15+
type2positions = circshift(cs==1,1) & obsNeedsType2;
16+
types(type2positions) = 2;
17+
end
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
function [toks,probs] = predictMaskedToken(mdl,x,maskIdx,k)
2+
arguments
3+
mdl
4+
x
5+
maskIdx
6+
k (1,1) double {mustBePositive,mustBeInteger} = 1
7+
end
8+
probs = bert.languageModel(x,mdl.Parameters);
9+
probs = extractdata(probs(:,maskIdx));
10+
[~,idx] = maxk(probs,k);
11+
toks = mdl.Tokenizer.FullTokenizer.decode(idx);
12+
end
13+
14+

+bert/+layer/block.m

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
function z = block(z,weights,hyperParameters,nvp)
2+
% block Transformer block for BERT
3+
%
4+
% Z = block(X,weights,hyperParameters) computes the BERT style
5+
% transformer block on the input X as described in [1]. Here X is a
6+
% (numFeatures*numHeads)-by-numInputSubwords array. The weights and
7+
% hyperParameters must be structs in the same format as returned by the
8+
% bert() function.
9+
%
10+
% Z = block(X,weights,hyperParameters,'PARAM1',VAL1,'PARAM2',VAL2)
11+
% specifies the optional parameter name/value pairs:
12+
%
13+
% 'HiddenDropout' - The dropout probability to be applied between
14+
% the self attention mechanism and the residual
15+
% connection. The default is 0.
16+
%
17+
% 'AttentionDropout' - The dropout probability to be applied to the
18+
% attention probabilities. The default is 0.
19+
%
20+
% 'InputMask' - A logical mask to be used in the attention
21+
% mechanism, for example to block attending to
22+
% padding tokens. The default is [], no masking is
23+
% applied.
24+
%
25+
% References:
26+
% [1] https://arxiv.org/abs/1810.04805
27+
28+
% Copyright 2021 The MathWorks, Inc.
29+
arguments
30+
z
31+
weights
32+
hyperParameters
33+
nvp.HiddenDropout (1,1) double {mustBeNonnegative, mustBeLessThanOrEqual(nvp.HiddenDropout,1)} = 0
34+
nvp.AttentionDropout (1,1) double {mustBeNonnegative, mustBeLessThanOrEqual(nvp.AttentionDropout,1)} = 0
35+
nvp.InputMask = []
36+
end
37+
z = attention(z,weights.attention,hyperParameters.NumHeads,nvp.AttentionDropout,nvp.HiddenDropout,nvp.InputMask);
38+
z = ffn(z,weights.feedforward,nvp.HiddenDropout);
39+
end
40+
41+
function z = attention(z,w,num_heads,attentionDropout,dropout,mask)
42+
% The self attention part of the transformer layer.
43+
layer_input = z;
44+
45+
% Get weights
46+
Q_w = w.query.kernel;
47+
Q_b = w.query.bias;
48+
K_w = w.key.kernel;
49+
K_b = w.key.bias;
50+
V_w = w.value.kernel;
51+
V_b = w.value.bias;
52+
53+
% Put weights into format for transformer.layer.attention
54+
weights.attn_c_attn_w_0 = cat(1,Q_w,K_w,V_w);
55+
weights.attn_c_attn_b_0 = cat(1,Q_b,K_b,V_b);
56+
weights.attn_c_proj_w_0 = w.output.kernel;
57+
weights.attn_c_proj_b_0 = w.output.bias;
58+
hyperparameters.NumHeads = num_heads;
59+
z = transformer.layer.attention(z,[],weights,hyperparameters,'CausalMask',false,'Dropout',attentionDropout,'InputMask',mask);
60+
61+
% Dropout
62+
z = transformer.layer.dropout(z,dropout);
63+
64+
% Residual connection
65+
z = layer_input+z;
66+
67+
% Layer normalize.
68+
z = transformer.layer.normalization(z,w.LayerNorm.gamma,w.LayerNorm.beta);
69+
end
70+
71+
function z = ffn(z,w,dropout)
72+
% The feed-forward network part of the transformer layer.
73+
74+
% Weights for embedding in higher dimensional space
75+
int_w = w.intermediate.kernel;
76+
int_b = w.intermediate.bias;
77+
78+
% Weights for projecting back down to original space
79+
out_w = w.output.kernel;
80+
out_b = w.output.bias;
81+
82+
% Create weights struct for multiLayerPerceptron
83+
weights.mlp_c_fc_w_0 = int_w;
84+
weights.mlp_c_fc_b_0 = int_b;
85+
weights.mlp_c_proj_w_0 = out_w;
86+
weights.mlp_c_proj_b_0 = out_b;
87+
ffn_out = transformer.layer.multiLayerPerceptron(z,weights);
88+
89+
% Dropout
90+
ffn_out = transformer.layer.dropout(ffn_out,dropout);
91+
92+
% Layer normalize.
93+
out_g = w.LayerNorm.gamma;
94+
out_b = w.LayerNorm.beta;
95+
z = transformer.layer.normalization(ffn_out+z,out_g,out_b);
96+
end

+bert/+layer/classifier.m

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
function y = classifier(x,p)
2+
% classifier The standard BERT classifier, a single fullyconnect.
3+
%
4+
% Z = classifier(X,classifierWeights) applies a fullyconnect operation to
5+
% the input X with weights classifierWeights.kernel and bias
6+
% classifierWeights.bias. The input X must be an unformatted dlarray of
7+
% size hiddenSize-by-numObs. The classifierWeights.kernel must be of size
8+
% outputSize-by-hiddenSize, and the classifierWeights.bias must be of
9+
% size outputSize-by-1.
10+
11+
% Copyright 2021 The MathWorks, Inc.
12+
y = transformer.layer.convolution1d(x,p.kernel,p.bias);
13+
end

+bert/+layer/classifierHead.m

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
function z = classifierHead(x,poolerWeights,classifierWeights)
2+
% classifierHead The standard classification head for a BERT model.
3+
%
4+
% Z = classifierHead(X,poolerWeights,classifierWeights) applies
5+
% bert.layer.pooler and bert.layer.classifier to X with poolerWeights and
6+
% classifierWeights respectively. Both poolerWeights and
7+
% classifierWeights must be structs with fields 'kernel' and 'bias'.
8+
9+
% Copyright 2021 The MathWorks, Inc.
10+
z = bert.layer.pooler(x,poolerWeights);
11+
z = bert.layer.classifier(z,classifierWeights);
12+
end

+bert/+layer/embedding.m

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
function z = embedding(x,types,positions,w,dropout)
2+
% embedding The BERT embeddings of encoded tokens, token types and token
3+
% positions.
4+
%
5+
% Z = embedding(X,types,positions,weights,dropoutProbability) computes
6+
% the embedding of encoded tokens X, token types specified by types, and
7+
% token positions. Inputs X, types and positions are
8+
% 1-by-numInputTokens-by-numObs unformatted dlarray-s. The types take
9+
% values 1 or 2. The weights input is a struct of embedding weights such
10+
% as mdl.Parameters.Weights.embeddings where mdl = bert(). The
11+
% dropoutProbability is a scalar double between 0 and 1 corresponding to
12+
% the post-embedding dropout probability.
13+
14+
% Copyright 2021 The MathWorks, Inc.
15+
wordEmbedding = embed(x,w.word_embeddings,'DataFormat','CTB');
16+
typeEmbedding = embed(types,w.token_type_embeddings,'DataFormat','CTB');
17+
positionEmbedding = embed(positions,w.position_embeddings,'DataFormat','CTB');
18+
z = wordEmbedding+typeEmbedding+positionEmbedding;
19+
z = transformer.layer.normalization(z,w.LayerNorm.gamma,w.LayerNorm.beta);
20+
z = transformer.layer.dropout(z,dropout);
21+
end

0 commit comments

Comments
 (0)