@@ -54,21 +54,25 @@ export function findKNNGPUCosDistNorm<T>(
5454 . runAsyncTask (
5555 'Finding nearest neighbors...' ,
5656 async ( ) => {
57- const typedArray = vector . toTypedArray ( dataPoints , accessor ) ;
58- const bigMatrix = tf . tensor ( typedArray , [ N , dim ] ) ;
59- const bigMatrixTransposed = tf . transpose ( bigMatrix ) ;
60- // A * A^T.
61- const cosSimilarityMatrix = tf . matMul ( bigMatrix , bigMatrixTransposed ) ;
62- bigMatrix . dispose ( ) ;
63- bigMatrixTransposed . dispose ( ) ;
57+ const cosSimilarityMatrix = tf . tidy ( ( ) => {
58+ const typedArray = vector . toTypedArray ( dataPoints , accessor ) ;
59+ const bigMatrix = tf . tensor ( typedArray , [ N , dim ] ) ;
60+ const bigMatrixTransposed = tf . transpose ( bigMatrix ) ;
61+ // A * A^T.
62+ return tf . matMul ( bigMatrix , bigMatrixTransposed ) ;
63+ } ) ;
6464 // `.data()` returns flattened Float32Array of B * N dimension.
6565 // For matrix of
6666 // [ 1 2 ]
6767 // [ 3 4 ],
6868 // `.data()` returns [1, 2, 3, 4].
69- const partial = await cosSimilarityMatrix . data ( ) ;
70- // Discard all tensors and free up the memory.
71- cosSimilarityMatrix . dispose ( ) ;
69+ let partial ;
70+ try {
71+ partial = await cosSimilarityMatrix . data ( ) ;
72+ } finally {
73+ // Discard all tensors and free up the memory.
74+ cosSimilarityMatrix . dispose ( ) ;
75+ }
7276 for ( let i = 0 ; i < N ; i ++ ) {
7377 let kMin = new KMin < NearestEntry > ( k ) ;
7478 for ( let j = 0 ; j < N ; j ++ ) {
0 commit comments