Skip to content

Commit 403f1d6

Browse files
authored
Merge pull request #23 from RonasIT/5-feature-normalization
2 parents 67f3228 + 77f733d commit 403f1d6

File tree

13 files changed

+139
-23
lines changed

13 files changed

+139
-23
lines changed

packages/tfjs-node-helpers-example/src/app/app.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import { GenderFeatureExtractor } from './feature-extractors/gender';
66
import { OwnsTheCarFeatureExtractor } from './feature-extractors/owns-the-car';
77
import { join } from 'node:path';
88
import { TrainingDataService } from './services/training-data';
9+
import { AgeMinMaxFeatureNormalizer } from './feature-normalizers/age';
10+
import { AnnualSalaryMinMaxFeatureNormalizer } from './feature-normalizers/annual-salary';
911

1012
export async function startApplication(): Promise<void> {
1113
await train();
@@ -23,7 +25,11 @@ async function train(): Promise<void> {
2325
new AnnualSalaryFeatureExtractor(),
2426
new GenderFeatureExtractor()
2527
],
26-
outputFeatureExtractor: new OwnsTheCarFeatureExtractor()
28+
outputFeatureExtractor: new OwnsTheCarFeatureExtractor(),
29+
inputFeatureNormalizers: [
30+
new AgeMinMaxFeatureNormalizer(),
31+
new AnnualSalaryMinMaxFeatureNormalizer()
32+
]
2733
});
2834

2935
const trainingDataService = new TrainingDataService({

packages/tfjs-node-helpers-example/src/app/feature-extractors/age.ts

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,10 @@ export class AgeFeatureExtractor extends FeatureExtractor<DatasetItem, FeatureTy
66
public featureType = FeatureType.AGE;
77

88
public extract(item: DatasetItem): Feature<FeatureType> {
9-
const minAge = 18;
10-
const maxAge = 63;
11-
129
return new Feature({
1310
type: this.featureType,
1411
label: `${item.age} years`,
15-
value: (item.age - minAge) / (maxAge - minAge)
12+
value: item.age
1613
});
1714
}
1815
}

packages/tfjs-node-helpers-example/src/app/feature-extractors/annual-salary.ts

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,10 @@ export class AnnualSalaryFeatureExtractor extends FeatureExtractor<DatasetItem,
66
public featureType = FeatureType.ANNUAL_SALARY;
77

88
public extract(item: DatasetItem): Feature<FeatureType> {
9-
const minAnnualSalary = 15000;
10-
const maxAnnualSalary = 152500;
11-
129
return new Feature({
1310
type: this.featureType,
1411
label: item.annual_salary.toString(),
15-
value: (item.annual_salary - minAnnualSalary) / (maxAnnualSalary - minAnnualSalary)
12+
value: item.annual_salary
1613
});
1714
}
1815
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import { MinMaxFeatureNormalizer } from '@ronas-it/tfjs-node-helpers';
2+
import { FeatureType } from '../enums/feature-type';
3+
4+
export class AgeMinMaxFeatureNormalizer extends MinMaxFeatureNormalizer<FeatureType> {
5+
public featureType = FeatureType.AGE;
6+
7+
constructor() {
8+
super({ min: 18, max: 63 });
9+
}
10+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import { MinMaxFeatureNormalizer } from '@ronas-it/tfjs-node-helpers';
2+
import { FeatureType } from '../enums/feature-type';
3+
4+
export class AnnualSalaryMinMaxFeatureNormalizer extends MinMaxFeatureNormalizer<FeatureType> {
5+
public featureType = FeatureType.ANNUAL_SALARY;
6+
7+
constructor() {
8+
super({ min: 15000, max: 152500 });
9+
}
10+
}

packages/tfjs-node-helpers-example/src/app/services/training-data.ts

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1-
import { extractFeatures, Sample, splitSamplesIntoTrainingValidationTestForBinaryClassification } from '@ronas-it/tfjs-node-helpers';
1+
import {
2+
extractFeatures,
3+
normalizeFeatures,
4+
Sample,
5+
splitSamplesIntoTrainingValidationTestForBinaryClassification
6+
} from '@ronas-it/tfjs-node-helpers';
27
import { AgeFeatureExtractor } from '../feature-extractors/age';
38
import { AnnualSalaryFeatureExtractor } from '../feature-extractors/annual-salary';
49
import { GenderFeatureExtractor } from '../feature-extractors/gender';
510
import { OwnsTheCarFeatureExtractor } from '../feature-extractors/owns-the-car';
611
import dataset from '../../assets/data.json';
12+
import { AgeMinMaxFeatureNormalizer } from '../feature-normalizers/age';
13+
import { AnnualSalaryMinMaxFeatureNormalizer } from '../feature-normalizers/annual-salary';
714

815
export class TrainingDataService {
916
private simulatedDelayMs: number;
@@ -16,7 +23,7 @@ export class TrainingDataService {
1623
}
1724

1825
public async initialize(): Promise<void> {
19-
const samples = await extractFeatures({
26+
const extracts = await extractFeatures({
2027
data: dataset,
2128
inputFeatureExtractors: [
2229
new AgeFeatureExtractor(),
@@ -26,6 +33,14 @@ export class TrainingDataService {
2633
outputFeatureExtractor: new OwnsTheCarFeatureExtractor()
2734
});
2835

36+
const samples = await normalizeFeatures({
37+
extracts,
38+
inputFeatureNormalizers: [
39+
new AgeMinMaxFeatureNormalizer(),
40+
new AnnualSalaryMinMaxFeatureNormalizer()
41+
]
42+
});
43+
2944
const { trainingSamples, validationSamples, testingSamples } = splitSamplesIntoTrainingValidationTestForBinaryClassification(samples);
3045

3146
this.trainingSamples = trainingSamples;

packages/tfjs-node-helpers/src/classification/binary-classification-trainer.ts

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@ import { prepareDatasetsForBinaryClassification } from '../feature-engineering/p
1919
import { ConfusionMatrix } from '../testing/confusion-matrix';
2020
import { Metrics } from '../testing/metrics';
2121
import { binarize } from '../utils/binarize';
22+
import { FeatureNormalizer } from '../feature-engineering/feature-normalizer';
2223

2324
export type BinaryClassificationTrainerOptions = {
2425
batchSize?: number;
2526
epochs?: number;
2627
patience?: number;
2728
inputFeatureExtractors?: Array<FeatureExtractor<any, any>>;
2829
outputFeatureExtractor?: FeatureExtractor<any, any>;
30+
inputFeatureNormalizers?: Array<FeatureNormalizer<any>>;
2931
model?: LayersModel;
3032
hiddenLayers?: Array<layers.Layer>;
3133
optimizer?: string | Optimizer;
@@ -39,6 +41,7 @@ export class BinaryClassificationTrainer {
3941
protected tensorBoardLogsDirectory?: string;
4042
protected inputFeatureExtractors?: Array<FeatureExtractor<any, any>>;
4143
protected outputFeatureExtractor?: FeatureExtractor<any, any>;
44+
protected inputFeatureNormalizers?: Array<FeatureNormalizer<any>>;
4245
protected model!: LayersModel;
4346

4447
protected static DEFAULT_BATCH_SIZE: number = 32;
@@ -52,6 +55,7 @@ export class BinaryClassificationTrainer {
5255
this.tensorBoardLogsDirectory = options.tensorBoardLogsDirectory;
5356
this.inputFeatureExtractors = options.inputFeatureExtractors;
5457
this.outputFeatureExtractor = options.outputFeatureExtractor;
58+
this.inputFeatureNormalizers = options.inputFeatureNormalizers;
5559

5660
this.initializeModel(options);
5761
}
@@ -63,7 +67,7 @@ export class BinaryClassificationTrainer {
6367
testingDataset,
6468
printTestingResults
6569
}: {
66-
data?: Array<any>,
70+
data?: Array<any>;
6771
trainingDataset?: data.Dataset<TensorContainer>;
6872
validationDataset?: data.Dataset<TensorContainer>;
6973
testingDataset?: data.Dataset<TensorContainer>;
@@ -90,14 +94,19 @@ export class BinaryClassificationTrainer {
9094
validationDataset === undefined ||
9195
testingDataset === undefined
9296
) {
93-
if (this.inputFeatureExtractors === undefined || this.outputFeatureExtractor === undefined) {
94-
throw new Error('trainingDataset, validationDataset and testingDataset are required when inputFeatureExtractors and outputFeatureExtractor are not provided!');
97+
if (
98+
this.inputFeatureExtractors === undefined ||
99+
this.outputFeatureExtractor === undefined ||
100+
this.inputFeatureNormalizers === undefined
101+
) {
102+
throw new Error('trainingDataset, validationDataset and testingDataset are required when inputFeatureExtractors, outputFeatureExtractor and inputFeatureNormalizers are not provided!');
95103
}
96104

97105
const datasets = await prepareDatasetsForBinaryClassification({
98106
data: data as Array<any>,
99107
inputFeatureExtractors: this.inputFeatureExtractors,
100108
outputFeatureExtractor: this.outputFeatureExtractor,
109+
inputFeatureNormalizers: this.inputFeatureNormalizers,
101110
batchSize: this.batchSize
102111
});
103112

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1-
import { Sample } from '../training/sample';
21
import { FeatureExtractor } from './feature-extractor';
2+
import { Feature } from './feature';
3+
4+
export type DataItemExtract<T> = {
5+
inputFeatures: Array<Feature<T>>;
6+
outputFeature: Feature<T>;
7+
};
38

49
export const extractFeatures = async <D, T>({
510
data,
@@ -9,8 +14,8 @@ export const extractFeatures = async <D, T>({
914
data: Array<D>;
1015
inputFeatureExtractors: Array<FeatureExtractor<D, T>>;
1116
outputFeatureExtractor: FeatureExtractor<D, T>;
12-
}): Promise<Array<Sample>> => {
13-
const samples = [];
17+
}): Promise<Array<DataItemExtract<T>>> => {
18+
const extracts = [];
1419

1520
for (const dataItem of data) {
1621
const [inputFeatures, outputFeature] = await Promise.all([
@@ -22,11 +27,8 @@ export const extractFeatures = async <D, T>({
2227
outputFeatureExtractor.extract(dataItem)
2328
]);
2429

25-
const input = inputFeatures.map((feature) => feature.value);
26-
const output = [outputFeature.value];
27-
28-
samples.push({ input, output });
30+
extracts.push({ inputFeatures, outputFeature });
2931
}
3032

31-
return samples;
33+
return extracts;
3234
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import { Feature } from './feature';
2+
3+
export abstract class FeatureNormalizer<T> {
4+
public abstract featureType: T;
5+
6+
public abstract normalize(feature: Feature<T>): Feature<T> | Promise<Feature<T>>;
7+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import { FeatureNormalizer } from './feature-normalizer';
2+
import { Feature } from './feature';
3+
4+
export abstract class MinMaxFeatureNormalizer<T> extends FeatureNormalizer<T> {
5+
private min: number;
6+
private max: number;
7+
8+
constructor({ min, max }: { min: number; max: number }) {
9+
super();
10+
11+
this.min = min;
12+
this.max = max;
13+
}
14+
15+
public normalize(feature: Feature<T>): Feature<T> | Promise<Feature<T>> {
16+
return new Feature({
17+
...feature,
18+
value: (feature.value - this.min) / (this.max - this.min)
19+
});
20+
}
21+
}

0 commit comments

Comments
 (0)