File tree Expand file tree Collapse file tree 2 files changed +22
-3
lines changed Expand file tree Collapse file tree 2 files changed +22
-3
lines changed Original file line number Diff line number Diff line change @@ -1122,15 +1122,17 @@ export class LayersModel extends Container implements tfc.InferenceModel {
11221122 * });
11231123 * model.predictOnBatch(tf.ones([8, 10])).print();
11241124 * ```
1125- * @param x: Input samples, as an Tensor
1125+ * @param x: Input samples, as an Tensor (for models with exactly one
1126+ * input) or an array of Tensors (for models with more than one input).
11261127 * @return Tensor(s) of predictions
11271128 */
11281129 /** @doc {heading: 'Models', subheading: 'Classes'} */
1129- predictOnBatch ( x : Tensor ) : Tensor | Tensor [ ] {
1130+ predictOnBatch ( x : Tensor | Tensor [ ] ) : Tensor | Tensor [ ] {
11301131 checkInputData ( x , this . inputNames , this . feedInputShapes , true ) ;
11311132 // TODO(cais): Take care of the learning_phase boolean flag.
11321133 // if (this.useLearningPhase) ...
1133- return this . predictLoop ( x , Array . isArray ( x ) ? x [ 0 ] . shape [ 0 ] : x . shape [ 0 ] ) ;
1134+ const batchSize = ( Array . isArray ( x ) ? x [ 0 ] : x ) . shape [ 0 ] ;
1135+ return this . predictLoop ( x , batchSize ) ;
11341136 }
11351137
11361138 protected standardizeUserDataXY (
Original file line number Diff line number Diff line change @@ -2090,6 +2090,23 @@ describeMathCPUAndGPU('Sequential', () => {
20902090 model . predictOnBatch ( inputBatch ) as Tensor , expectedOutput ) ;
20912091 } ) ;
20922092
2093+ it ( 'predictOnBatch() works with multi-input model.' , ( ) => {
2094+ const input1 = tfl . input ( { shape : [ 3 ] } ) ;
2095+ const input2 = tfl . input ( { shape : [ 4 ] } ) ;
2096+ const dense1 = tfl . layers . dense ( { units : 1 , activation : 'sigmoid' } ) ;
2097+ const y1 = dense1 . apply ( input1 ) as tfl . SymbolicTensor ;
2098+ const dense2 = tfl . layers . dense ( { units : 1 , activation : 'sigmoid' } ) ;
2099+ const y2 = dense2 . apply ( input2 ) as tfl . SymbolicTensor ;
2100+ const y = tfl . layers . concatenate ( ) . apply ( [ y1 , y2 ] ) as tfl . SymbolicTensor ;
2101+ const model = tfl . model ( { inputs : [ input1 , input2 ] , outputs : y } ) ;
2102+
2103+ const batchSize = 5 ;
2104+ const x1 = zeros ( [ batchSize , 3 ] ) ;
2105+ const x2 = ones ( [ batchSize , 4 ] ) ;
2106+ const out = model . predictOnBatch ( [ x1 , x2 ] ) as Tensor ;
2107+ expect ( out . shape ) . toEqual ( [ 5 , 2 ] ) ;
2108+ } ) ;
2109+
20932110 it ( 'compile() and fit()' , async ( ) => {
20942111 const batchSize = 5 ;
20952112 const inputSize = 4 ;
You can’t perform that action at this time.
0 commit comments