Skip to content

Commit aa92053

Browse files
authored
[tfjs-core] Support newAxisMask in stridedSlice and make squeez… (#1829)
FEATURE BUG * Add support for the `newAxisMask` param in `stridedSlice()`. Because stridedSlice() is difficult to implement correctly, and all backends have to repeat fragile logic, shift common preprocess and postprocess logic to the logical stridedSlice op and simplify the actual kernel. * Make tf.squeeze(x, axis) consistent with TF when axis is empty array. These fixes were found by running a converted TF model in the browser.
1 parent 28923ef commit aa92053

File tree

9 files changed

+143
-107
lines changed

9 files changed

+143
-107
lines changed

tfjs-core/src/backends/backend.ts

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,7 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
141141
throw new Error('Not yet implemented');
142142
}
143143
stridedSlice<T extends Tensor>(
144-
x: T, begin: number[], end: number[], strides: number[],
145-
beginMask: number, endMask: number, ellipsisMask: number,
146-
newAxisMask: number, shrinkAxisMask: number): T {
144+
x: T, begin: number[], end: number[], strides: number[]): T {
147145
throw new Error('Not yet implemented');
148146
}
149147
unstack(x: Tensor, axis: number): Tensor[] {

tfjs-core/src/backends/cpu/backend_cpu.ts

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import * as ops from '../../ops/ops';
3333
import {buffer, scalar, tensor, tensor3d, tensor4d} from '../../ops/ops';
3434
import * as scatter_nd_util from '../../ops/scatter_nd_util';
3535
import * as selu_util from '../../ops/selu_util';
36-
import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../../ops/slice_util';
36+
import {computeFlatOffset, computeOutShape, isSliceContinous} from '../../ops/slice_util';
3737
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../../tensor';
3838
import {BackendValues, DataType, DataValues, NumericDataType, PixelData, Rank, ShapeMap, TypedArray, upcastType} from '../../types';
3939
import * as util from '../../util';
@@ -313,34 +313,28 @@ export class MathBackendCPU implements KernelBackend {
313313
}
314314

315315
stridedSlice<T extends Tensor>(
316-
x: T, begin: number[], end: number[], strides: number[],
317-
beginMask: number, endMask: number, ellipsisMask: number,
318-
newAxisMask: number, shrinkAxisMask: number): T {
316+
x: T, begin: number[], end: number[], strides: number[]): T {
319317
this.assertNotComplex(x, 'stridedSlice');
320318

321-
const [beginIndex, size, shrinkAxis] = getStridedSlicedInfo(
322-
x.shape, begin, end, strides, beginMask, endMask, ellipsisMask,
323-
newAxisMask, shrinkAxisMask);
319+
const outShape = computeOutShape(begin, end, strides);
324320

325-
const shape = size.filter((v, index) => shrinkAxis.indexOf(index) === -1);
326-
327-
if (shape.some(axis => axis === 0)) {
328-
return ops.tensor([], shape) as T;
321+
if (outShape.some(axis => axis === 0)) {
322+
return ops.tensor([], outShape) as T;
329323
}
330324

331-
const buffer = ops.buffer(size, x.dtype);
325+
const buffer = ops.buffer(outShape, x.dtype);
332326
const xBuf = this.bufferSync(x);
333327
for (let i = 0; i < buffer.size; i++) {
334328
const loc = buffer.indexToLoc(i);
335329

336330
const newLoc: number[] = new Array(loc.length);
337331
for (let j = 0; j < newLoc.length; j++) {
338-
newLoc[j] = loc[j] * strides[j] + beginIndex[j];
332+
newLoc[j] = loc[j] * strides[j] + begin[j];
339333
}
340334
buffer.set(xBuf.get(...newLoc), ...loc);
341335
}
342336

343-
return buffer.toTensor().reshape(shape) as T;
337+
return buffer.toTensor() as T;
344338
}
345339

346340
diag(x: Tensor): Tensor {

tfjs-core/src/backends/webgl/backend_webgl.ts

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import * as gather_nd_util from '../../ops/gather_nd_util';
3434
import * as reduce_util from '../../ops/reduce_util';
3535
import * as scatter_nd_util from '../../ops/scatter_nd_util';
3636
import * as segment_util from '../../ops/segment_util';
37-
import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../../ops/slice_util';
37+
import * as slice_util from '../../ops/slice_util';
3838
import {softmax} from '../../ops/softmax';
3939
import {range, scalar, tensor} from '../../ops/tensor_ops';
4040
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../../tensor';
@@ -729,7 +729,7 @@ export class MathBackendWebGL implements KernelBackend {
729729
return tensor([], size, x.dtype) as T;
730730
}
731731
const {isPacked} = this.texData.get(x.dataId);
732-
const isContinous = isSliceContinous(x.shape, begin, size);
732+
const isContinous = slice_util.isSliceContinous(x.shape, begin, size);
733733
if (isPacked || !isContinous) {
734734
const program = ENV.getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
735735
new SlicePackedProgram(size) :
@@ -749,7 +749,7 @@ export class MathBackendWebGL implements KernelBackend {
749749
Object.assign(newTexData, xTexData);
750750
newTexData.shape = size;
751751
newTexData.dtype = x.dtype;
752-
let flatOffset = computeFlatOffset(begin, x.strides);
752+
let flatOffset = slice_util.computeFlatOffset(begin, x.strides);
753753
if (xTexData.slice) {
754754
// We are slicing an already sliced tensor, so we have to accumulate
755755
// the offset.
@@ -769,26 +769,18 @@ export class MathBackendWebGL implements KernelBackend {
769769
}
770770

771771
stridedSlice<T extends Tensor>(
772-
x: T, begin: number[], end: number[], strides: number[],
773-
beginMask: number, endMask: number, ellipsisMask: number,
774-
newAxisMask: number, shrinkAxisMask: number): T {
772+
x: T, begin: number[], end: number[], strides: number[]): T {
775773
if (this.shouldExecuteOnCPU([x])) {
776-
return this.cpuBackend.stridedSlice(
777-
x, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask,
778-
shrinkAxisMask);
774+
return this.cpuBackend.stridedSlice(x, begin, end, strides);
779775
}
780776

781-
const [beginIndex, size, shrinkAxis] = getStridedSlicedInfo(
782-
x.shape, begin, end, strides, beginMask, endMask, ellipsisMask,
783-
newAxisMask, shrinkAxisMask);
777+
const outShape = slice_util.computeOutShape(begin, end, strides);
784778

785-
const shape = size.filter((v, index) => shrinkAxis.indexOf(index) === -1);
786-
if (shape.some(axis => axis === 0)) {
787-
return tensor([], shape) as T;
779+
if (outShape.some(axis => axis === 0)) {
780+
return tensor([], outShape) as T;
788781
}
789782

790-
const program =
791-
new StridedSliceProgram(beginIndex, strides, size, shrinkAxis);
783+
const program = new StridedSliceProgram(begin, strides, outShape);
792784
return this.compileAndRun(program, [x]);
793785
}
794786

tfjs-core/src/backends/webgl/strided_slice_gpu.ts

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,11 @@ export class StridedSliceProgram implements GPGPUProgram {
2424
userCode: string;
2525

2626
constructor(
27-
begin: number[], strides: number[], size: number[],
28-
shrinkAxis: number[]) {
29-
const shape = size.filter((v, index) => shrinkAxis.indexOf(index) === -1);
30-
this.outputShape = shape;
27+
begin: number[], strides: number[], size: number[]) {
28+
this.outputShape = size;
3129
const rank = size.length;
3230
const inputDtype = getCoordsDataType(size.length);
33-
const dtype = getCoordsDataType(shape.length);
31+
const dtype = getCoordsDataType(size.length);
3432

3533
let newCoords = '';
3634
if (rank === 1) {
@@ -39,14 +37,10 @@ export class StridedSliceProgram implements GPGPUProgram {
3937
let outputAxis = 0;
4038
newCoords =
4139
size.map((_, i) => {
42-
if (shrinkAxis.indexOf(i) === -1) {
43-
outputAxis++;
44-
return shape.length === 1 ?
45-
`coords * strides[${i}] + begin[${i}]` :
46-
`coords[${outputAxis - 1}] * strides[${i}] + begin[${i}]`;
47-
} else {
48-
return `begin[${i}]`;
49-
}
40+
outputAxis++;
41+
return size.length === 1 ?
42+
`coords * strides[${i}] + begin[${i}]` :
43+
`coords[${outputAxis - 1}] * strides[${i}] + begin[${i}]`;
5044
})
5145
.join(',');
5246
}

tfjs-core/src/ops/slice_util.ts

Lines changed: 19 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -38,49 +38,28 @@ export function assertParamsValid(
3838
}
3939
}
4040

41-
/**
42-
* Calculate the start index and output tensor shape for strided slice op.
43-
* @returns array of [startIndex, size, shrinkAxis]
44-
*/
45-
export function getStridedSlicedInfo(
46-
shape: number[], begin: number[], end: number[], strides: number[],
47-
beginMask = 0, endMask = 0, ellipsisMask = 0, newAxisMask = 0,
48-
shrinkAxisMask = 0): [number[], number[], number[]] {
49-
if (ellipsisMask !== 0) {
50-
throw new Error('ellipsis mask is not yet supported');
51-
}
52-
if (newAxisMask !== 0) {
53-
throw new Error('new axis mask is not yet supported');
54-
}
55-
// Note that the axis orders are reversed for runtime ops, so the indices,
56-
// strides and masks must be as well too.
57-
const startIndex: number[] = [];
58-
const endIndex: number[] = [];
59-
const shrinkAxis: number[] = [];
60-
for (let i = 0; i < shape.length; i++) {
61-
startIndex[i] = startForAxis(beginMask, begin, strides, shape, i);
62-
endIndex[i] = stopForAxis(endMask, end, strides, shape, i);
63-
// When shrinking an axis, use startIndex + 1 for endIndex.
64-
// Check the axis bit from right of shrinkAxisMask
65-
if (shrinkAxisMask & 1 << i) {
66-
endIndex[i] = startIndex[i] + 1;
67-
shrinkAxis.push(i);
41+
/** Converts a binary mask to an array of axes. Used in stridedSlice(). */
42+
export function maskToAxes(mask: number): number[] {
43+
const axes = [];
44+
let axis = 0;
45+
while (mask > 0) {
46+
if (mask & 1) {
47+
axes.push(axis);
6848
}
49+
mask /= 2;
50+
axis++;
6951
}
52+
return axes;
53+
}
7054

71-
let size = new Array(shape.length).fill(0);
72-
size = size.map((d, i) => {
73-
let count = 0;
74-
const stride = strides[i] || 1;
75-
for (let start = startIndex[i];
76-
!(stride > 0 ? start >= endIndex[i] : start <= endIndex[i]);
77-
start += stride) {
78-
count += 1;
79-
}
80-
return count;
81-
});
82-
83-
return [startIndex, size, shrinkAxis];
55+
/** Computes the output shape given the strided slice params. */
56+
export function computeOutShape(
57+
begin: number[], end: number[], strides: number[]): number[] {
58+
const size = [];
59+
for (let axis = 0; axis < begin.length; axis++) {
60+
size[axis] = Math.ceil((end[axis] - begin[axis]) / strides[axis]);
61+
}
62+
return size;
8463
}
8564

8665
export function startForAxis(

tfjs-core/src/ops/strided_slice.ts

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ import {ENGINE} from '../engine';
1919
import {Tensor} from '../tensor';
2020
import {convertToTensor} from '../tensor_util_env';
2121
import {TensorLike} from '../types';
22+
2223
import {op} from './operation';
2324
import {slice} from './slice';
24-
import {getStridedSlicedInfo} from './slice_util';
25+
import {computeOutShape, maskToAxes, startForAxis, stopForAxis} from './slice_util';
2526

2627
/**
2728
* Extracts a strided slice of a tensor.
@@ -56,30 +57,53 @@ import {getStridedSlicedInfo} from './slice_util';
5657
*/
5758
/** @doc {heading: 'Operations', subheading: 'Slicing and Joining'} */
5859
function stridedSlice_(
59-
x: Tensor|TensorLike, begin: number[], end: number[], strides: number[],
60+
x: Tensor|TensorLike, begin: number[], end: number[], strides?: number[],
6061
beginMask = 0, endMask = 0, ellipsisMask = 0, newAxisMask = 0,
6162
shrinkAxisMask = 0): Tensor {
63+
if (strides == null) {
64+
strides = new Array(begin.length);
65+
}
6266
if (ellipsisMask !== 0) {
6367
throw new Error('ellipsis mask is not yet supported');
6468
}
65-
if (newAxisMask !== 0) {
66-
throw new Error('new axis mask is not yet supported');
69+
let $x = convertToTensor(x, 'x', 'stridedSlice');
70+
71+
// Expand the dims of x based on the newAxisMask.
72+
const expandAxes = maskToAxes(newAxisMask);
73+
const newShape = $x.shape.slice();
74+
expandAxes.forEach(axis => {
75+
begin[axis] = 0;
76+
end[axis] = 1;
77+
newShape.splice(axis, 0, 1);
78+
});
79+
$x = $x.reshape(newShape);
80+
81+
// Normalize the start, end and strides.
82+
for (let axis = 0; axis < $x.rank; axis++) {
83+
begin[axis] = startForAxis(beginMask, begin, strides, $x.shape, axis);
84+
end[axis] = stopForAxis(endMask, end, strides, $x.shape, axis);
85+
strides[axis] = strides[axis] || 1;
6786
}
68-
const $x = convertToTensor(x, 'x', 'stridedSlice');
87+
88+
const shrinkAxes = maskToAxes(shrinkAxisMask);
89+
// Adjust the ends based on the shrink mask.
90+
shrinkAxes.forEach(axis => {
91+
end[axis] = begin[axis] + 1;
92+
strides[axis] = 1;
93+
});
94+
95+
// Figure out the output shape.
96+
const size = computeOutShape(begin, end, strides);
97+
// Remove the axes based on shrinkMask.
98+
const outShape = size.filter((_, axis) => shrinkAxes.indexOf(axis) === -1);
99+
69100
const nonStrided = strides.every(v => v === 1);
70101
if (nonStrided) {
71-
const [beginIndex, size, shrinkAxis] = getStridedSlicedInfo(
72-
$x.shape, begin, end, strides, beginMask, endMask, ellipsisMask,
73-
newAxisMask, shrinkAxisMask);
74-
const outShape =
75-
size.filter((_, index) => shrinkAxis.indexOf(index) === -1);
76-
return slice($x, beginIndex, size).reshape(outShape);
102+
return slice($x, begin, size).reshape(outShape);
77103
}
78-
return ENGINE.runKernel(
79-
backend => backend.stridedSlice(
80-
$x, begin, end, strides, beginMask, endMask, ellipsisMask,
81-
newAxisMask, shrinkAxisMask),
82-
{$x});
104+
const res = ENGINE.runKernel(
105+
backend => backend.stridedSlice($x, begin, end, strides), {$x});
106+
return res.reshape(outShape);
83107
}
84108

85109
export const stridedSlice = op({stridedSlice_});

tfjs-core/src/ops/strided_slice_test.ts

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,53 @@ import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
2020
import {expectArraysClose} from '../test_util';
2121

2222
describeWithFlags('stridedSlice', ALL_ENVS, () => {
23-
it('stridedSlice should fail if new axis mask is set', () => {
24-
const tensor = tf.tensor1d([0, 1, 2, 3]);
25-
expect(() => tf.stridedSlice(tensor, [0], [3], [2], 0, 0, 0, 1)).toThrow();
23+
it('stridedSlice with first axis being new', async () => {
24+
// Python slice code: t[tf.newaxis,0:3]
25+
const t = tf.tensor1d([0, 1, 2, 3]);
26+
const begin = [0, 0];
27+
const end = [1, 3];
28+
const strides = [1, 2];
29+
const beginMask = 0;
30+
const endMask = 0;
31+
const ellipsisMask = 0;
32+
const newAxisMask = 1;
33+
34+
const output = tf.stridedSlice(
35+
t, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask);
36+
expect(output.shape).toEqual([1, 2]);
37+
expectArraysClose(await output.data(), [0, 2]);
38+
});
39+
40+
it('strided slice with several new axes', () => {
41+
// Python slice code: t[1:2,tf.newaxis,0:3,tf.newaxis,2:5]
42+
const t = tf.zeros([2, 3, 4, 5]);
43+
const begin = [1, 0, 0, 0, 2];
44+
const end = [2, 1, 3, 1, 5];
45+
const strides: number[] = null;
46+
const beginMask = 0;
47+
const endMask = 0;
48+
const ellipsisMask = 0;
49+
const newAxisMask = 0b1010;
50+
const output = tf.stridedSlice(
51+
t, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask);
52+
expect(output.shape).toEqual([1, 1, 3, 1, 2, 5]);
53+
});
54+
55+
it('strided slice with new axes and shrink axes', () => {
56+
// Python slice code: t[1:2,tf.newaxis,1,tf.newaxis,2,2:5]
57+
const t = tf.zeros([2, 3, 4, 5]);
58+
const begin = [1, 0, 1, 0, 2, 2];
59+
const end = [2, 1, 2, 1, 3, 5];
60+
const strides: number[] = null;
61+
const beginMask = 0;
62+
const endMask = 0;
63+
const ellipsisMask = 0;
64+
const newAxisMask = 0b1010;
65+
const shrinkAxisMask = 0b10100;
66+
const output = tf.stridedSlice(
67+
t, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask,
68+
shrinkAxisMask);
69+
expect(output.shape).toEqual([1, 1, 1, 3]);
2670
});
2771

2872
it('stridedSlice should fail if ellipsis mask is set', () => {

tfjs-core/src/tensor_test.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1461,6 +1461,14 @@ describeWithFlags('tensor', ALL_ENVS, () => {
14611461
expect(res.shape).toEqual([0, 0]);
14621462
});
14631463

1464+
it('squeeze can take an empty list of axis', () => {
1465+
const a = tf.zeros([2, 1, 3, 1, 4]);
1466+
const axes: number[] = [];
1467+
// Empty axes list means all possible axes.
1468+
const res = tf.squeeze(a, axes);
1469+
expect(res.shape).toEqual([2, 3, 4]);
1470+
});
1471+
14641472
it('squeeze a complex64 tensor', async () => {
14651473
const a = tf.complex([[4], [1], [5]], [[2], [3], [6]]);
14661474
const b = a.squeeze();

tfjs-core/src/util.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,10 @@ export function squeezeShape(shape: number[], axis?: number[]):
349349
{newShape: number[], keptDims: number[]} {
350350
const newShape: number[] = [];
351351
const keptDims: number[] = [];
352-
const axes = axis == null ? null : parseAxisParam(axis, shape).sort();
352+
const isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0;
353+
const axes = (axis == null || isEmptyArray) ?
354+
null :
355+
parseAxisParam(axis, shape).sort();
353356
let j = 0;
354357
for (let i = 0; i < shape.length; ++i) {
355358
if (axes != null) {

0 commit comments

Comments
 (0)