Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 6311b21

Browse files
author
Nikhil Thorat
authored
Fix unit test to not use tf.tidy in boolean mask test. (#1871)
DEV Also rename booleanMask to booleanMaskAsync.
1 parent cac5b15 commit 6311b21

File tree

2 files changed

+12
-18
lines changed

2 files changed

+12
-18
lines changed

src/ops/boolean_mask.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import {gather} from './segment_ops';
2929
* ```js
3030
* const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
3131
* const mask = tf.tensor1d([1, 0, 1], 'bool');
32-
* const result = await tf.booleanMask(tensor, mask);
32+
* const result = await tf.booleanMaskAsync(tensor, mask);
3333
* result.print();
3434
* ```
3535
*
@@ -40,7 +40,7 @@ import {gather} from './segment_ops';
4040
* Otherwise K + axis <= N.
4141
*/
4242
/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */
43-
async function booleanMask_(
43+
async function booleanMaskAsync_(
4444
tensor: Tensor|TensorLike, mask: Tensor|TensorLike,
4545
axis?: number): Promise<Tensor> {
4646
const $tensor = convertToTensor(tensor, 'tensor', 'boolMask');
@@ -84,4 +84,4 @@ async function booleanMask_(
8484
return res;
8585
}
8686

87-
export const booleanMask = booleanMask_;
87+
export const booleanMaskAsync = booleanMaskAsync_;

src/ops/boolean_mask_test.ts

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@
1717

1818
import * as tf from '../index';
1919
import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
20-
import {Tensor} from '../tensor';
2120
import {expectArraysClose} from '../test_util';
2221

23-
describeWithFlags('booleanMask', ALL_ENVS, () => {
22+
describeWithFlags('booleanMaskAsync', ALL_ENVS, () => {
2423
it('1d array, 1d mask, default axis', async () => {
2524
const array = tf.tensor1d([1, 2, 3]);
2625
const mask = tf.tensor1d([1, 0, 1], 'bool');
27-
const result = await tf.booleanMask(array, mask);
26+
const result = await tf.booleanMaskAsync(array, mask);
2827
expect(result.shape).toEqual([2]);
2928
expect(result.dtype).toBe('float32');
3029
expectArraysClose(await result.data(), [1, 3]);
@@ -33,7 +32,7 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
3332
it('2d array, 1d mask, default axis', async () => {
3433
const array = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
3534
const mask = tf.tensor1d([1, 0, 1], 'bool');
36-
const result = await tf.booleanMask(array, mask);
35+
const result = await tf.booleanMaskAsync(array, mask);
3736
expect(result.shape).toEqual([2, 2]);
3837
expect(result.dtype).toBe('float32');
3938
expectArraysClose(await result.data(), [1, 2, 5, 6]);
@@ -42,7 +41,7 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
4241
it('2d array, 2d mask, default axis', async () => {
4342
const array = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
4443
const mask = tf.tensor2d([1, 0, 1, 0, 1, 0], [3, 2], 'bool');
45-
const result = await tf.booleanMask(array, mask);
44+
const result = await tf.booleanMaskAsync(array, mask);
4645
expect(result.shape).toEqual([3]);
4746
expect(result.dtype).toBe('float32');
4847
expectArraysClose(await result.data(), [1, 3, 5]);
@@ -52,7 +51,7 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
5251
const array = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
5352
const mask = tf.tensor1d([0, 1], 'bool');
5453
const axis = 1;
55-
const result = await tf.booleanMask(array, mask, axis);
54+
const result = await tf.booleanMaskAsync(array, mask, axis);
5655
expect(result.shape).toEqual([3, 1]);
5756
expect(result.dtype).toBe('float32');
5857
expectArraysClose(await result.data(), [2, 4, 6]);
@@ -61,7 +60,7 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
6160
it('accepts tensor-like object as array or mask', async () => {
6261
const array = [[1, 2], [3, 4], [5, 6]];
6362
const mask = [1, 0, 1];
64-
const result = await tf.booleanMask(array, mask);
63+
const result = await tf.booleanMaskAsync(array, mask);
6564
expect(result.shape).toEqual([2, 2]);
6665
expect(result.dtype).toBe('float32');
6766
expectArraysClose(await result.data(), [1, 2, 5, 6]);
@@ -72,13 +71,8 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
7271

7372
const array = tf.tensor1d([1, 2, 3]);
7473
const mask = tf.tensor1d([1, 0, 1], 'bool');
75-
let resultPromise: Promise<Tensor> = null;
7674

77-
tf.tidy(() => {
78-
resultPromise = tf.booleanMask(array, mask);
79-
});
80-
81-
const result = await resultPromise;
75+
const result = await tf.booleanMaskAsync(array, mask);
8276
expect(result.shape).toEqual([2]);
8377
expect(result.dtype).toBe('float32');
8478
expectArraysClose(await result.data(), [1, 3]);
@@ -95,7 +89,7 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
9589
const mask = tf.scalar(1, 'bool');
9690
let errorMessage = 'No error thrown.';
9791
try {
98-
await tf.booleanMask(array, mask);
92+
await tf.booleanMaskAsync(array, mask);
9993
} catch (error) {
10094
errorMessage = error.message;
10195
}
@@ -107,7 +101,7 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
107101
const mask = tf.tensor2d([1, 0], [1, 2], 'bool');
108102
let errorMessage = 'No error thrown.';
109103
try {
110-
await tf.booleanMask(array, mask);
104+
await tf.booleanMaskAsync(array, mask);
111105
} catch (error) {
112106
errorMessage = error.message;
113107
}

0 commit comments

Comments
 (0)