Skip to content

Commit 4ebc201

Browse files
committed
Fixing predictMaskedToken and updating README.
1 parent 3a2c48f commit 4ebc201

File tree

3 files changed

+55
-4
lines changed

3 files changed

+55
-4
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,14 @@ Download or [clone](https://www.mathworks.com/help/matlab/matlab_prog/use-source
7171
## Example: Classify Text Data Using BERT
7272
The simplest use of a pretrained BERT model is to use it as a feature extractor. In particular, you can use the BERT model to convert documents to feature vectors which you can then use as inputs to train a deep learning classification network.
7373

74-
The example [`ClassifyTextDataUsingBERT.m`](./ClassifyTextDataUsingBERT.m) shows how to use a pretrained BERT model to classify failure events given a data set of factory reports.
74+
The example [`ClassifyTextDataUsingBERT.m`](./ClassifyTextDataUsingBERT.m) shows how to use a pretrained BERT model to classify failure events given a data set of factory reports. This example requires the `factoryReports.csv` data set from the Text Analytics example (Prepare Text Data for Analysis)[https://www.mathworks.com/help/textanalytics/ug/prepare-text-data-for-analysis.html].
7575

7676
## Example: Fine-Tune Pretrained BERT Model
7777
To get the most out of a pretrained BERT model, you can retrain and fine tune the BERT parameters weights for your task.
7878

79-
The example [`FineTuneBERT.m`](./FineTuneBERT.m) shows how to fine-tune a pretrained BERT model to classify failure events given a data set of factory reports.
79+
The example [`FineTuneBERT.m`](./FineTuneBERT.m) shows how to fine-tune a pretrained BERT model to classify failure events given a data set of factory reports. This example requires the `factoryReports.csv` data set from the Text Analytics example (Prepare Text Data for Analysis)[https://www.mathworks.com/help/textanalytics/ug/prepare-text-data-for-analysis.html].
80+
81+
The example [`FineTuneBERTJapanese.m`](./FineTuneBERTJapanese.m) shows the same workflow using a pretrained Japanese-BERT model. This example requires the `factoryReportsJP.csv` data set from the Text Analytics example (Analyze Japanese Text Data)[https://www.mathworks.com/help/textanalytics/ug/analyze-japanese-text.html], available in R2023a or later.
8082

8183
## Example: Analyze Sentiment with FinBERT
8284
FinBERT is a sentiment analysis model trained on financial text data and fine-tuned for sentiment analysis.

predictMaskedToken.m

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
% replaces instances of mdl.Tokenizer.MaskToken in the string text with
77
% the most likely token according to the BERT model mdl.
88

9-
% Copyright 2021 The MathWorks, Inc.
9+
% Copyright 2021-2023 The MathWorks, Inc.
1010
arguments
1111
mdl {mustBeA(mdl,'struct')}
1212
str {mustBeText}
@@ -44,7 +44,7 @@
4444
tokens = fulltok.tokenize(pieces(i));
4545
if ~isempty(tokens)
4646
% "" tokenizes to empty - awkward
47-
x = cat(2,x,fulltok.encode(tokens));
47+
x = cat(2,x,fulltok.encode(tokens{1}));
4848
end
4949
if i<numel(pieces)
5050
x = cat(2,x,maskCode);

test/tpredictMaskedToken.m

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
classdef(SharedTestFixtures={
2+
DownloadBERTFixture, DownloadJPBERTFixture}) tpredictMaskedToken < matlab.unittest.TestCase
3+
% tpredictMaskedToken Unit test for predictMaskedToken
4+
5+
% Copyright 2023 The MathWorks, Inc.
6+
7+
properties(TestParameter)
8+
AllModels = {"base","multilingual-cased","medium",...
9+
"small","mini","tiny","japanese-base",...
10+
"japanese-base-wwm"}
11+
ValidText = iGetValidText;
12+
end
13+
14+
methods(Test)
15+
function verifyOutputDimSizes(test, AllModels, ValidText)
16+
inSize = size(ValidText);
17+
mdl = bert("Model", AllModels);
18+
outputText = predictMaskedToken(mdl,ValidText);
19+
test.verifyEqual(size(outputText), inSize);
20+
end
21+
22+
function maskTokenIsRemoved(test, AllModels)
23+
text = "This has a [MASK] token.";
24+
mdl = bert("Model", AllModels);
25+
outputText = predictMaskedToken(mdl,text);
26+
test.verifyFalse(contains(outputText, "[MASK]"));
27+
end
28+
29+
function inputWithoutMASKRemainsTheSame(test, AllModels)
30+
text = "This has a no mask token.";
31+
mdl = bert("Model", AllModels);
32+
outputText = predictMaskedToken(mdl,text);
33+
test.verifyEqual(text, outputText);
34+
end
35+
end
36+
end
37+
38+
function validText = iGetValidText
39+
manyStrs = ["Accelerating the pace of [MASK] and science";
40+
"The cat [MASK] soundly.";
41+
"The [MASK] set beautifully."];
42+
singleStr = "Artificial intelligence continues to shape the future of industries," + ...
43+
" as innovative applications emerge in fields such as healthcare, transportation," + ...
44+
" entertainment, and finance, driving productivity and enhancing human capabilities.";
45+
validText = struct('StringsAsColumns',manyStrs,...
46+
'StringsAsRows',manyStrs',...
47+
'ManyStrings',repmat(singleStr,3),...
48+
'SingleString',singleStr);
49+
end

0 commit comments

Comments
 (0)