1717
1818import * as tf from '../index' ;
1919import { ALL_ENVS , describeWithFlags } from '../jasmine_util' ;
20- import { Tensor } from '../tensor' ;
2120import { expectArraysClose } from '../test_util' ;
2221
23- describeWithFlags ( 'booleanMask ' , ALL_ENVS , ( ) => {
22+ describeWithFlags ( 'booleanMaskAsync ' , ALL_ENVS , ( ) => {
2423 it ( '1d array, 1d mask, default axis' , async ( ) => {
2524 const array = tf . tensor1d ( [ 1 , 2 , 3 ] ) ;
2625 const mask = tf . tensor1d ( [ 1 , 0 , 1 ] , 'bool' ) ;
27- const result = await tf . booleanMask ( array , mask ) ;
26+ const result = await tf . booleanMaskAsync ( array , mask ) ;
2827 expect ( result . shape ) . toEqual ( [ 2 ] ) ;
2928 expect ( result . dtype ) . toBe ( 'float32' ) ;
3029 expectArraysClose ( await result . data ( ) , [ 1 , 3 ] ) ;
@@ -33,7 +32,7 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
3332 it ( '2d array, 1d mask, default axis' , async ( ) => {
3433 const array = tf . tensor2d ( [ 1 , 2 , 3 , 4 , 5 , 6 ] , [ 3 , 2 ] ) ;
3534 const mask = tf . tensor1d ( [ 1 , 0 , 1 ] , 'bool' ) ;
36- const result = await tf . booleanMask ( array , mask ) ;
35+ const result = await tf . booleanMaskAsync ( array , mask ) ;
3736 expect ( result . shape ) . toEqual ( [ 2 , 2 ] ) ;
3837 expect ( result . dtype ) . toBe ( 'float32' ) ;
3938 expectArraysClose ( await result . data ( ) , [ 1 , 2 , 5 , 6 ] ) ;
@@ -42,7 +41,7 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
4241 it ( '2d array, 2d mask, default axis' , async ( ) => {
4342 const array = tf . tensor2d ( [ 1 , 2 , 3 , 4 , 5 , 6 ] , [ 3 , 2 ] ) ;
4443 const mask = tf . tensor2d ( [ 1 , 0 , 1 , 0 , 1 , 0 ] , [ 3 , 2 ] , 'bool' ) ;
45- const result = await tf . booleanMask ( array , mask ) ;
44+ const result = await tf . booleanMaskAsync ( array , mask ) ;
4645 expect ( result . shape ) . toEqual ( [ 3 ] ) ;
4746 expect ( result . dtype ) . toBe ( 'float32' ) ;
4847 expectArraysClose ( await result . data ( ) , [ 1 , 3 , 5 ] ) ;
@@ -52,7 +51,7 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
5251 const array = tf . tensor2d ( [ 1 , 2 , 3 , 4 , 5 , 6 ] , [ 3 , 2 ] ) ;
5352 const mask = tf . tensor1d ( [ 0 , 1 ] , 'bool' ) ;
5453 const axis = 1 ;
55- const result = await tf . booleanMask ( array , mask , axis ) ;
54+ const result = await tf . booleanMaskAsync ( array , mask , axis ) ;
5655 expect ( result . shape ) . toEqual ( [ 3 , 1 ] ) ;
5756 expect ( result . dtype ) . toBe ( 'float32' ) ;
5857 expectArraysClose ( await result . data ( ) , [ 2 , 4 , 6 ] ) ;
@@ -61,7 +60,7 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
6160 it ( 'accepts tensor-like object as array or mask' , async ( ) => {
6261 const array = [ [ 1 , 2 ] , [ 3 , 4 ] , [ 5 , 6 ] ] ;
6362 const mask = [ 1 , 0 , 1 ] ;
64- const result = await tf . booleanMask ( array , mask ) ;
63+ const result = await tf . booleanMaskAsync ( array , mask ) ;
6564 expect ( result . shape ) . toEqual ( [ 2 , 2 ] ) ;
6665 expect ( result . dtype ) . toBe ( 'float32' ) ;
6766 expectArraysClose ( await result . data ( ) , [ 1 , 2 , 5 , 6 ] ) ;
@@ -72,13 +71,8 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
7271
7372 const array = tf . tensor1d ( [ 1 , 2 , 3 ] ) ;
7473 const mask = tf . tensor1d ( [ 1 , 0 , 1 ] , 'bool' ) ;
75- let resultPromise : Promise < Tensor > = null ;
7674
77- tf . tidy ( ( ) => {
78- resultPromise = tf . booleanMask ( array , mask ) ;
79- } ) ;
80-
81- const result = await resultPromise ;
75+ const result = await tf . booleanMaskAsync ( array , mask ) ;
8276 expect ( result . shape ) . toEqual ( [ 2 ] ) ;
8377 expect ( result . dtype ) . toBe ( 'float32' ) ;
8478 expectArraysClose ( await result . data ( ) , [ 1 , 3 ] ) ;
@@ -95,7 +89,7 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
9589 const mask = tf . scalar ( 1 , 'bool' ) ;
9690 let errorMessage = 'No error thrown.' ;
9791 try {
98- await tf . booleanMask ( array , mask ) ;
92+ await tf . booleanMaskAsync ( array , mask ) ;
9993 } catch ( error ) {
10094 errorMessage = error . message ;
10195 }
@@ -107,7 +101,7 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
107101 const mask = tf . tensor2d ( [ 1 , 0 ] , [ 1 , 2 ] , 'bool' ) ;
108102 let errorMessage = 'No error thrown.' ;
109103 try {
110- await tf . booleanMask ( array , mask ) ;
104+ await tf . booleanMaskAsync ( array , mask ) ;
111105 } catch ( error ) {
112106 errorMessage = error . message ;
113107 }
0 commit comments