Skip to content

Commit b826685

Browse files
tensorflow.js
1 parent 678087d commit b826685

File tree

131 files changed

+3818
-36
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

131 files changed

+3818
-36
lines changed
File renamed without changes.
Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,40 @@
1+
/*
2+
* @Author: victorsun
3+
* @Date: 2019-12-04 20:15:29
4+
* @LastEditors: victorsun - csxiaoyao
5+
* @LastEditTime: 2020-03-21 18:34:40
6+
* @Description: sunjianfeng@csxiaoyao.com
7+
*/
18
export function getData(numSamples) {
2-
let points = [];
3-
4-
function genGauss(cx, cy, label) {
5-
for (let i = 0; i < numSamples / 2; i++) {
6-
let x = normalRandom(cx);
7-
let y = normalRandom(cy);
8-
points.push({ x, y, label });
9-
}
9+
let points = [];
10+
// 生成高斯分布(正态分布)的点
11+
function genGauss(cx, cy, label) {
12+
// numSamples 分成两拨
13+
for (let i = 0; i < numSamples / 2; i++) {
14+
let x = normalRandom(cx);
15+
let y = normalRandom(cy);
16+
points.push({ x, y, label });
1017
}
11-
12-
genGauss(2, 2, 1);
13-
genGauss(-2, -2, 0);
14-
return points;
1518
}
16-
17-
/**
18-
* Samples from a normal distribution. Uses the seedrandom library as the
19-
* random generator.
20-
*
21-
* @param mean The mean. Default is 0.
22-
* @param variance The variance. Default is 1.
23-
*/
24-
function normalRandom(mean = 0, variance = 1) {
25-
let v1, v2, s;
26-
do {
27-
v1 = 2 * Math.random() - 1;
28-
v2 = 2 * Math.random() - 1;
29-
s = v1 * v1 + v2 * v2;
30-
} while (s > 1);
31-
32-
let result = Math.sqrt(-2 * Math.log(s) / s) * v1;
33-
return mean + Math.sqrt(variance) * result;
34-
}
19+
genGauss(2, 2, 1);
20+
genGauss(-2, -2, 0);
21+
return points;
22+
}
23+
24+
/**
25+
* Box-Muller transform 算法
26+
* @param mean 正态分布中心值
27+
* @param variance 密集程度
28+
*/
29+
function normalRandom(mean = 0, variance = 1) {
30+
let v1, v2, s;
31+
// Math.random() 0-1均匀分布
32+
// 原正态分布公式使用了 sin cos,性能较低,此处进行了优化
33+
do {
34+
v1 = 2 * Math.random() - 1;
35+
v2 = 2 * Math.random() - 1;
36+
s = v1 * v1 + v2 * v2;
37+
} while (s > 1);
38+
let result = Math.sqrt(-2 * Math.log(s) / s) * v1;
39+
return mean + Math.sqrt(variance) * result;
40+
}

20-tensorflow.js/05-logistic-regression/script.js

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
* @Author: victorsun
33
* @Date: 2019-12-04 20:15:29
44
* @LastEditors: victorsun - csxiaoyao
5-
* @LastEditTime: 2020-03-21 13:27:25
5+
* @LastEditTime: 2020-03-22 00:05:54
66
* @Description: sunjianfeng@csxiaoyao.com
77
*/
88
import * as tf from '@tensorflow/tfjs';
@@ -31,16 +31,16 @@ window.onload = async () => {
3131

3232
// 2. 初始化神经网络模型
3333
const model = tf.sequential();
34-
// 添加层,dense: y=ax+b,设置激活函数(防止输入超过100%,对过大过小值收敛,保证数据在 0 - 1 之间),常用 sigmoid
34+
// 添加层,dense: y=ax+b,设置激活函数sigmoid(防止输入超过100%,对过大过小值收敛,保证数据在 0 - 1 之间)
3535
model.add(tf.layers.dense({
3636
units: 1, // 输出值为一个概率值,1个神经元即可
3737
inputShape: [2], // 坐标 x,y 两个值,特征数量为2
38-
activation: 'sigmoid' // 设置激活函数 sigmoid
38+
activation: 'sigmoid' // 设置激活函数 sigmoid 0-1
3939
}));
4040
// 设置损失函数和优化器
4141
model.compile({
42-
loss: tf.losses.logLoss,
43-
optimizer: tf.train.adam(0.1) // adam自动调节学习速率
42+
loss: tf.losses.logLoss, // 损失函数,log损失,用于逻辑回归问题
43+
optimizer: tf.train.adam(0.1) // adam自动调节学习速率,初始化学习速率0.1
4444
});
4545

4646
// 3. 训练数据转tensor

20-tensorflow.js/06-xor/data.js

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* @Author: victorsun
3+
* @Date: 2019-12-04 20:15:29
4+
* @LastEditors: victorsun - csxiaoyao
5+
* @LastEditTime: 2020-03-21 18:53:29
6+
* @Description: sunjianfeng@csxiaoyao.com
7+
*/
8+
// 参考05逻辑回归中data
9+
export function getData(numSamples) {
10+
let points = [];
11+
12+
function genGauss(cx, cy, label) {
13+
for (let i = 0; i < numSamples / 2; i++) {
14+
let x = normalRandom(cx);
15+
let y = normalRandom(cy);
16+
points.push({ x, y, label });
17+
}
18+
}
19+
20+
genGauss(2, 2, 0);
21+
genGauss(-2, -2, 0);
22+
genGauss(-2, 2, 1);
23+
genGauss(2, -2, 1);
24+
return points;
25+
}
26+
27+
/**
28+
* Samples from a normal distribution. Uses the seedrandom library as the
29+
* random generator.
30+
*
31+
* @param mean The mean. Default is 0.
32+
* @param variance The variance. Default is 1.
33+
*/
34+
function normalRandom(mean = 0, variance = 1) {
35+
let v1, v2, s;
36+
do {
37+
v1 = 2 * Math.random() - 1;
38+
v2 = 2 * Math.random() - 1;
39+
s = v1 * v1 + v2 * v2;
40+
} while (s > 1);
41+
42+
let result = Math.sqrt(-2 * Math.log(s) / s) * v1;
43+
return mean + Math.sqrt(variance) * result;
44+
}

20-tensorflow.js/06-xor/index.html

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
<script src="script.js"></script>
2+
<form action="" onsubmit="predict(this);return false;">
3+
x: <input type="text" name="x">
4+
y: <input type="text" name="y">
5+
<button type="submit">预测</button>
6+
</form>

20-tensorflow.js/06-xor/script.js

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* @Author: victorsun
3+
* @Date: 2019-12-04 20:15:29
4+
* @LastEditors: victorsun - csxiaoyao
5+
* @LastEditTime: 2020-03-22 00:06:21
6+
* @Description: sunjianfeng@csxiaoyao.com
7+
*/
8+
import * as tf from '@tensorflow/tfjs';
9+
import * as tfvis from '@tensorflow/tfjs-vis';
10+
import { getData } from './data.js';
11+
12+
/**
13+
* XOR 异或(非线性)逻辑回归
14+
* 分割为四个象限
15+
*/
16+
window.onload = async () => {
17+
// 1. 准备训练数据
18+
const data = getData(400);
19+
// 散点图
20+
tfvis.render.scatterplot(
21+
{ name: 'XOR 训练数据' },
22+
{
23+
values: [
24+
data.filter(p => p.label === 1),
25+
data.filter(p => p.label === 0),
26+
]
27+
}
28+
);
29+
// 2. 初始化神经网络
30+
const model = tf.sequential();
31+
32+
// 3. 添加全连接层
33+
// 3.1 隐藏层,核心
34+
model.add(tf.layers.dense({
35+
units: 4, // 神经元个数,比如此案例,4个象限
36+
inputShape: [2], // 输入特征,坐标 x y,inputShape为2
37+
// 激活函数 relu 让神经网络拥有非线性拟合能力,这是复杂神经网络拟合的核心
38+
// 如果不加这个激活函数,结果只能为线性,不能得到非线性的结果
39+
activation: 'relu'
40+
}));
41+
// 3.2 输出层
42+
model.add(tf.layers.dense({
43+
units: 1, // 神经元个数,最终输出的是一个概率值
44+
// inputShape: [4], // 不需要设置,因为上层的 units 已经确定了
45+
activation: 'sigmoid' // 选 sigmoid 输出 0-1 之间的值
46+
}));
47+
// 3.3 设置损失函数和优化器
48+
model.compile({
49+
loss: tf.losses.logLoss, // 损失函数,log损失,本质也是逻辑回归
50+
optimizer: tf.train.adam(0.1) // 优化器,adam,初始化学习速率0.1
51+
});
52+
53+
// 4. 准备训练数据
54+
const inputs = tf.tensor(data.map(p => [p.x, p.y]));
55+
const labels = tf.tensor(data.map(p => p.label));
56+
57+
// 5. 训练
58+
await model.fit(inputs, labels, {
59+
epochs: 10, // 10轮
60+
callbacks: tfvis.show.fitCallbacks(
61+
{ name: '训练效果' },
62+
['loss']
63+
)
64+
});
65+
66+
window.predict = (form) => {
67+
const pred = model.predict(tf.tensor([[form.x.value * 1, form.y.value * 1]]));
68+
alert(`预测结果:${pred.dataSync()[0]}`);
69+
};
70+
};

0 commit comments

Comments
 (0)