|
5 | 5 | % Copyright 2023 The MathWorks, Inc. |
6 | 6 |
|
7 | 7 | properties(TestParameter) |
8 | | - AllModels = {"base","multilingual-cased","medium",... |
9 | | - "small","mini","tiny","japanese-base",... |
10 | | - "japanese-base-wwm"} |
| 8 | + Models = {"tiny","japanese-base-wwm"} |
11 | 9 | ValidText = iGetValidText; |
12 | 10 | end |
13 | 11 |
|
14 | 12 | methods(Test) |
15 | | - function verifyOutputDimSizes(test, AllModels, ValidText) |
| 13 | + function verifyOutputDimSizes(test, Models, ValidText) |
16 | 14 | inSize = size(ValidText); |
17 | | - mdl = bert("Model", AllModels); |
| 15 | + mdl = bert("Model", Models); |
18 | 16 | outputText = predictMaskedToken(mdl,ValidText); |
19 | 17 | test.verifyEqual(size(outputText), inSize); |
20 | 18 | end |
21 | 19 |
|
22 | | - function maskTokenIsRemoved(test, AllModels) |
| 20 | + function maskTokenIsRemoved(test, Models) |
23 | 21 | text = "This has a [MASK] token."; |
24 | | - mdl = bert("Model", AllModels); |
| 22 | + mdl = bert("Model", Models); |
25 | 23 | outputText = predictMaskedToken(mdl,text); |
26 | 24 | test.verifyFalse(contains(outputText, "[MASK]")); |
27 | 25 | end |
28 | 26 |
|
29 | | - function inputWithoutMASKRemainsTheSame(test, AllModels) |
| 27 | + function inputWithoutMASKRemainsTheSame(test, Models) |
30 | 28 | text = "This has a no mask token."; |
31 | | - mdl = bert("Model", AllModels); |
| 29 | + mdl = bert("Model", Models); |
32 | 30 | outputText = predictMaskedToken(mdl,text); |
33 | 31 | test.verifyEqual(text, outputText); |
34 | 32 | end |
|
0 commit comments