@@ -1558,40 +1558,52 @@ export class MathBackendCPU implements KernelBackend {
15581558 const dilationWidth = convInfo . dilationWidth ;
15591559 const padLeft = convInfo . padInfo . left ;
15601560 const padTop = convInfo . padInfo . top ;
1561+ const isChannelsLast = convInfo . dataFormat === 'channelsLast' ;
1562+
15611563 const y = ops . buffer ( convInfo . outShape , x . dtype as 'float32' ) ;
15621564
1565+ const xBatchStride = x . strides [ 0 ] ;
1566+ const xRowStride = isChannelsLast ? x . strides [ 1 ] : x . strides [ 2 ] ;
1567+ const xColStride = isChannelsLast ? x . strides [ 2 ] : 1 ;
1568+ const xChannelStride = isChannelsLast ? 1 : x . strides [ 1 ] ;
1569+ const yBatchStride = y . strides [ 0 ] ;
1570+ const yRowStride = isChannelsLast ? y . strides [ 1 ] : y . strides [ 2 ] ;
1571+ const yColStride = isChannelsLast ? y . strides [ 2 ] : 1 ;
1572+ const yChannelStride = isChannelsLast ? 1 : y . strides [ 1 ] ;
1573+
15631574 const xVals = this . readSync ( x . dataId ) as TypedArray ;
15641575 const wVals = this . readSync ( filter . dataId ) as TypedArray ;
15651576 const yVals = y . values ;
15661577
15671578 for ( let b = 0 ; b < convInfo . batchSize ; ++ b ) {
1568- const xOffset1 = b * x . strides [ 0 ] ;
1569- const yOffset1 = b * y . strides [ 0 ] ;
1579+ const xOffset1 = b * xBatchStride ;
1580+ const yOffset1 = b * yBatchStride ;
15701581 for ( let yR = 0 ; yR < convInfo . outHeight ; ++ yR ) {
1571- const yOffset2 = yOffset1 + yR * y . strides [ 1 ] ;
1582+ const yOffset2 = yOffset1 + yR * yRowStride ;
15721583 const xRCorner = yR * convInfo . strideHeight - padTop ;
15731584 for ( let wR = 0 ; wR < filterHeight ; wR ++ ) {
15741585 const xR = xRCorner + wR * dilationHeight ;
15751586 if ( xR < 0 || xR >= convInfo . inHeight ) {
15761587 continue ;
15771588 }
15781589 const wOffset1 = wR * filter . strides [ 0 ] ;
1579- const xOffset2 = xOffset1 + xR * x . strides [ 1 ] ;
1590+ const xOffset2 = xOffset1 + xR * xRowStride ;
15801591 for ( let yC = 0 ; yC < convInfo . outWidth ; ++ yC ) {
1581- const yOffset3 = yOffset2 + yC * convInfo . outChannels ;
1592+ const yOffset3 = yOffset2 + yC * yColStride ;
15821593 const xCCorner = yC * convInfo . strideWidth - padLeft ;
15831594 for ( let wC = 0 ; wC < filterWidth ; wC ++ ) {
15841595 const xC = xCCorner + wC * dilationWidth ;
15851596 if ( xC < 0 || xC >= convInfo . inWidth ) {
15861597 continue ;
15871598 }
15881599 const wOffset2 = wOffset1 + wC * filter . strides [ 1 ] ;
1589- const xOffset3 = xOffset2 + xC * convInfo . inChannels ;
1600+ const xOffset3 = xOffset2 + xC * xColStride ;
15901601 let wOffset3 = wOffset2 ;
15911602 for ( let d1 = 0 ; d1 < convInfo . inChannels ; ++ d1 ) {
1592- const xVal = xVals [ xOffset3 + d1 ] ;
1603+ const xVal = xVals [ xOffset3 + d1 * xChannelStride ] ;
15931604 for ( let d2 = 0 ; d2 < convInfo . outChannels ; ++ d2 ) {
1594- yVals [ yOffset3 + d2 ] += xVal * wVals [ wOffset3 + d2 ] ;
1605+ yVals [ yOffset3 + d2 * yChannelStride ] +=
1606+ xVal * wVals [ wOffset3 + d2 ] ;
15951607 }
15961608 wOffset3 += convInfo . outChannels ;
15971609 }
@@ -1677,9 +1689,7 @@ export class MathBackendCPU implements KernelBackend {
16771689
16781690 const dx = ops . buffer < Rank . R4 > ( convInfo . inShape , 'float32' ) ;
16791691 const dxValues = dx . values ;
1680- const [ dxS0 , dxS1 , dxS2 ] = dx . strides ;
16811692 const dyValues = this . readSync ( dy . dataId ) as TypedArray ;
1682- const [ dyS0 , dyS1 , dyS2 ] = dy . strides ;
16831693 const fltValues = this . readSync ( filter . dataId ) as TypedArray ;
16841694 const [ fltS0 , fltS1 , fltS2 ] = filter . strides ;
16851695 const {
@@ -1693,11 +1703,22 @@ export class MathBackendCPU implements KernelBackend {
16931703 outHeight,
16941704 outWidth,
16951705 strideHeight,
1696- strideWidth
1706+ strideWidth,
1707+ dataFormat
16971708 } = convInfo ;
16981709 const topPad = filterHeight - 1 - convInfo . padInfo . top ;
16991710 const leftPad = filterWidth - 1 - convInfo . padInfo . left ;
17001711
1712+ const isChannelsLast = dataFormat === 'channelsLast' ;
1713+ const xBatchStride = dx . strides [ 0 ] ;
1714+ const xRowStride = isChannelsLast ? dx . strides [ 1 ] : dx . strides [ 2 ] ;
1715+ const xColStride = isChannelsLast ? dx . strides [ 2 ] : 1 ;
1716+ const xChannelStride = isChannelsLast ? 1 : dx . strides [ 1 ] ;
1717+ const yBatchStride = dy . strides [ 0 ] ;
1718+ const yRowStride = isChannelsLast ? dy . strides [ 1 ] : dy . strides [ 2 ] ;
1719+ const yColStride = isChannelsLast ? dy . strides [ 2 ] : 1 ;
1720+ const yChannelStride = isChannelsLast ? 1 : dy . strides [ 1 ] ;
1721+
17011722 for ( let b = 0 ; b < batchSize ; ++ b ) {
17021723 for ( let d1 = 0 ; d1 < inChannels ; ++ d1 ) {
17031724 for ( let xR = 0 ; xR < inHeight ; ++ xR ) {
@@ -1718,18 +1739,21 @@ export class MathBackendCPU implements KernelBackend {
17181739
17191740 for ( let yC = xCMin ; yC < yCMax ; ++ yC ) {
17201741 const wC = yC * strideWidth - xCCorner ;
1721- const dyOffset = dyS0 * b + dyS1 * yR + dyS2 * yC ;
1742+ const dyOffset =
1743+ yBatchStride * b + yRowStride * yR + yColStride * yC ;
17221744 const fltOffset = fltS0 * ( filterHeight - 1 - wR ) +
17231745 fltS1 * ( filterWidth - 1 - wC ) + fltS2 * d1 ;
17241746
17251747 for ( let d2 = 0 ; d2 < outChannels ; ++ d2 ) {
1726- const pixel = dyValues [ dyOffset + d2 ] ;
1748+ const pixel = dyValues [ dyOffset + yChannelStride * d2 ] ;
17271749 const weight = fltValues [ fltOffset + d2 ] ;
17281750 dotProd += pixel * weight ;
17291751 }
17301752 }
17311753 }
1732- dxValues [ dxS0 * b + dxS1 * xR + dxS2 * xC + d1 ] = dotProd ;
1754+ const dxOffset = xBatchStride * b + xRowStride * xR +
1755+ xColStride * xC + xChannelStride * d1 ;
1756+ dxValues [ dxOffset ] = dotProd ;
17331757 }
17341758 }
17351759 }
@@ -1829,6 +1853,7 @@ export class MathBackendCPU implements KernelBackend {
18291853 const strideWidth = convInfo . strideWidth ;
18301854 const filterHeight = convInfo . filterHeight ;
18311855 const filterWidth = convInfo . filterWidth ;
1856+ const isChannelsLast = convInfo . dataFormat === 'channelsLast' ;
18321857 const dW = ops . buffer < Rank . R4 > ( convInfo . filterShape , 'float32' ) ;
18331858
18341859 const leftPad = convInfo . padInfo . left ;
@@ -1854,7 +1879,13 @@ export class MathBackendCPU implements KernelBackend {
18541879 const xR = wR + yR * strideHeight - topPad ;
18551880 for ( let yC = yCMin ; yC < yCMax ; ++ yC ) {
18561881 const xC = wC + yC * strideWidth - leftPad ;
1857- dotProd += xBuf . get ( b , xR , xC , d1 ) * dyBuf . get ( b , yR , yC , d2 ) ;
1882+ if ( isChannelsLast ) {
1883+ dotProd +=
1884+ xBuf . get ( b , xR , xC , d1 ) * dyBuf . get ( b , yR , yC , d2 ) ;
1885+ } else {
1886+ dotProd +=
1887+ xBuf . get ( b , d1 , xR , xC ) * dyBuf . get ( b , d2 , yR , yC ) ;
1888+ }
18581889 }
18591890 }
18601891 }
0 commit comments