@@ -372,75 +372,83 @@ export function randomNormal(
372372 * For N dimensions it is a sum product over the last axis of x and the
373373 * second-to-last of y:
374374 *
375- * @param x A tensor of at least rank 2.
376- * @param y A tensor of at least rank 2.
377- * @param fusedActivation (optional) A string identifying the activation
375+ * @param a A tensor of at least rank 2.
376+ * @param b A tensor of at least rank 2.
377+ * @param activation (optional) A string identifying the activation
378378 * function.
379379 * @return Result of the dot operation.
380380 */
381381export function dot (
382- x : Tensor , y : Tensor , fusedActivation ?: tfc . fused . Activation ,
382+ a : Tensor , b : Tensor , activation ?: tfc . fused . Activation ,
383383 bias ?: Tensor ) : Tensor {
384- if ( ( x . rank < 2 ) || ( y . rank < 2 ) ) {
384+ if ( ( a . rank < 2 ) || ( b . rank < 2 ) ) {
385385 throw new NotImplementedError (
386386 `dot requires both inputs to be rank >= 2` +
387- ` but got x shape = ${ x . shape } and y shape = ${ y . shape } ` ) ;
387+ ` but got x shape = ${ a . shape } and y shape = ${ b . shape } ` ) ;
388388 }
389- if ( y . rank >= 3 ) {
390- const xLastDim = x . shape . slice ( - 1 ) [ 0 ] ;
391- const ySecondLastDim = y . shape . slice ( - 2 ) [ 0 ] ;
389+ if ( b . rank >= 3 ) {
390+ const xLastDim = a . shape . slice ( - 1 ) [ 0 ] ;
391+ const ySecondLastDim = b . shape . slice ( - 2 ) [ 0 ] ;
392392 if ( xLastDim !== ySecondLastDim ) {
393393 throw new NotImplementedError (
394394 `If rank y >= 3, then the second last dim` +
395395 ` of y must equal the last dim of x but got x shape = ${
396- x . shape } and ` +
397- ` y shape = ${ y . shape } ` ) ;
396+ a . shape } and ` +
397+ ` y shape = ${ b . shape } ` ) ;
398398 }
399399 }
400400 // Handle basic 2D x 2D case.
401- if ( ( x . rank === 2 ) && ( y . rank === 2 ) ) {
402- const transposeX = false ;
403- const transposeY = false ;
401+ if ( ( a . rank === 2 ) && ( b . rank === 2 ) ) {
402+ const transposeA = false ;
403+ const transposeB = false ;
404404 // tfc.fused.matMul only fuses certain activation functions. Unsupported
405405 // activation functions are treated as 'linear' activations, which is
406406 // equivalent to a no-op.
407- return tfc . fused . matMul (
408- x as Tensor2D , y as Tensor2D , transposeX , transposeY ,
409- bias ? reshapeBias ( x . rank , bias , imageDataFormat ( ) ) : null ,
410- fusedActivation ) ;
407+ return tfc . fused . matMul ( {
408+ a,
409+ b : b as Tensor2D ,
410+ transposeA,
411+ transposeB,
412+ bias : bias ? reshapeBias ( a . rank , bias , imageDataFormat ( ) ) : null ,
413+ activation
414+ } ) ;
411415 } else {
412416 // Reshape x into the analogous 2D Tensor.
413- const xFirstDims = x . shape . slice ( ) ; // Holds all but the last dim of x.
414- const xLastDim = xFirstDims . pop ( ) ;
415- x = x . reshape ( [ - 1 , xLastDim ] ) ;
417+ const aFirstDims = a . shape . slice ( ) ; // Holds all but the last dim of x.
418+ const aLastDim = aFirstDims . pop ( ) ;
419+ a = a . reshape ( [ - 1 , aLastDim ] ) ;
416420
417421 // Reshape y into the analogous 2D Tensor, and keep track of the
418422 // required dimensions to reproduce the output shape.
419- const yShape = y . shape . slice ( ) ;
420- const yLastDim = yShape . pop ( ) ;
421- const ySecondLastDim = yShape . pop ( ) ;
422- const yOtherDims = [ ...yShape , yLastDim ] ;
423+ const bShape = b . shape . slice ( ) ;
424+ const bLastDim = bShape . pop ( ) ;
425+ const ySecondLastDim = bShape . pop ( ) ;
426+ const yOtherDims = [ ...bShape , bLastDim ] ;
423427 // permutation should be like [r-2, 0, 1, 2, ... r-4, r-3, r-1]
424428 // where r is the rank of y.
425- const perm = Array . from ( { length : y . rank } , ( _ , i ) => {
429+ const perm = Array . from ( { length : b . rank } , ( _ , i ) => {
426430 if ( i === 0 ) {
427- return y . rank - 2 ;
428- } else if ( i <= y . rank - 2 ) {
431+ return b . rank - 2 ;
432+ } else if ( i <= b . rank - 2 ) {
429433 return i - 1 ;
430434 }
431435 return i ;
432436 } ) ;
433- y = y . transpose ( perm ) . reshape ( [ ySecondLastDim , - 1 ] ) ;
437+ b = b . transpose ( perm ) . reshape ( [ ySecondLastDim , - 1 ] ) ;
434438
435439 // Multiply x and y as 2D Tensors, and then reshape back to original.
436- const outputShape = [ ...xFirstDims , ...yOtherDims ] ;
437- const transposeX = false ;
438- const transposeY = false ;
440+ const outputShape = [ ...aFirstDims , ...yOtherDims ] ;
441+ const transposeA = false ;
442+ const transposeB = false ;
439443 return tfc . fused
440- . matMul (
441- x as Tensor2D , y as Tensor2D , transposeX , transposeY ,
442- bias ? reshapeBias ( x . rank , bias , imageDataFormat ( ) ) : null ,
443- fusedActivation )
444+ . matMul ( {
445+ a,
446+ b,
447+ transposeA,
448+ transposeB,
449+ bias : bias ? reshapeBias ( a . rank , bias , imageDataFormat ( ) ) : null ,
450+ activation
451+ } )
444452 . reshape ( outputShape ) ;
445453 }
446454}
@@ -522,8 +530,8 @@ export function square(x: Tensor): Tensor {
522530 * Element-wise exponentiation.
523531 *
524532 * Porting Note: In PyKeras, `a` (the exponent) is a Python integer, which
525- * takes advatnage of the backend's (e.g., TensorFlow's) automatic conversion
526- * to tensor. Here we allow `a` to be either a number or a tensor.
533+ * takes advatnage of the backend's (e.g., TensorFlow's) automatic
534+ * conversion to tensor. Here we allow `a` to be either a number or a tensor.
527535 *
528536 * @param x The base tensor.
529537 * @param a The exponent, tensor or number. If a number, it is rounded to the
@@ -688,8 +696,9 @@ export function hardSigmoid(x: Tensor): Tensor {
688696/**
689697 * Invoke `x` in the training phase, and `alt` otherwise.
690698 *
691- * Porting Note: We do not create placeholder tensors for the `training` boolean
692- * flag here, because there is no such thing in the TF.js imperative backend.
699+ * Porting Note: We do not create placeholder tensors for the `training`
700+ * boolean flag here, because there is no such thing in the TF.js imperative
701+ * backend.
693702 *
694703 * @param x The function to invoke iff `training` is `true`.
695704 * @param alt The function to invoke iff `training` is `false`.
0 commit comments