Skip to content

Commit faecb64

Browse files
authored
Fix input signature for predictOnBatch(); Add unit test (#574)
BUG
1 parent a68baa5 commit faecb64

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

src/engine/training.ts

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,15 +1122,17 @@ export class LayersModel extends Container implements tfc.InferenceModel {
11221122
* });
11231123
* model.predictOnBatch(tf.ones([8, 10])).print();
11241124
* ```
1125-
* @param x: Input samples, as an Tensor
1125+
* @param x: Input samples, as an Tensor (for models with exactly one
1126+
* input) or an array of Tensors (for models with more than one input).
11261127
* @return Tensor(s) of predictions
11271128
*/
11281129
/** @doc {heading: 'Models', subheading: 'Classes'} */
1129-
predictOnBatch(x: Tensor): Tensor|Tensor[] {
1130+
predictOnBatch(x: Tensor|Tensor[]): Tensor|Tensor[] {
11301131
checkInputData(x, this.inputNames, this.feedInputShapes, true);
11311132
// TODO(cais): Take care of the learning_phase boolean flag.
11321133
// if (this.useLearningPhase) ...
1133-
return this.predictLoop(x, Array.isArray(x) ? x[0].shape[0] : x.shape[0]);
1134+
const batchSize = (Array.isArray(x) ? x[0] : x).shape[0];
1135+
return this.predictLoop(x, batchSize);
11341136
}
11351137

11361138
protected standardizeUserDataXY(

src/models_test.ts

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2090,6 +2090,23 @@ describeMathCPUAndGPU('Sequential', () => {
20902090
model.predictOnBatch(inputBatch) as Tensor, expectedOutput);
20912091
});
20922092

2093+
it('predictOnBatch() works with multi-input model.', () => {
2094+
const input1 = tfl.input({shape: [3]});
2095+
const input2 = tfl.input({shape: [4]});
2096+
const dense1 = tfl.layers.dense({units: 1, activation: 'sigmoid'});
2097+
const y1 = dense1.apply(input1) as tfl.SymbolicTensor;
2098+
const dense2 = tfl.layers.dense({units: 1, activation: 'sigmoid'});
2099+
const y2 = dense2.apply(input2) as tfl.SymbolicTensor;
2100+
const y = tfl.layers.concatenate().apply([y1, y2]) as tfl.SymbolicTensor;
2101+
const model = tfl.model({inputs: [input1, input2], outputs: y});
2102+
2103+
const batchSize = 5;
2104+
const x1 = zeros([batchSize, 3]);
2105+
const x2 = ones([batchSize, 4]);
2106+
const out = model.predictOnBatch([x1, x2]) as Tensor;
2107+
expect(out.shape).toEqual([5, 2]);
2108+
});
2109+
20932110
it('compile() and fit()', async () => {
20942111
const batchSize = 5;
20952112
const inputSize = 4;

0 commit comments

Comments
 (0)