Skip to content

Commit fdaf25c

Browse files
authored
Merge pull request #22 from RonasIT/vglinskii/custom-metrics
2 parents 7b01dc9 + 23aa03e commit fdaf25c

File tree

20 files changed

+292
-139
lines changed

20 files changed

+292
-139
lines changed

.eslintrc.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424
{
2525
"files": ["*.ts", "*.tsx"],
2626
"extends": ["plugin:@nrwl/nx/typescript"],
27-
"rules": {}
27+
"rules": {
28+
"@typescript-eslint/no-inferrable-types": "off",
29+
"semi": "error"
30+
}
2831
},
2932
{
3033
"files": ["*.js", "*.jsx"],

.prettierrc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
{
2-
"singleQuote": true
2+
"singleQuote": true,
3+
"trailingComma": "none",
4+
"printWidth": 120
35
}

README.md

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,40 @@ class GenderFeatureExtractor extends FeatureExtractor<DatasetItem, FeatureType>
6060
That's it! Now we can use the defined feature extractor to extract valuable
6161
information from our dataset.
6262

63+
### Metrics
64+
65+
After your model has been trained it's important to evaluate it.
66+
One way to do this is by analyzing *metrics*.
67+
The library helps measure model performance by passing a list of
68+
metric calculators to the model trainer.
69+
70+
We have a list of built-in metric calculators for popular metrics:
71+
- AccuracyMetricCalculator
72+
- PrecisionMetricCalculator
73+
- RecallMetricCalculator
74+
- SpecificityMetricCalculator
75+
- F1ScoreMetricCalculator
76+
- FNRMetricCalculator
77+
- FPRMetricCalculator
78+
79+
You can implement your own `MetricCalculator`. In the example below, we define
80+
a metric calculator for `precision`. For that we create a `PrecisionMetricCalculator`
81+
class extending the `MetricCalculator` base class provided by the library and
82+
implementing `calculate` method.
83+
84+
```typescript
85+
class PrecisionMetricCalculator extends MetricCalculator {
86+
public calculate(trueValues: Float32Array, predictedValues: Float32Array): Metric {
87+
const { tp, fp } = new ConfusionMatrix(trueValues, predictedValues);
88+
89+
return new Metric({
90+
title: 'Precision',
91+
value: tp / (tp + fp)
92+
});
93+
}
94+
}
95+
```
96+
6397
### Binary classification
6498

6599
This library provides two classes to train and evaluate binary classification
@@ -83,6 +117,7 @@ Before training the model, you need to create an instance of the
83117
that should be fed into the model as inputs.
84118
- `outputFeatureExtractor` – the feature extractor to extract information that
85119
we want to predict.
120+
- `metricCalculators` – a list of metric calculators that will be used during test stage.
86121

87122
An example can be found below:
88123

@@ -100,7 +135,13 @@ const trainer = new BinaryClassificationTrainer({
100135
new AnnualSalaryFeatureExtractor(),
101136
new GenderFeatureExtractor()
102137
],
103-
outputFeatureExtractor: new OwnsTheCarFeatureExtractor()
138+
outputFeatureExtractor: new OwnsTheCarFeatureExtractor(),
139+
metricCalculators: [
140+
new AccuracyMetricCalculator(),
141+
new PrecisionMetricCalculator(),
142+
new SpecificityMetricCalculator(),
143+
new FPRMetricCalculator()
144+
]
104145
});
105146
```
106147

@@ -225,7 +266,7 @@ const ownsTheCar = await classifier.predict([0.2, 0.76, 0]);
225266
- [ ] Uncertainty ([#15](https://github.com/RonasIT/tfjs-node-helpers/issues/15))
226267
- [ ] Handle class imbalance problem ([#10](https://github.com/RonasIT/tfjs-node-helpers/issues/10))
227268
- [ ] Add more metrics ([#17](https://github.com/RonasIT/tfjs-node-helpers/issues/17))
228-
- [ ] Custom metrics ([#18](https://github.com/RonasIT/tfjs-node-helpers/issues/18))
269+
- [x] Custom metrics ([#18](https://github.com/RonasIT/tfjs-node-helpers/issues/18))
229270
- [ ] Automated tests ([#6](https://github.com/RonasIT/tfjs-node-helpers/issues/6))
230271
- [ ] Continuous Integration ([#11](https://github.com/RonasIT/tfjs-node-helpers/issues/11))
231272
- [ ] Add more examples ([#8](https://github.com/RonasIT/tfjs-node-helpers/issues/8))

packages/tfjs-node-helpers-example/src/app/app.ts

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,15 @@
1-
import { BinaryClassificationTrainer, BinaryClassifier, makeChunkedDataset } from '@ronas-it/tfjs-node-helpers';
1+
import {
2+
AccuracyMetricCalculator,
3+
BinaryClassificationTrainer,
4+
BinaryClassifier,
5+
F1ScoreMetricCalculator,
6+
FNRMetricCalculator,
7+
FPRMetricCalculator,
8+
makeChunkedDataset,
9+
PrecisionMetricCalculator,
10+
RecallMetricCalculator,
11+
SpecificityMetricCalculator
12+
} from '@ronas-it/tfjs-node-helpers';
213
import { data, layers, TensorContainer } from '@tensorflow/tfjs-node';
314
import { AgeFeatureExtractor } from './feature-extractors/age';
415
import { AnnualSalaryFeatureExtractor } from './feature-extractors/annual-salary';
@@ -16,10 +27,7 @@ export async function startApplication(): Promise<void> {
1627

1728
async function train(): Promise<void> {
1829
const trainer = new BinaryClassificationTrainer({
19-
hiddenLayers: [
20-
layers.dense({ units: 128, activation: 'mish' }),
21-
layers.dense({ units: 128, activation: 'mish' })
22-
],
30+
hiddenLayers: [layers.dense({ units: 128, activation: 'mish' }), layers.dense({ units: 128, activation: 'mish' })],
2331
inputFeatureExtractors: [
2432
new AgeFeatureExtractor(),
2533
new AnnualSalaryFeatureExtractor(),
@@ -29,6 +37,15 @@ async function train(): Promise<void> {
2937
inputFeatureNormalizers: [
3038
new AgeMinMaxFeatureNormalizer(),
3139
new AnnualSalaryMinMaxFeatureNormalizer()
40+
],
41+
metricCalculators: [
42+
new AccuracyMetricCalculator(),
43+
new PrecisionMetricCalculator(),
44+
new F1ScoreMetricCalculator(),
45+
new SpecificityMetricCalculator(),
46+
new RecallMetricCalculator(),
47+
new FNRMetricCalculator(),
48+
new FPRMetricCalculator()
3249
]
3350
});
3451

@@ -43,23 +60,26 @@ async function train(): Promise<void> {
4360
trainingDataService.getTestingSamplesCount()
4461
]);
4562

46-
const makeTrainingDataset = (): data.Dataset<TensorContainer> => makeChunkedDataset({
47-
loadChunk: (skip, take) => trainingDataService.getTrainingSamples(skip, take),
48-
chunkSize: 32,
49-
batchSize: 32
50-
});
63+
const makeTrainingDataset = (): data.Dataset<TensorContainer> =>
64+
makeChunkedDataset({
65+
loadChunk: (skip, take) => trainingDataService.getTrainingSamples(skip, take),
66+
chunkSize: 32,
67+
batchSize: 32
68+
});
5169

52-
const makeValidationDataset = (): data.Dataset<TensorContainer> => makeChunkedDataset({
53-
loadChunk: (skip, take) => trainingDataService.getValidationSamples(skip, take),
54-
chunkSize: 32,
55-
batchSize: validationSamplesCount
56-
});
70+
const makeValidationDataset = (): data.Dataset<TensorContainer> =>
71+
makeChunkedDataset({
72+
loadChunk: (skip, take) => trainingDataService.getValidationSamples(skip, take),
73+
chunkSize: 32,
74+
batchSize: validationSamplesCount
75+
});
5776

58-
const makeTestingDataset = (): data.Dataset<TensorContainer> => makeChunkedDataset({
59-
loadChunk: (skip, take) => trainingDataService.getTestingSamples(skip, take),
60-
chunkSize: 32,
61-
batchSize: testingSamplesCount
62-
});
77+
const makeTestingDataset = (): data.Dataset<TensorContainer> =>
78+
makeChunkedDataset({
79+
loadChunk: (skip, take) => trainingDataService.getTestingSamples(skip, take),
80+
chunkSize: 32,
81+
batchSize: testingSamplesCount
82+
});
6383

6484
const trainingDataset = makeTrainingDataset();
6585
const validationDataset = makeValidationDataset();

0 commit comments

Comments
 (0)