Skip to content

Commit a68baa5

Browse files
FremyCompanycaisq
authored andcommitted
Fix batch size detection for multi-input models (#572)
1 parent 8c00b47 commit a68baa5

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/engine/training.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1130,7 +1130,7 @@ export class LayersModel extends Container implements tfc.InferenceModel {
11301130
checkInputData(x, this.inputNames, this.feedInputShapes, true);
11311131
// TODO(cais): Take care of the learning_phase boolean flag.
11321132
// if (this.useLearningPhase) ...
1133-
return this.predictLoop(x, x.shape[0]);
1133+
return this.predictLoop(x, Array.isArray(x) ? x[0].shape[0] : x.shape[0]);
11341134
}
11351135

11361136
protected standardizeUserDataXY(

0 commit comments

Comments
 (0)