1- import { BinaryClassificationTrainer , BinaryClassifier , makeChunkedDataset } from '@ronas-it/tfjs-node-helpers' ;
1+ import {
2+ AccuracyMetricCalculator ,
3+ BinaryClassificationTrainer ,
4+ BinaryClassifier ,
5+ F1ScoreMetricCalculator ,
6+ FNRMetricCalculator ,
7+ FPRMetricCalculator ,
8+ makeChunkedDataset ,
9+ PrecisionMetricCalculator ,
10+ RecallMetricCalculator ,
11+ SpecificityMetricCalculator
12+ } from '@ronas-it/tfjs-node-helpers' ;
213import { data , layers , TensorContainer } from '@tensorflow/tfjs-node' ;
314import { AgeFeatureExtractor } from './feature-extractors/age' ;
415import { AnnualSalaryFeatureExtractor } from './feature-extractors/annual-salary' ;
@@ -16,10 +27,7 @@ export async function startApplication(): Promise<void> {
1627
1728async function train ( ) : Promise < void > {
1829 const trainer = new BinaryClassificationTrainer ( {
19- hiddenLayers : [
20- layers . dense ( { units : 128 , activation : 'mish' } ) ,
21- layers . dense ( { units : 128 , activation : 'mish' } )
22- ] ,
30+ hiddenLayers : [ layers . dense ( { units : 128 , activation : 'mish' } ) , layers . dense ( { units : 128 , activation : 'mish' } ) ] ,
2331 inputFeatureExtractors : [
2432 new AgeFeatureExtractor ( ) ,
2533 new AnnualSalaryFeatureExtractor ( ) ,
@@ -29,6 +37,15 @@ async function train(): Promise<void> {
2937 inputFeatureNormalizers : [
3038 new AgeMinMaxFeatureNormalizer ( ) ,
3139 new AnnualSalaryMinMaxFeatureNormalizer ( )
40+ ] ,
41+ metricCalculators : [
42+ new AccuracyMetricCalculator ( ) ,
43+ new PrecisionMetricCalculator ( ) ,
44+ new F1ScoreMetricCalculator ( ) ,
45+ new SpecificityMetricCalculator ( ) ,
46+ new RecallMetricCalculator ( ) ,
47+ new FNRMetricCalculator ( ) ,
48+ new FPRMetricCalculator ( )
3249 ]
3350 } ) ;
3451
@@ -43,23 +60,26 @@ async function train(): Promise<void> {
4360 trainingDataService . getTestingSamplesCount ( )
4461 ] ) ;
4562
46- const makeTrainingDataset = ( ) : data . Dataset < TensorContainer > => makeChunkedDataset ( {
47- loadChunk : ( skip , take ) => trainingDataService . getTrainingSamples ( skip , take ) ,
48- chunkSize : 32 ,
49- batchSize : 32
50- } ) ;
63+ const makeTrainingDataset = ( ) : data . Dataset < TensorContainer > =>
64+ makeChunkedDataset ( {
65+ loadChunk : ( skip , take ) => trainingDataService . getTrainingSamples ( skip , take ) ,
66+ chunkSize : 32 ,
67+ batchSize : 32
68+ } ) ;
5169
52- const makeValidationDataset = ( ) : data . Dataset < TensorContainer > => makeChunkedDataset ( {
53- loadChunk : ( skip , take ) => trainingDataService . getValidationSamples ( skip , take ) ,
54- chunkSize : 32 ,
55- batchSize : validationSamplesCount
56- } ) ;
70+ const makeValidationDataset = ( ) : data . Dataset < TensorContainer > =>
71+ makeChunkedDataset ( {
72+ loadChunk : ( skip , take ) => trainingDataService . getValidationSamples ( skip , take ) ,
73+ chunkSize : 32 ,
74+ batchSize : validationSamplesCount
75+ } ) ;
5776
58- const makeTestingDataset = ( ) : data . Dataset < TensorContainer > => makeChunkedDataset ( {
59- loadChunk : ( skip , take ) => trainingDataService . getTestingSamples ( skip , take ) ,
60- chunkSize : 32 ,
61- batchSize : testingSamplesCount
62- } ) ;
77+ const makeTestingDataset = ( ) : data . Dataset < TensorContainer > =>
78+ makeChunkedDataset ( {
79+ loadChunk : ( skip , take ) => trainingDataService . getTestingSamples ( skip , take ) ,
80+ chunkSize : 32 ,
81+ batchSize : testingSamplesCount
82+ } ) ;
6383
6484 const trainingDataset = makeTrainingDataset ( ) ;
6585 const validationDataset = makeValidationDataset ( ) ;
0 commit comments