Skip to content

Commit df13769

Browse files
authored
Merge pull request #6 from matlab-deep-learning/input_mask_bug
Input mask bug for gpuArray
2 parents 774eb6d + bae5430 commit df13769

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

+bert/model.m

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
arguments
5252
x dlarray {mustBeNumericDlarray,mustBeNonempty}
5353
parameters {mustBeA(parameters,'struct')}
54-
nvp.InputMask {mustBeALogicalOrDlarrayLogical} = logical.empty()
54+
nvp.InputMask {mustBeNumericOrLogical} = logical.empty()
5555
nvp.DropoutProb (1,1) {mustBeNonnegative,mustBeLessThanOrEqual(nvp.DropoutProb,1),mustBeNumeric} = 0
5656
nvp.AttentionDropoutProb (1,1) {mustBeNonnegative,mustBeLessThanOrEqual(nvp.AttentionDropoutProb,1),mustBeNumeric} = 0
5757
nvp.Outputs {mustBePositive,mustBeLessThanOrEqualNumLayers(nvp.Outputs,parameters),mustBeInteger,mustBeNumeric} = parameters.Hyperparameters.NumLayers
@@ -72,7 +72,7 @@
7272
inputMask = x~=nvp.PaddingCode;
7373
else
7474
assert(isequal(size(nvp.InputMask),size(x)),"bert:model:InvalidMaskSize","Expected InputMask to have same size as input X.");
75-
inputMask = nvp.InputMask;
75+
inputMask = logical(nvp.InputMask);
7676
end
7777

7878
% Assuming CTB format of x.

test/bert/tmodel.m

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,8 @@ function canUseInputMask(test)
261261
function s = iInvalidInputMask()
262262
s = struct(...
263263
'InvalidDims', iInvalidInputCase([true; true; false],"bert:model:InvalidMaskSize"), ...
264-
'InvalidTypeString', iInvalidInputCase("foo",'MATLAB:validators:mustBeA'), ...
265-
'InvalidTypeCell', iInvalidInputCase({{1, 1, 0}},'MATLAB:validators:mustBeA'), ...
266-
'NonLogical', iInvalidInputCase(dlarray([5, 1, 0]),'MATLAB:validators:mustBeA'));
264+
'InvalidTypeString', iInvalidInputCase("foo",'MATLAB:validators:mustBeNumericOrLogical'), ...
265+
'InvalidTypeCell', iInvalidInputCase({{1, 1, 0}},'MATLAB:validators:mustBeNumericOrLogical'));
267266
end
268267

269268
function s = iInvalidInputX()

0 commit comments

Comments
 (0)