Skip to content

Commit 330c3a0

Browse files
Merge pull request #350 from code-dot-org/rename-simple-trainer
Rename SimpleTrainer to KNNTrainer as that is a more descriptive name
2 parents 4fe1ed2 + a0ffeb9 commit 330c3a0

File tree

7 files changed

+19
-19
lines changed

7 files changed

+19
-19
lines changed

src/oceans/models/loading.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import {AppMode, Modes} from '../constants';
55
import {initFishData} from '../../utils/fishData';
66
import {getAppMode, $time, finishLoading} from '../helpers';
77
import modeHelpers from '../modeHelpers';
8-
import SimpleTrainer from '../../utils/SimpleTrainer';
8+
import KNNTrainer from '../../utils/KNNTrainer';
99

1010
export const init = async () => {
1111
const startTime = $time();
@@ -23,7 +23,7 @@ export const init = async () => {
2323
].includes(appModeBase);
2424

2525
if (appModeBase === AppMode.CreaturesVTrashDemo) {
26-
const trainer = new SimpleTrainer(oceanObj => oceanObj.getTensor());
26+
const trainer = new KNNTrainer(oceanObj => oceanObj.getTensor());
2727
setState({trainer, word: 'fish'});
2828
}
2929

src/oceans/models/train.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import 'idempotent-babel-polyfill';
22
import {setState, getState} from '../state';
33
import {ClassType, AppMode} from '../constants';
4-
import SimpleTrainer from '../../utils/SimpleTrainer';
4+
import KNNTrainer from '../../utils/KNNTrainer';
55
import SVMTrainer from '../../utils/SVMTrainer';
66
import {generateOcean} from '../../utils/generateOcean';
77
import I18n from '../i18n';
@@ -14,7 +14,7 @@ const init = () => {
1414
if ([AppMode.FishShort, AppMode.FishLong].includes(state.appMode)) {
1515
trainer = new SVMTrainer(fish => fish.getKnnData());
1616
} else {
17-
trainer = new SimpleTrainer(oceanObj => oceanObj.getTensor());
17+
trainer = new KNNTrainer(oceanObj => oceanObj.getTensor());
1818
}
1919
}
2020

src/utils/SimpleTrainer.js renamed to src/utils/KNNTrainer.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import * as tf from '@tensorflow/tfjs';
22
import * as knnClassifier from '@tensorflow-models/knn-classifier';
33

4-
export default class SimpleTrainer {
4+
export default class KNNTrainer {
55
constructor(converterFn) {
66
this.converterFn = converterFn || (input => input); // Default to returning example as-is
77
this.knn = knnClassifier.create();
@@ -41,7 +41,7 @@ export default class SimpleTrainer {
4141
return result;
4242
}
4343

44-
// SimpleTrainer-specific methods below
44+
// KNNTrainer-specific methods below
4545

4646
setTopK(k) {
4747
this.TOPK = k;

test/unit/oceans/models/pond.test.js

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import {setState, getState, resetState} from '@ml/oceans/state';
33
import {TrashOceanObject} from '@ml/oceans/OceanObject';
44
import {ClassType, Modes} from '@ml/oceans/constants';
55
import {init} from '@ml/oceans/models/pond';
6-
import SimpleTrainer from '@ml/utils/SimpleTrainer';
6+
import KNNTrainer from '@ml/utils/KNNTrainer';
77
import {generateOcean} from '@ml/utils/generateOcean';
88

99
describe('Model quality test', () => {
@@ -14,7 +14,7 @@ describe('Model quality test', () => {
1414
beforeEach(() => {
1515
resetState();
1616
setState({
17-
trainer: new SimpleTrainer(),
17+
trainer: new KNNTrainer(),
1818
mode: Modes.Pond,
1919
fishData: generateOcean(100, 0, true, true)
2020
});
@@ -31,7 +31,7 @@ describe('Model quality test', () => {
3131
});
3232

3333
test('init state with predictions', async () => {
34-
const trainer = new SimpleTrainer();
34+
const trainer = new KNNTrainer();
3535
trainer.predict = jest.fn(async example => {
3636
if (example instanceof TrashOceanObject) {
3737
return {predictedClassId: 1, confidenceByClassId: {0: 0, 1: 1}};

test/unit/oceans/models/predict.test.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ const {initFishData} = require('@ml/utils/fishData');
22
import {setState, getState, resetState} from '@ml/oceans/state';
33
import {ClassType, Modes, AppMode} from '@ml/oceans/constants';
44
import {init, predictFish} from '@ml/oceans/models/predict';
5-
import SimpleTrainer from '@ml/utils/SimpleTrainer';
5+
import KNNTrainer from '@ml/utils/KNNTrainer';
66
import {TrashOceanObject} from '@ml/oceans/OceanObject';
77

88
describe('Predict test', () => {
@@ -12,7 +12,7 @@ describe('Predict test', () => {
1212
});
1313

1414
beforeEach(() => {
15-
const trainer = new SimpleTrainer();
15+
const trainer = new KNNTrainer();
1616
trainer.predict = jest.fn(async example => {
1717
if (example instanceof TrashOceanObject) {
1818
return {predictedClassId: 1, confidenceByClassId: {0: 0, 1: 1}};

test/unit/utils/SimpleTrainer.test.js renamed to test/unit/utils/KNNTrainer.test.js

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
* @jest-environment node
33
*/
44

5-
import SimpleTrainer from '@ml/utils/SimpleTrainer';
5+
import KNNTrainer from '@ml/utils/KNNTrainer';
66
import * as tf from '@tensorflow/tfjs';
77

88
describe('Simple Trainer tests', () => {
9-
test('SimpleTrainer predicts', async () => {
10-
const trainer = new SimpleTrainer();
9+
test('KNNTrainer predicts', async () => {
10+
const trainer = new KNNTrainer();
1111
trainer.setTopK(3);
1212

1313
trainer.addTrainingExample(tf.tensor([1, 1]), 0);
@@ -24,8 +24,8 @@ describe('Simple Trainer tests', () => {
2424
trainer.dispose();
2525
});
2626

27-
test('SimpleTrainer can be restored', async () => {
28-
const trainer = new SimpleTrainer();
27+
test('KNNTrainer can be restored', async () => {
28+
const trainer = new KNNTrainer();
2929
trainer.setTopK(3);
3030

3131
trainer.addTrainingExample(tf.tensor([1, 1]), 0);
@@ -42,7 +42,7 @@ describe('Simple Trainer tests', () => {
4242
const classifierDatasetString = trainer.getDatasetJSON();
4343
trainer.clearAll();
4444

45-
const retrainedTrainer = new SimpleTrainer();
45+
const retrainedTrainer = new KNNTrainer();
4646
retrainedTrainer.setTopK(3);
4747
const untrainedResult = await retrainedTrainer.predict(tf.tensor([1, 1]));
4848
expect(untrainedResult.predictedClassId).toEqual(null);

test/unit/utils/generateOcean.test.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import {initFishData} from '@ml/utils/fishData';
66
import {generateOcean, filterOcean} from '@ml/utils/generateOcean';
7-
import SimpleTrainer from '@ml/utils/SimpleTrainer';
7+
import KNNTrainer from '@ml/utils/KNNTrainer';
88

99
describe('Generate ocean test', () => {
1010
beforeAll(() => {
@@ -34,7 +34,7 @@ describe('Generate ocean test', () => {
3434
test('Can generate predictions on a random set of fish', async () => {
3535
const numFish = 25;
3636
const trainingOcean = generateOcean(numFish);
37-
const trainer = new SimpleTrainer(fish => fish.getTensor());
37+
const trainer = new KNNTrainer(fish => fish.getTensor());
3838
trainingOcean.forEach(fish => {
3939
trainer.addTrainingExample(fish, Math.round(Math.random()));
4040
});

0 commit comments

Comments
 (0)