1515 * =============================================================================
1616 */
1717
18+ import * as tfc from '@tensorflow/tfjs-core' ;
1819import { BackendTimingInfo , DataMover , DataType , fill , KernelBackend , ones , Rank , rsqrt , Scalar , scalar , ShapeMap , Tensor , Tensor1D , tensor1d , Tensor2D , tensor2d , Tensor3D , tensor3d , Tensor4D , tidy , util } from '@tensorflow/tfjs-core' ;
1920import { EPSILON_FLOAT32 } from '@tensorflow/tfjs-core/dist/backends/backend' ;
2021import { Conv2DInfo , Conv3DInfo } from '@tensorflow/tfjs-core/dist/ops/conv_util' ;
21- import * as tfc from '@tensorflow/tfjs-core' ;
22-
2322import { Activation , FusedBatchMatMulConfig } from '@tensorflow/tfjs-core/dist/ops/fused_util' ;
2423import { Tensor5D } from '@tensorflow/tfjs-core/dist/tensor' ;
2524import { BackendValues , upcastType } from '@tensorflow/tfjs-core/dist/types' ;
26- import { isNullOrUndefined , isArray } from 'util' ;
25+ import { isArray , isNullOrUndefined } from 'util' ;
2726
2827import { Int64Scalar } from './int64_tensors' ;
2928import { TensorMetadata , TFEOpAttr , TFJSBinding } from './tfjs_binding' ;
@@ -288,30 +287,26 @@ export class NodeJSKernelBackend extends KernelBackend {
288287 }
289288
290289 stridedSlice < T extends Tensor > (
291- x : T , begin : number [ ] , end : number [ ] , strides : number [ ] ,
292- beginMask : number , endMask : number , ellipsisMask : number ,
293- newAxisMask : number , shrinkAxisMask : number ) : T {
290+ x : T , begin : number [ ] , end : number [ ] , strides : number [ ] ) : T {
294291 const beginTensor = tensor1d ( begin , 'int32' ) ;
292+ for ( let axis = 0 ; axis < end . length ; axis ++ ) {
293+ // Unlike Numpy, when the strides are negative, TF C uses -n-1 instead of
294+ // -1 as the "end" in order to include the first element.
295+ if ( strides [ axis ] < 0 && end [ axis ] === - 1 ) {
296+ end [ axis ] -= x . shape [ axis ] ;
297+ }
298+ }
295299 const endTensor = tensor1d ( end , 'int32' ) ;
296300 const stridesTensor = tensor1d ( strides , 'int32' ) ;
301+ // All of the masks have already been accounted for in the high level op,
302+ // so the backend does NOT need to deal with masks.
297303 const opAttrs = [
298304 createTypeOpAttr ( 'T' , x . dtype ) , createTypeOpAttr ( 'Index' , 'int32' ) ,
299- { name : 'begin_mask' , type : this . binding . TF_ATTR_INT , value : beginMask } ,
300- { name : 'end_mask' , type : this . binding . TF_ATTR_INT , value : endMask } , {
301- name : 'ellipsis_mask' ,
302- type : this . binding . TF_ATTR_INT ,
303- value : ellipsisMask
304- } ,
305- {
306- name : 'new_axis_mask' ,
307- type : this . binding . TF_ATTR_INT ,
308- value : newAxisMask
309- } ,
310- {
311- name : 'shrink_axis_mask' ,
312- type : this . binding . TF_ATTR_INT ,
313- value : shrinkAxisMask
314- }
305+ { name : 'begin_mask' , type : this . binding . TF_ATTR_INT , value : 0 } ,
306+ { name : 'end_mask' , type : this . binding . TF_ATTR_INT , value : 0 } ,
307+ { name : 'ellipsis_mask' , type : this . binding . TF_ATTR_INT , value : 0 } ,
308+ { name : 'new_axis_mask' , type : this . binding . TF_ATTR_INT , value : 0 } ,
309+ { name : 'shrink_axis_mask' , type : this . binding . TF_ATTR_INT , value : 0 }
315310 ] ;
316311 return this . executeSingleOutput (
317312 'StridedSlice' , opAttrs ,
@@ -359,6 +354,8 @@ export class NodeJSKernelBackend extends KernelBackend {
359354 result = this . relu ( result ) ;
360355 } else if ( activation === 'prelu' ) {
361356 result = this . prelu ( result , preluActivationWeights ) as Tensor4D ;
357+ } else if ( activation === 'elu' ) {
358+ result = this . elu ( result ) ;
362359 } else {
363360 throw new Error ( `Activation: ${
364361 activation } has not been implemented for the Node.js backend`) ;
@@ -384,6 +381,8 @@ export class NodeJSKernelBackend extends KernelBackend {
384381 result = this . relu ( result ) ;
385382 } else if ( activation === 'prelu' ) {
386383 result = this . prelu ( result , preluActivationWeights ) as Tensor3D ;
384+ } else if ( activation === 'elu' ) {
385+ result = this . elu ( result ) ;
387386 } else {
388387 throw new Error ( `Activation: ${
389388 activation } has not been implemented for the Node.js backend`) ;
@@ -2028,4 +2027,4 @@ export function ensureTensorflowBackend() {
20282027 tfc . getBackend ( ) === 'tensorflow' ,
20292028 ( ) => `Expect the current backend to be "tensorflow", but got "${
20302029 tfc . getBackend ( ) } "`) ;
2031- }
2030+ }
0 commit comments