@@ -21,11 +21,11 @@ import {expectArraysClose} from '../test_util';
2121
2222import { tensor1d , tensor2d , tensor3d } from './tensor_ops' ;
2323
24- describeWithFlags ( 'inTopK ' , ALL_ENVS , async ( ) => {
24+ describeWithFlags ( 'inTopKAsync ' , ALL_ENVS , async ( ) => {
2525 it ( 'predictions 2d array, targets 1d array, with default k' , async ( ) => {
2626 const predictions = tensor2d ( [ [ 20 , 10 , 40 , 30 ] , [ 30 , 50 , - 20 , 10 ] ] ) ;
2727 const targets = tensor1d ( [ 2 , 0 ] ) ;
28- const precision = tf . inTopK ( predictions , targets ) ;
28+ const precision = await tf . inTopKAsync ( predictions , targets ) ;
2929 expect ( precision . shape ) . toEqual ( [ 2 ] ) ;
3030 expect ( precision . dtype ) . toBe ( 'bool' ) ;
3131 expectArraysClose ( await precision . data ( ) , [ 1 , 0 ] ) ;
@@ -35,7 +35,7 @@ describeWithFlags('inTopK', ALL_ENVS, async () => {
3535 const predictions = tensor2d ( [ [ 20 , 10 , 40 , 30 ] , [ 30 , 50 , - 20 , 10 ] ] ) ;
3636 const targets = tensor1d ( [ 2 , 0 ] ) ;
3737 const k = 2 ;
38- const precision = tf . inTopK ( predictions , targets , k ) ;
38+ const precision = await tf . inTopKAsync ( predictions , targets , k ) ;
3939 expect ( precision . shape ) . toEqual ( [ 2 ] ) ;
4040 expect ( precision . dtype ) . toBe ( 'bool' ) ;
4141 expectArraysClose ( await precision . data ( ) , [ 1 , 1 ] ) ;
@@ -45,7 +45,7 @@ describeWithFlags('inTopK', ALL_ENVS, async () => {
4545 const predictions =
4646 tensor3d ( [ [ [ 1 , 5 , 2 ] , [ 4 , 3 , 6 ] ] , [ [ 3 , 2 , 1 ] , [ 1 , 2 , 3 ] ] ] ) ;
4747 const targets = tensor2d ( [ [ 1 , 2 ] , [ 0 , 1 ] ] ) ;
48- const precision = tf . inTopK ( predictions , targets ) ;
48+ const precision = await tf . inTopKAsync ( predictions , targets ) ;
4949 expect ( precision . shape ) . toEqual ( [ 2 , 2 ] ) ;
5050 expect ( precision . dtype ) . toBe ( 'bool' ) ;
5151 expectArraysClose ( await precision . data ( ) , [ 1 , 1 , 1 , 0 ] ) ;
@@ -56,7 +56,7 @@ describeWithFlags('inTopK', ALL_ENVS, async () => {
5656 tensor3d ( [ [ [ 1 , 5 , 2 ] , [ 4 , 3 , 6 ] ] , [ [ 3 , 2 , 1 ] , [ 1 , 2 , 3 ] ] ] ) ;
5757 const targets = tensor2d ( [ [ 1 , 2 ] , [ 0 , 1 ] ] ) ;
5858 const k = 2 ;
59- const precision = tf . inTopK ( predictions , targets , k ) ;
59+ const precision = await tf . inTopKAsync ( predictions , targets , k ) ;
6060 expect ( precision . shape ) . toEqual ( [ 2 , 2 ] ) ;
6161 expect ( precision . dtype ) . toBe ( 'bool' ) ;
6262 expectArraysClose ( await precision . data ( ) , [ 1 , 1 , 1 , 1 ] ) ;
@@ -66,13 +66,13 @@ describeWithFlags('inTopK', ALL_ENVS, async () => {
6666 const predictions = tensor2d ( [ [ 1 , 2 , 2 , 1 ] ] ) ;
6767
6868 const targets1 = tensor1d ( [ 1 ] ) ;
69- const precision1 = tf . inTopK ( predictions , targets1 ) ;
69+ const precision1 = await tf . inTopKAsync ( predictions , targets1 ) ;
7070 expect ( precision1 . shape ) . toEqual ( [ 1 ] ) ;
7171 expect ( precision1 . dtype ) . toBe ( 'bool' ) ;
7272 expectArraysClose ( await precision1 . data ( ) , [ 1 ] ) ;
7373
7474 const targets2 = tensor1d ( [ 2 ] ) ;
75- const precision2 = tf . inTopK ( predictions , targets2 ) ;
75+ const precision2 = await tf . inTopKAsync ( predictions , targets2 ) ;
7676 expect ( precision2 . shape ) . toEqual ( [ 1 ] ) ;
7777 expect ( precision2 . dtype ) . toBe ( 'bool' ) ;
7878 expectArraysClose ( await precision2 . data ( ) , [ 0 ] ) ;
@@ -81,28 +81,72 @@ describeWithFlags('inTopK', ALL_ENVS, async () => {
8181 it ( 'accept tensor-like object, with default k' , async ( ) => {
8282 const predictions = [ [ 20 , 10 , 40 , 30 ] , [ 30 , 50 , - 20 , 10 ] ] ;
8383 const targets = [ 2 , 0 ] ;
84- const precision = tf . inTopK ( predictions , targets ) ;
84+ const precision = await tf . inTopKAsync ( predictions , targets ) ;
8585 expect ( precision . shape ) . toEqual ( [ 2 ] ) ;
8686 expect ( precision . dtype ) . toBe ( 'bool' ) ;
8787 expectArraysClose ( await precision . data ( ) , [ 1 , 0 ] ) ;
8888 } ) ;
8989
90- it ( 'throws when predictions_rank <2' , ( ) => {
90+ it ( 'doesnt leak tensors with tensor-like objects' , async ( ) => {
91+ const numTensors = tf . memory ( ) . numTensors ;
92+
93+ const predictions = [ [ 20 , 10 , 40 , 30 ] , [ 30 , 50 , - 20 , 10 ] ] ;
94+ const targets = [ 2 , 0 ] ;
95+ const precision = await tf . inTopKAsync ( predictions , targets ) ;
96+ precision . dispose ( ) ;
97+
98+ expect ( tf . memory ( ) . numTensors ) . toBe ( numTensors ) ;
99+ } ) ;
100+
101+ it ( 'throws when predictions_rank <2' , async ( ) => {
91102 const predictions = tensor1d ( [ 20 , 10 , 40 , 30 ] ) ;
92103 const targets = [ 2 ] ;
93- expect ( ( ) => tf . inTopK ( predictions , targets ) ) . toThrowError ( ) ;
104+
105+ // expect(...).toThrowError() does not support async functions.
106+ // See https://github.com/jasmine/jasmine/issues/1410
107+ try {
108+ await tf . inTopKAsync ( predictions , targets ) ;
109+ throw new Error ( 'The line above should have thrown an error' ) ;
110+ } catch ( ex ) {
111+ expect ( ex . message )
112+ . toEqual (
113+ 'inTopK() expects the predictions to ' +
114+ 'be of rank 2 or higher, but got 1' ) ;
115+ }
94116 } ) ;
95117
96- it ( 'throws when prediction_rank != targets_rank + 1' , ( ) => {
118+ it ( 'throws when prediction.rank != targets.rank + 1' , async ( ) => {
97119 const predictions = tensor2d ( [ [ 20 , 10 , 40 , 30 ] , [ 30 , 50 , - 20 , 10 ] ] ) ;
98120 const targets = tensor2d ( [ [ 0 ] , [ 0 ] ] ) ;
99- expect ( ( ) => tf . inTopK ( predictions , targets ) ) . toThrowError ( ) ;
121+
122+ // expect(...).toThrowError() does not support async functions.
123+ // See https://github.com/jasmine/jasmine/issues/1410
124+ try {
125+ await tf . inTopKAsync ( predictions , targets ) ;
126+ throw new Error ( 'The line above should have thrown an error' ) ;
127+ } catch ( ex ) {
128+ expect ( ex . message )
129+ . toEqual (
130+ 'predictions rank should be 1 larger than targets rank,' +
131+ ' but got predictions rank 2 and targets rank 2' ) ;
132+ }
100133 } ) ;
101134
102- it ( 'throws when k > size of last dimension of predictions' , ( ) => {
135+ it ( 'throws when k > size of last dimension of predictions' , async ( ) => {
103136 const predictions = tensor2d ( [ [ 20 , 10 , 40 , 30 ] , [ 30 , 50 , - 20 , 10 ] ] ) ;
104137 const targets = tensor1d ( [ 2 , 0 ] ) ;
105138 const k = 5 ;
106- expect ( ( ) => tf . inTopK ( predictions , targets , k ) ) . toThrowError ( ) ;
139+
140+ // expect(...).toThrowError() does not support async functions.
141+ // See https://github.com/jasmine/jasmine/issues/1410
142+ try {
143+ await tf . inTopKAsync ( predictions , targets , k ) ;
144+ throw new Error ( 'The line above should have thrown an error' ) ;
145+ } catch ( ex ) {
146+ expect ( ex . message )
147+ . toEqual (
148+ '\'k\' passed to inTopK() must be > 0 && <= the predictions ' +
149+ 'last dimension (4), but got 5' ) ;
150+ }
107151 } ) ;
108152} ) ;
0 commit comments