Skip to content

Commit 1313667

Browse files
authored
Merge pull request #24 from RonasIT/vglinskii/more-metrics
2 parents 3ae8dda + 60e011b commit 1313667

File tree

25 files changed

+346
-76
lines changed

25 files changed

+346
-76
lines changed

README.md

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,25 @@ We have a list of built-in metric calculators for popular metrics:
7272
- PrecisionMetricCalculator
7373
- RecallMetricCalculator
7474
- SpecificityMetricCalculator
75-
- F1ScoreMetricCalculator
7675
- FNRMetricCalculator
7776
- FPRMetricCalculator
77+
- NPVMetricCalculator
78+
- MCCMetricCalculator
79+
- FBetaScoreMetricCalculator
80+
- ROCAUCMetricCalculator
81+
- PRAUCMetricCalculator
82+
- BrierLossMetricCalculator
83+
- BinaryCrossentropyMetricCalculator
84+
- CohenKappaMetricCalculator
7885

7986
You can implement your own `MetricCalculator`. In the example below, we define
8087
a metric calculator for `precision`. For that we create a `PrecisionMetricCalculator`
8188
class extending the `MetricCalculator` base class provided by the library and
8289
implementing `calculate` method.
8390

8491
```typescript
85-
class PrecisionMetricCalculator extends MetricCalculator {
86-
public calculate(trueValues: Float32Array, predictedValues: Float32Array): Metric {
92+
export class PrecisionMetricCalculator extends MetricCalculator {
93+
public calculate({ trueValues, predictedValues }: TestingResult): Metric {
8794
const { tp, fp } = new ConfusionMatrix(trueValues, predictedValues);
8895

8996
return new Metric({
@@ -258,7 +265,7 @@ const ownsTheCar = await classifier.predict([0.2, 0.76, 0]);
258265
- [x] Asynchronously loaded datasets ([#14](https://github.com/RonasIT/tfjs-node-helpers/issues/14))
259266
- [x] Feature normalization ([#5](https://github.com/RonasIT/tfjs-node-helpers/issues/5))
260267
- [x] Custom metrics ([#18](https://github.com/RonasIT/tfjs-node-helpers/issues/18))
261-
- [ ] Add more metrics ([#17](https://github.com/RonasIT/tfjs-node-helpers/issues/17))
268+
- [x] Add more metrics ([#17](https://github.com/RonasIT/tfjs-node-helpers/issues/17))
262269
- [ ] Refactor features ([#25](https://github.com/RonasIT/tfjs-node-helpers/issues/25))
263270
- [ ] Task-oriented architecture ([#26](https://github.com/RonasIT/tfjs-node-helpers/issues/26))
264271
- [ ] Categorical features ([#19](https://github.com/RonasIT/tfjs-node-helpers/issues/19))

packages/tfjs-node-helpers-example/src/app/app.ts

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,20 @@ import {
22
AccuracyMetricCalculator,
33
BinaryClassificationTrainer,
44
BinaryClassifier,
5-
F1ScoreMetricCalculator,
5+
BrierLossMetricCalculator,
66
FNRMetricCalculator,
77
FPRMetricCalculator,
8+
BinaryCrossentropyMetricCalculator,
89
makeChunkedDataset,
10+
PRAUCMetricCalculator,
911
PrecisionMetricCalculator,
1012
RecallMetricCalculator,
11-
SpecificityMetricCalculator
13+
ROCAUCMetricCalculator,
14+
SpecificityMetricCalculator,
15+
CohenKappaMetricCalculator,
16+
NPVMetricCalculator,
17+
MCCMetricCalculator,
18+
FBetaScoreMetricCalculator
1219
} from '@ronas-it/tfjs-node-helpers';
1320
import { data, layers, TensorContainer } from '@tensorflow/tfjs-node';
1421
import { AgeFeatureExtractor } from './feature-extractors/age';
@@ -34,18 +41,23 @@ async function train(): Promise<void> {
3441
new GenderFeatureExtractor()
3542
],
3643
outputFeatureExtractor: new OwnsTheCarFeatureExtractor(),
37-
inputFeatureNormalizers: [
38-
new AgeMinMaxFeatureNormalizer(),
39-
new AnnualSalaryMinMaxFeatureNormalizer()
40-
],
44+
inputFeatureNormalizers: [new AgeMinMaxFeatureNormalizer(), new AnnualSalaryMinMaxFeatureNormalizer()],
4145
metricCalculators: [
4246
new AccuracyMetricCalculator(),
4347
new PrecisionMetricCalculator(),
44-
new F1ScoreMetricCalculator(),
48+
new FBetaScoreMetricCalculator(1),
4549
new SpecificityMetricCalculator(),
4650
new RecallMetricCalculator(),
4751
new FNRMetricCalculator(),
48-
new FPRMetricCalculator()
52+
new FPRMetricCalculator(),
53+
new NPVMetricCalculator(),
54+
new MCCMetricCalculator(),
55+
new FBetaScoreMetricCalculator(2),
56+
new ROCAUCMetricCalculator(),
57+
new PRAUCMetricCalculator(),
58+
new BrierLossMetricCalculator(),
59+
new BinaryCrossentropyMetricCalculator(),
60+
new CohenKappaMetricCalculator()
4961
]
5062
});
5163

packages/tfjs-node-helpers/src/classification/binary-classification-trainer.ts

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,13 +189,13 @@ export class BinaryClassificationTrainer {
189189
const predictions = this.model.predict(testXs) as Tensor;
190190
const binarizedPredictions = binarize(predictions);
191191

192-
const trueValues = await testYs.data<'float32'>();
193-
const predictedValues = await binarizedPredictions.data<'float32'>();
194-
195-
const confusionMatrix = new ConfusionMatrix(trueValues, predictedValues);
192+
const confusionMatrix = new ConfusionMatrix(testYs, binarizedPredictions);
196193
const metrics = calculateMetrics({
197-
trueValues,
198-
predictedValues,
194+
testingResult: {
195+
trueValues: testYs,
196+
predictedValues: binarizedPredictions,
197+
probabilities: predictions
198+
},
199199
metricCalculators: this.metricCalculators
200200
});
201201

packages/tfjs-node-helpers/src/index.ts

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,19 @@ export * from './utils/make-chunked-dataset';
1313
export * from './testing/calculate-metrics';
1414
export * from './testing/confusion-matrix';
1515
export * from './testing/metric';
16+
export * from './testing/result';
1617
export * from './testing/metric-calculator';
1718
export * from './testing/metrics/accuracy';
1819
export * from './testing/metrics/precision';
19-
export * from './testing/metrics/f1-score';
2020
export * from './testing/metrics/recall';
2121
export * from './testing/metrics/specificity';
2222
export * from './testing/metrics/fpr';
2323
export * from './testing/metrics/fnr';
24+
export * from './testing/metrics/fbeta-score';
25+
export * from './testing/metrics/mcc';
26+
export * from './testing/metrics/npv';
27+
export * from './testing/metrics/roc-auc';
28+
export * from './testing/metrics/pr-auc';
29+
export * from './testing/metrics/brier-loss';
30+
export * from './testing/metrics/binary-crossentropy';
31+
export * from './testing/metrics/cohen-kappa';
Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import { Metric } from './metric';
22
import { MetricCalculator } from './metric-calculator';
3+
import { TestingResult } from './result';
34

45
export const calculateMetrics = ({
5-
trueValues,
6-
predictedValues,
6+
testingResult,
77
metricCalculators
88
}: {
9-
trueValues: Float32Array;
10-
predictedValues: Float32Array;
9+
testingResult: TestingResult;
1110
metricCalculators: Array<MetricCalculator>;
12-
}): Array<Metric> => metricCalculators.map((calculator) => calculator.calculate(trueValues, predictedValues));
11+
}): Array<Metric> => metricCalculators.map((calculator) => calculator.calculate(testingResult));
Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
1+
import { math, Tensor } from '@tensorflow/tfjs-node';
2+
13
export class ConfusionMatrix {
24
public tp: number;
35
public tn: number;
46
public fp: number;
57
public fn: number;
68

7-
constructor(trueValues: Float32Array, predictedValues: Float32Array) {
8-
this.tp = 0;
9-
this.tn = 0;
10-
this.fp = 0;
11-
this.fn = 0;
9+
constructor(trueValues: Tensor, predictedValues: Tensor) {
10+
const trueY = trueValues.as1D();
11+
const predY = predictedValues.as1D();
12+
const numberOfClasses = 2;
13+
14+
const confusionMatrix = math.confusionMatrix(trueY, predY, numberOfClasses);
15+
16+
trueY.dispose();
17+
predY.dispose();
18+
19+
const [[tn, fp], [fn, tp]] = confusionMatrix.arraySync();
1220

13-
for (let index = 0; index < trueValues.length; index++) {
14-
const trueValue = trueValues[index];
15-
const predictedValue = predictedValues[index];
21+
confusionMatrix.dispose();
1622

17-
if (trueValue === 1 && predictedValue === 1) {
18-
this.tp++;
19-
} else if (trueValue === 0 && predictedValue === 0) {
20-
this.tn++;
21-
} else if (trueValue === 0 && predictedValue === 1) {
22-
this.fp++;
23-
} else {
24-
this.fn++;
25-
}
26-
}
23+
this.tp = tp;
24+
this.fp = fp;
25+
this.fn = fn;
26+
this.tn = tn;
2727
}
2828
}
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { Metric } from './metric';
2+
import { TestingResult } from './result';
23

34
export abstract class MetricCalculator {
4-
public abstract calculate(trueValues: Float32Array, predictedValues: Float32Array): Metric;
5+
public abstract calculate(testingResult: TestingResult): Metric;
56
}
Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
1-
import { ConfusionMatrix } from '../confusion-matrix';
1+
import { metrics } from '@tensorflow/tfjs-node';
22
import { Metric } from '../metric';
33
import { MetricCalculator } from '../metric-calculator';
4+
import { TestingResult } from '../result';
45

56
export class AccuracyMetricCalculator extends MetricCalculator {
6-
public calculate(trueValues: Float32Array, predictedValues: Float32Array): Metric {
7-
const { tp, tn, fp, fn } = new ConfusionMatrix(trueValues, predictedValues);
7+
public calculate({ trueValues, predictedValues }: TestingResult): Metric {
8+
const trueY = trueValues.as1D();
9+
const predY = predictedValues.as1D();
10+
const valueTensor = metrics.binaryAccuracy(trueY, predY);
11+
12+
trueY.dispose();
13+
predY.dispose();
14+
15+
const [value] = valueTensor.dataSync();
16+
17+
valueTensor.dispose();
818

919
return new Metric({
1020
title: 'Accuracy',
11-
value: (tp + tn) / (tp + tn + fp + fn)
21+
value
1222
});
1323
}
1424
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import { metrics } from '@tensorflow/tfjs-node';
2+
import { Metric } from '../metric';
3+
import { MetricCalculator } from '../metric-calculator';
4+
import { TestingResult } from '../result';
5+
6+
export class BinaryCrossentropyMetricCalculator extends MetricCalculator {
7+
public calculate({ trueValues, probabilities }: TestingResult): Metric {
8+
const trueY = trueValues.as1D();
9+
const predY = probabilities.as1D();
10+
const valueTensor = metrics.binaryCrossentropy(trueY, predY);
11+
12+
trueY.dispose();
13+
predY.dispose();
14+
15+
const [value] = valueTensor.dataSync();
16+
17+
valueTensor.dispose();
18+
19+
return new Metric({
20+
title: 'Binary Cross-entropy',
21+
value
22+
});
23+
}
24+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import { metrics } from '@tensorflow/tfjs-node';
2+
import { Metric } from '../metric';
3+
import { MetricCalculator } from '../metric-calculator';
4+
import { TestingResult } from '../result';
5+
6+
export class BrierLossMetricCalculator extends MetricCalculator {
7+
public calculate({ trueValues, probabilities }: TestingResult): Metric {
8+
const trueY = trueValues.as1D();
9+
const predY = probabilities.as1D();
10+
const valueTensor = metrics.MSE(trueY, predY);
11+
12+
trueY.dispose();
13+
predY.dispose();
14+
15+
const [value] = valueTensor.dataSync();
16+
17+
valueTensor.dispose();
18+
19+
return new Metric({
20+
title: 'Brier Loss',
21+
value
22+
});
23+
}
24+
}

0 commit comments

Comments
 (0)