22 * @jest -environment node
33 */
44
5- import SimpleTrainer from '@ml/utils/SimpleTrainer ' ;
5+ import KNNTrainer from '@ml/utils/KNNTrainer ' ;
66import * as tf from '@tensorflow/tfjs' ;
77
88describe ( 'Simple Trainer tests' , ( ) => {
9- test ( 'SimpleTrainer predicts' , async ( ) => {
10- const trainer = new SimpleTrainer ( ) ;
9+ test ( 'KNNTrainer predicts' , async ( ) => {
10+ const trainer = new KNNTrainer ( ) ;
1111 trainer . setTopK ( 3 ) ;
1212
1313 trainer . addTrainingExample ( tf . tensor ( [ 1 , 1 ] ) , 0 ) ;
@@ -24,8 +24,8 @@ describe('Simple Trainer tests', () => {
2424 trainer . dispose ( ) ;
2525 } ) ;
2626
27- test ( 'SimpleTrainer can be restored' , async ( ) => {
28- const trainer = new SimpleTrainer ( ) ;
27+ test ( 'KNNTrainer can be restored' , async ( ) => {
28+ const trainer = new KNNTrainer ( ) ;
2929 trainer . setTopK ( 3 ) ;
3030
3131 trainer . addTrainingExample ( tf . tensor ( [ 1 , 1 ] ) , 0 ) ;
@@ -42,7 +42,7 @@ describe('Simple Trainer tests', () => {
4242 const classifierDatasetString = trainer . getDatasetJSON ( ) ;
4343 trainer . clearAll ( ) ;
4444
45- const retrainedTrainer = new SimpleTrainer ( ) ;
45+ const retrainedTrainer = new KNNTrainer ( ) ;
4646 retrainedTrainer . setTopK ( 3 ) ;
4747 const untrainedResult = await retrainedTrainer . predict ( tf . tensor ( [ 1 , 1 ] ) ) ;
4848 expect ( untrainedResult . predictedClassId ) . toEqual ( null ) ;
0 commit comments