@@ -57,20 +57,18 @@ export function findKNNGPUCosDistNorm<T>(
5757 const typedArray = vector . toTypedArray ( dataPoints , accessor ) ;
5858 const bigMatrix = tf . tensor ( typedArray , [ N , dim ] ) ;
5959 const bigMatrixTransposed = tf . transpose ( bigMatrix ) ;
60- // 1 - A * A^T.
61- const bigMatrixSquared = tf . matMul ( bigMatrix , bigMatrixTransposed ) ;
62- const cosDistMatrix = tf . sub ( 1 , bigMatrixSquared ) ;
60+ // A * A^T.
61+ const cosSimilarityMatrix = tf . matMul ( bigMatrix , bigMatrixTransposed ) ;
62+ bigMatrix . dispose ( ) ;
63+ bigMatrixTransposed . dispose ( ) ;
6364 // `.data()` returns flattened Float32Array of B * N dimension.
6465 // For matrix of
6566 // [ 1 2 ]
6667 // [ 3 4 ],
6768 // `.data()` returns [1, 2, 3, 4].
68- const partial = await cosDistMatrix . data ( ) ;
69+ const partial = await cosSimilarityMatrix . data ( ) ;
6970 // Discard all tensors and free up the memory.
70- bigMatrix . dispose ( ) ;
71- bigMatrixTransposed . dispose ( ) ;
72- bigMatrixSquared . dispose ( ) ;
73- cosDistMatrix . dispose ( ) ;
71+ cosSimilarityMatrix . dispose ( ) ;
7472 for ( let i = 0 ; i < N ; i ++ ) {
7573 let kMin = new KMin < NearestEntry > ( k ) ;
7674 for ( let j = 0 ; j < N ; j ++ ) {
@@ -81,15 +79,15 @@ export function findKNNGPUCosDistNorm<T>(
8179 // Access i * N's row at `j` column.
8280 // Reach row has N entries and j-th index has cosine distance
8381 // between i-th vs. j-th vectors.
84- const cosDist = partial [ i * N + j ] ;
82+ const cosDist = 1 - partial [ i * N + j ] ;
8583 if ( cosDist >= 0 ) {
8684 kMin . add ( cosDist , { index : j , dist : cosDist } ) ;
8785 }
8886 }
8987 nearest [ i ] = kMin . getMinKItems ( ) ;
9088 }
9189 } ,
92- KNN_MSG_ID ,
90+ KNN_MSG_ID
9391 )
9492 . then (
9593 ( ) => {
@@ -163,7 +161,7 @@ export function findKNN<T>(
163161 logging . setModalMessage ( null ! , KNN_MSG_ID ) ;
164162 return nearest ;
165163 } ,
166- KNN_MSG_ID ,
164+ KNN_MSG_ID
167165 ) ;
168166}
169167/**
@@ -194,4 +192,3 @@ export function findKNNofPoint<T>(
194192 }
195193 return kMin . getMinKItems ( ) ;
196194}
197-
0 commit comments