Skip to content

Commit ff3f9f3

Browse files
Merge pull request #21 from RonasIT/14-load-data-asynchronously-in-the-example-application
Load data asynchronously in the example application
2 parents 407ecfa + 628ce02 commit ff3f9f3

File tree

2 files changed

+110
-4
lines changed

2 files changed

+110
-4
lines changed

packages/tfjs-node-helpers-example/src/app/app.ts

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
import { BinaryClassificationTrainer, BinaryClassifier } from '@ronas-it/tfjs-node-helpers';
2-
import { layers } from '@tensorflow/tfjs-node';
1+
import { BinaryClassificationTrainer, BinaryClassifier, makeChunkedDataset } from '@ronas-it/tfjs-node-helpers';
2+
import { data, layers, TensorContainer } from '@tensorflow/tfjs-node';
33
import { AgeFeatureExtractor } from './feature-extractors/age';
44
import { AnnualSalaryFeatureExtractor } from './feature-extractors/annual-salary';
55
import { GenderFeatureExtractor } from './feature-extractors/gender';
66
import { OwnsTheCarFeatureExtractor } from './feature-extractors/owns-the-car';
77
import { join } from 'node:path';
8-
import data from '../assets/data.json';
8+
import { TrainingDataService } from './services/training-data';
99

1010
export async function startApplication(): Promise<void> {
1111
await train();
@@ -26,8 +26,43 @@ async function train(): Promise<void> {
2626
outputFeatureExtractor: new OwnsTheCarFeatureExtractor()
2727
});
2828

29+
const trainingDataService = new TrainingDataService({
30+
simulatedDelayMs: 100
31+
});
32+
33+
await trainingDataService.initialize();
34+
35+
const [validationSamplesCount, testingSamplesCount] = await Promise.all([
36+
trainingDataService.getValidationSamplesCount(),
37+
trainingDataService.getTestingSamplesCount()
38+
]);
39+
40+
const makeTrainingDataset = (): data.Dataset<TensorContainer> => makeChunkedDataset({
41+
loadChunk: (skip, take) => trainingDataService.getTrainingSamples(skip, take),
42+
chunkSize: 32,
43+
batchSize: 32
44+
});
45+
46+
const makeValidationDataset = (): data.Dataset<TensorContainer> => makeChunkedDataset({
47+
loadChunk: (skip, take) => trainingDataService.getValidationSamples(skip, take),
48+
chunkSize: 32,
49+
batchSize: validationSamplesCount
50+
});
51+
52+
const makeTestingDataset = (): data.Dataset<TensorContainer> => makeChunkedDataset({
53+
loadChunk: (skip, take) => trainingDataService.getTestingSamples(skip, take),
54+
chunkSize: 32,
55+
batchSize: testingSamplesCount
56+
});
57+
58+
const trainingDataset = makeTrainingDataset();
59+
const validationDataset = makeValidationDataset();
60+
const testingDataset = makeTestingDataset();
61+
2962
await trainer.trainAndTest({
30-
data,
63+
trainingDataset,
64+
validationDataset,
65+
testingDataset,
3166
printTestingResults: true
3267
});
3368

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import { extractFeatures, Sample, splitSamplesIntoTrainingValidationTestForBinaryClassification } from '@ronas-it/tfjs-node-helpers';
2+
import { AgeFeatureExtractor } from '../feature-extractors/age';
3+
import { AnnualSalaryFeatureExtractor } from '../feature-extractors/annual-salary';
4+
import { GenderFeatureExtractor } from '../feature-extractors/gender';
5+
import { OwnsTheCarFeatureExtractor } from '../feature-extractors/owns-the-car';
6+
import dataset from '../../assets/data.json';
7+
8+
export class TrainingDataService {
9+
private simulatedDelayMs: number;
10+
private trainingSamples: Array<Sample>;
11+
private validationSamples: Array<Sample>;
12+
private testingSamples: Array<Sample>;
13+
14+
constructor({ simulatedDelayMs }: { simulatedDelayMs: number }) {
15+
this.simulatedDelayMs = simulatedDelayMs;
16+
}
17+
18+
public async initialize(): Promise<void> {
19+
const samples = await extractFeatures({
20+
data: dataset,
21+
inputFeatureExtractors: [
22+
new AgeFeatureExtractor(),
23+
new AnnualSalaryFeatureExtractor(),
24+
new GenderFeatureExtractor()
25+
],
26+
outputFeatureExtractor: new OwnsTheCarFeatureExtractor()
27+
});
28+
29+
const { trainingSamples, validationSamples, testingSamples } = splitSamplesIntoTrainingValidationTestForBinaryClassification(samples);
30+
31+
this.trainingSamples = trainingSamples;
32+
this.validationSamples = validationSamples;
33+
this.testingSamples = testingSamples;
34+
}
35+
36+
public async getTrainingSamples(skip: number, take: number): Promise<Array<Sample>> {
37+
await this.simulateDelay();
38+
39+
return this.trainingSamples.slice(skip, skip + take);
40+
}
41+
42+
public async getValidationSamples(skip: number, take: number): Promise<Array<Sample>> {
43+
await this.simulateDelay();
44+
45+
return this.validationSamples.slice(skip, skip + take);
46+
}
47+
48+
public async getTestingSamples(skip: number, take: number): Promise<Array<Sample>> {
49+
await this.simulateDelay();
50+
51+
return this.testingSamples.slice(skip, skip + take);
52+
}
53+
54+
public async getValidationSamplesCount(): Promise<number> {
55+
await this.simulateDelay();
56+
57+
return this.validationSamples.length;
58+
}
59+
60+
public async getTestingSamplesCount(): Promise<number> {
61+
await this.simulateDelay();
62+
63+
return this.testingSamples.length;
64+
}
65+
66+
private async simulateDelay(): Promise<void> {
67+
if (this.simulatedDelayMs > 0) {
68+
await new Promise((resolve) => setTimeout(resolve, this.simulatedDelayMs));
69+
}
70+
}
71+
}

0 commit comments

Comments
 (0)