Skip to content

Commit ab95fb6

Browse files
committed
Add fit()
1 parent 5773f8c commit ab95fb6

File tree

3 files changed

+16
-8
lines changed

3 files changed

+16
-8
lines changed

src/ai/models/Model.js

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,18 @@ export default class Model {
1717
return this.predict([inputX]);
1818
}
1919

20-
train(inputXs, inputYs, iterationCount = 100) {
20+
train(inputXs, inputYs) {
2121
throw new Error(
2222
'Abstract method must be implemented in the derived class.'
2323
);
2424
}
2525

26+
fit(inputXs, inputYs, iterationCount = 100) {
27+
for (let i = 0; i < iterationCount; i += 1) {
28+
this.train(inputXs, inputYs);
29+
}
30+
}
31+
2632
loss(predictedYs, labels) {
2733
const meanSquareError = predictedYs
2834
.sub(tensor(labels))

src/ai/models/genetic/GeneticModel.js

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ export default class GeneticModel extends Model {
77
this.mutate(offspring);
88
}
99

10+
fit(chromosomes) {
11+
this.train(chromosomes);
12+
}
13+
1014
select(chromosomes) {
1115
const parents = [chromosomes[0], chromosomes[1]];
1216
return parents;

src/ai/models/nn/NNModel.js

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,10 @@ export default class NNModel extends Model {
4747
return prediction;
4848
}
4949

50-
train(inputXs, inputYs, iterationCount = 100) {
51-
for (let i = 0; i < iterationCount; i += 1) {
52-
this.optimizer.minimize(() => {
53-
const predictedYs = this.predict(inputXs);
54-
return this.loss(predictedYs, inputYs);
55-
});
56-
}
50+
train(inputXs, inputYs) {
51+
this.optimizer.minimize(() => {
52+
const predictedYs = this.predict(inputXs);
53+
return this.loss(predictedYs, inputYs);
54+
});
5755
}
5856
}

0 commit comments

Comments
 (0)