Skip to content

Commit 51809e2

Browse files
authored
updated10.25
1 parent 19e2fad commit 51809e2

File tree

1 file changed

+149
-0
lines changed

1 file changed

+149
-0
lines changed
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {
6+
"collapsed": true,
7+
"pycharm": {
8+
"is_executing": false
9+
}
10+
},
11+
"source": [
12+
"3.10 多层感知机的简洁实现\n",
13+
"下面我们使用tensorflow来实现上一节中的多层感知机。首先导入所需的包或模块"
14+
]
15+
},
16+
{
17+
"cell_type": "code",
18+
"execution_count": 9,
19+
"metadata": {
20+
"collapsed": true
21+
},
22+
"outputs": [],
23+
"source": [
24+
"import tensorflow as tf\n",
25+
"from tensorflow import keras\n",
26+
"import sys\n",
27+
"sys.path.append(\"..\") \n",
28+
"from tensorflow import keras\n",
29+
"fashion_mnist = keras.datasets.fashion_mnist"
30+
]
31+
},
32+
{
33+
"cell_type": "markdown",
34+
"metadata": {},
35+
"source": [
36+
"3.10.1 定义模型\n",
37+
"和softmax回归唯一的不同在于,我们多加了一个全连接层作为隐藏层。它的隐藏单元个数为256,并使用ReLU函数作为激活函数。"
38+
]
39+
},
40+
{
41+
"cell_type": "code",
42+
"execution_count": 10,
43+
"metadata": {
44+
"collapsed": true
45+
},
46+
"outputs": [],
47+
"source": [
48+
"model = tf.keras.models.Sequential([\n",
49+
" tf.keras.layers.Flatten(input_shape=(28, 28)),\n",
50+
" tf.keras.layers.Dense(256, activation='relu',),\n",
51+
" tf.keras.layers.Dense(10, activation='softmax')\n",
52+
" ])\n"
53+
]
54+
},
55+
{
56+
"cell_type": "markdown",
57+
"metadata": {},
58+
"source": [
59+
"3.10.2 读取数据并训练模型\n",
60+
"我们使用与3.7节中训练softmax回归几乎相同的步骤来读取数据并训练模型。"
61+
]
62+
},
63+
{
64+
"cell_type": "code",
65+
"execution_count": 12,
66+
"metadata": {},
67+
"outputs": [
68+
{
69+
"name": "stdout",
70+
"output_type": "stream",
71+
"text": [
72+
"Train on 60000 samples, validate on 10000 samples\n",
73+
"Epoch 1/5\n",
74+
"60000/60000 [==============================] - 2s 36us/sample - loss: 0.9844 - accuracy: 0.7100 - val_loss: 0.5575 - val_accuracy: 0.7945\n",
75+
"Epoch 2/5\n",
76+
"60000/60000 [==============================] - 2s 33us/sample - loss: 0.5038 - accuracy: 0.8141 - val_loss: 0.6090 - val_accuracy: 0.7702\n",
77+
"Epoch 3/5\n",
78+
"60000/60000 [==============================] - 2s 35us/sample - loss: 0.4334 - accuracy: 0.8396 - val_loss: 0.4691 - val_accuracy: 0.8341\n",
79+
"Epoch 4/5\n",
80+
"60000/60000 [==============================] - 2s 34us/sample - loss: 0.3969 - accuracy: 0.8535 - val_loss: 0.4293 - val_accuracy: 0.8494\n",
81+
"Epoch 5/5\n",
82+
"60000/60000 [==============================] - 2s 31us/sample - loss: 0.3754 - accuracy: 0.8621 - val_loss: 0.4657 - val_accuracy: 0.8288\n"
83+
]
84+
},
85+
{
86+
"data": {
87+
"text/plain": [
88+
"<tensorflow.python.keras.callbacks.History at 0x5ad1df28d0>"
89+
]
90+
},
91+
"execution_count": 12,
92+
"metadata": {},
93+
"output_type": "execute_result"
94+
}
95+
],
96+
"source": [
97+
"fashion_mnist = keras.datasets.fashion_mnist\n",
98+
"(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()\n",
99+
"x_train = x_train / 255.0\n",
100+
"x_test = x_test / 255.0\n",
101+
"model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.5),\n",
102+
" loss = 'sparse_categorical_crossentropy',\n",
103+
" metrics=['accuracy'])\n",
104+
"model.fit(x_train, y_train, epochs=5,\n",
105+
" batch_size=256,\n",
106+
" validation_data=(x_test, y_test),\n",
107+
" validation_freq=1)"
108+
]
109+
},
110+
{
111+
"cell_type": "markdown",
112+
"metadata": {},
113+
"source": [
114+
"小结\n",
115+
"通过Tensorflow2.0可以更简洁地实现多层感知机。"
116+
]
117+
},
118+
{
119+
"cell_type": "code",
120+
"execution_count": null,
121+
"metadata": {
122+
"collapsed": true
123+
},
124+
"outputs": [],
125+
"source": []
126+
}
127+
],
128+
"metadata": {
129+
"kernelspec": {
130+
"display_name": "Python 3",
131+
"language": "python",
132+
"name": "python3"
133+
},
134+
"language_info": {
135+
"codemirror_mode": {
136+
"name": "ipython",
137+
"version": 3
138+
},
139+
"file_extension": ".py",
140+
"mimetype": "text/x-python",
141+
"name": "python",
142+
"nbconvert_exporter": "python",
143+
"pygments_lexer": "ipython3",
144+
"version": "3.6.1"
145+
}
146+
},
147+
"nbformat": 4,
148+
"nbformat_minor": 1
149+
}

0 commit comments

Comments
 (0)