1+ function weightsStruct = createParameterStruct(oldWeightsStruct )
2+ % createParameterStruct Given the flat struct of BERT model weights, this
3+ % function parses that into a tree-like struct of weights.
4+
5+ % Copyright 2021 The MathWorks, Inc.
6+
7+ f = fieldnames(oldWeightsStruct );
8+ for i = 1 : numel(f )
9+ name = f{i };
10+ encoderLayerPrefix = " bert_encoder_layer" ;
11+ embeddingLayerPrefix = " bert_embeddings" ;
12+ poolingLayerPrefix = " bert_pooler_dense" ;
13+ langModPrefix = " cls_predictions" ;
14+ nspPrefix = " cls_seq_relationship_output" ;
15+ genericClassifierPrefix = " classifier_" ;
16+
17+ weight = dlarray(oldWeightsStruct.(name ));
18+
19+ if startsWith(name ,encoderLayerPrefix )
20+ % BERT transformer layer weights.
21+ layerIndex = extractBetween(name ,encoderLayerPrefix +" _" ," _" );
22+ newLayerIndex = str2double(layerIndex )+1 ;
23+ layerName = encoderLayerPrefix +" _" +layerIndex ;
24+ shortLayerName = " layer_" +newLayerIndex ;
25+ paramName = extractAfter(name ,layerName +" _" );
26+ attentionOrFeedforward = iParseAttentionOrFeedforward(paramName );
27+ [subParamName ,subsubParamName ] = iParseAttentionAndFeedforwardParamName(paramName ,attentionOrFeedforward );
28+ weightsStruct.(" encoder_layers" ).(shortLayerName ).(attentionOrFeedforward ).(subParamName ).(subsubParamName ) = weight ;
29+
30+ elseif startsWith(name ,embeddingLayerPrefix )
31+ % Emebdding parameters
32+ paramName = extractAfter(name ,embeddingLayerPrefix +" _" );
33+ if contains(paramName ," LayerNorm" )
34+ [subname ,subsubname ] = iParseLayerNorm(paramName );
35+ weightsStruct.(" embeddings" ).(subname ).(subsubname ) = weight ;
36+ else
37+ weightsStruct.(" embeddings" ).(paramName ) = weight ;
38+ end
39+
40+ elseif startsWith(name ,poolingLayerPrefix )
41+ paramName = extractAfter(name ,poolingLayerPrefix +" _" );
42+ weightsStruct.(" pooler" ).(paramName ) = weight ;
43+
44+ elseif startsWith(name ,langModPrefix )
45+ paramName = extractAfter(name ,langModPrefix +" _" );
46+ [subname ,subsubname ] = iParseLM(paramName );
47+ weightsStruct.(" masked_LM" ).(subname ).(subsubname ) = weight ;
48+
49+ elseif startsWith(name ,nspPrefix )
50+ paramName = extractAfter(name ,nspPrefix +" _" );
51+ if strcmp(paramName ," weights" )
52+ % This parameter wasn't renamed and transposed before
53+ % uploading. We can fix it here.
54+ paramName = " kernel" ;
55+ weight = weight .' ;
56+ end
57+ weightsStruct.(" sequence_relation" ).(paramName ) = weight ;
58+
59+ elseif startsWith(name ,genericClassifierPrefix )
60+ paramName = extractAfter(name ,genericClassifierPrefix );
61+ weightsStruct.(" classifier" ).(paramName ) = weight ;
62+ end
63+ end
64+ end
65+
66+ function name = iParseAttentionOrFeedforward(name )
67+ if contains(name , " attention" )
68+ name = " attention" ;
69+ else
70+ name = " feedforward" ;
71+ end
72+ end
73+
74+ function [name ,subname ] = iParseAttentionAndFeedforwardParamName(name ,attnOrFeedforward )
75+ switch attnOrFeedforward
76+ case " attention"
77+ [name ,subname ] = iParseAttentionParamName(name );
78+ case " feedforward"
79+ [name ,subname ] = iParseFeedforwardParamName(name );
80+ end
81+ end
82+
83+ function [subname ,subsubname ] = iParseAttentionParamName(name )
84+ if contains(name ," LayerNorm" )
85+ [subname ,subsubname ] = iParseLayerNorm(name );
86+ else
87+ name = strrep(name ," self_" ," " );
88+ name = strrep(name ," _dense" ," " );
89+ subname = extractBetween(name ," attention_" ," _" );
90+ subsubname = extractAfter(name ,subname +" _" );
91+ end
92+ end
93+
94+ function [subname ,subsubname ] = iParseFeedforwardParamName(name )
95+ if contains(name ," LayerNorm" )
96+ [subname ,subsubname ] = iParseLayerNorm(name );
97+ else
98+ subname = extractBefore(name ," _" );
99+ subsubname = extractAfter(name ," dense_" );
100+ end
101+ end
102+
103+ function [subname ,subsubname ] = iParseLayerNorm(name )
104+ subname = " LayerNorm" ;
105+ subsubname = extractAfter(name ," LayerNorm_" );
106+ end
107+
108+ function [subname ,subsubname ] = iParseLM(name )
109+ if contains(name ," LayerNorm" )
110+ [subname ,subsubname ] = iParseLayerNorm(name );
111+ else
112+ name = strrep(name ," dense_" ," " );
113+ subname = extractBefore(name ," _" );
114+ subsubname = extractAfter(name ," _" );
115+ end
116+ end
0 commit comments