@@ -19,9 +19,10 @@ import {ENGINE} from '../engine';
1919import { Tensor } from '../tensor' ;
2020import { convertToTensor } from '../tensor_util_env' ;
2121import { TensorLike } from '../types' ;
22+
2223import { op } from './operation' ;
2324import { slice } from './slice' ;
24- import { getStridedSlicedInfo } from './slice_util' ;
25+ import { computeOutShape , maskToAxes , startForAxis , stopForAxis } from './slice_util' ;
2526
2627/**
2728 * Extracts a strided slice of a tensor.
@@ -56,30 +57,53 @@ import {getStridedSlicedInfo} from './slice_util';
5657 */
5758/** @doc {heading: 'Operations', subheading: 'Slicing and Joining'} */
5859function stridedSlice_ (
59- x : Tensor | TensorLike , begin : number [ ] , end : number [ ] , strides : number [ ] ,
60+ x : Tensor | TensorLike , begin : number [ ] , end : number [ ] , strides ? : number [ ] ,
6061 beginMask = 0 , endMask = 0 , ellipsisMask = 0 , newAxisMask = 0 ,
6162 shrinkAxisMask = 0 ) : Tensor {
63+ if ( strides == null ) {
64+ strides = new Array ( begin . length ) ;
65+ }
6266 if ( ellipsisMask !== 0 ) {
6367 throw new Error ( 'ellipsis mask is not yet supported' ) ;
6468 }
65- if ( newAxisMask !== 0 ) {
66- throw new Error ( 'new axis mask is not yet supported' ) ;
69+ let $x = convertToTensor ( x , 'x' , 'stridedSlice' ) ;
70+
71+ // Expand the dims of x based on the newAxisMask.
72+ const expandAxes = maskToAxes ( newAxisMask ) ;
73+ const newShape = $x . shape . slice ( ) ;
74+ expandAxes . forEach ( axis => {
75+ begin [ axis ] = 0 ;
76+ end [ axis ] = 1 ;
77+ newShape . splice ( axis , 0 , 1 ) ;
78+ } ) ;
79+ $x = $x . reshape ( newShape ) ;
80+
81+ // Normalize the start, end and strides.
82+ for ( let axis = 0 ; axis < $x . rank ; axis ++ ) {
83+ begin [ axis ] = startForAxis ( beginMask , begin , strides , $x . shape , axis ) ;
84+ end [ axis ] = stopForAxis ( endMask , end , strides , $x . shape , axis ) ;
85+ strides [ axis ] = strides [ axis ] || 1 ;
6786 }
68- const $x = convertToTensor ( x , 'x' , 'stridedSlice' ) ;
87+
88+ const shrinkAxes = maskToAxes ( shrinkAxisMask ) ;
89+ // Adjust the ends based on the shrink mask.
90+ shrinkAxes . forEach ( axis => {
91+ end [ axis ] = begin [ axis ] + 1 ;
92+ strides [ axis ] = 1 ;
93+ } ) ;
94+
95+ // Figure out the output shape.
96+ const size = computeOutShape ( begin , end , strides ) ;
97+ // Remove the axes based on shrinkMask.
98+ const outShape = size . filter ( ( _ , axis ) => shrinkAxes . indexOf ( axis ) === - 1 ) ;
99+
69100 const nonStrided = strides . every ( v => v === 1 ) ;
70101 if ( nonStrided ) {
71- const [ beginIndex , size , shrinkAxis ] = getStridedSlicedInfo (
72- $x . shape , begin , end , strides , beginMask , endMask , ellipsisMask ,
73- newAxisMask , shrinkAxisMask ) ;
74- const outShape =
75- size . filter ( ( _ , index ) => shrinkAxis . indexOf ( index ) === - 1 ) ;
76- return slice ( $x , beginIndex , size ) . reshape ( outShape ) ;
102+ return slice ( $x , begin , size ) . reshape ( outShape ) ;
77103 }
78- return ENGINE . runKernel (
79- backend => backend . stridedSlice (
80- $x , begin , end , strides , beginMask , endMask , ellipsisMask ,
81- newAxisMask , shrinkAxisMask ) ,
82- { $x} ) ;
104+ const res = ENGINE . runKernel (
105+ backend => backend . stridedSlice ( $x , begin , end , strides ) , { $x} ) ;
106+ return res . reshape ( outShape ) ;
83107}
84108
85109export const stridedSlice = op ( { stridedSlice_} ) ;
0 commit comments