1+ /*
2+ * @Author : victorsun
3+ * @Date : 2019-12-04 20:15:29
4+ * @LastEditors : victorsun - csxiaoyao
5+ * @LastEditTime : 2020-03-22 20:13:25
6+ * @Description : sunjianfeng@csxiaoyao.com
7+ */
8+ import * as tf from '@tensorflow/tfjs' ;
9+ import * as tfvis from '@tensorflow/tfjs-vis' ;
10+ import { getInputs } from './data' ;
11+ import { img2x , file2img } from './utils' ;
12+
13+ /**
14+ * 【 迁移学习 】
15+ * 把已训练好的模型参数迁移到新的模型来帮助新模型训练
16+ * 深度学习模型参数多,从头训练成本高
17+ * 删除原始模型的最后一层,基于此截断模型的输出训练一个新的(通常相当浅的)模型
18+ * 本案例,在 mobilenet 基础上,最后输出 ['android', 'apple', 'windows'] 三选一
19+ * 模型的保存
20+ */
21+ const MOBILENET_MODEL_PATH = 'http://127.0.0.1:8080/mobilenet/web_model/model.json' ;
22+ const NUM_CLASSES = 3 ;
23+ const BRAND_CLASSES = [ 'android' , 'apple' , 'windows' ] ;
24+
25+ window . onload = async ( ) => {
26+ // 1. 获取输入数据并在 visor 面板中展示
27+ const { inputs, labels } = await getInputs ( ) ;
28+ const surface = tfvis . visor ( ) . surface ( { name : '输入示例' , styles : { height : 250 } } ) ;
29+ inputs . forEach ( img => {
30+ surface . drawArea . appendChild ( img ) ;
31+ } ) ;
32+
33+ // 加载mobilenet 模型并截断 构建双层神经网络 截断模型作为输入,双层神经网络作为输出
34+ // 2. 模型迁移
35+ // 2.1 加载 mobilenet 模型, tfjs_layers_model 格式
36+ const mobilenet = await tf . loadLayersModel ( MOBILENET_MODEL_PATH ) ;
37+ // 查看模型概况
38+ mobilenet . summary ( ) ;
39+
40+ // 2.2 获取模型中间层并截断
41+ const layer = mobilenet . getLayer ( 'conv_pw_13_relu' ) ; // 根据层名获取层
42+ // 生成新的截断模型
43+ const truncatedMobilenet = tf . model ( {
44+ inputs : mobilenet . inputs ,
45+ outputs : layer . output
46+ } ) ;
47+
48+ // 3. 构建双层神经网络,tensor数据从 mobilenet 模型 flow 到 构建到双层神经网络模型
49+ // 初始化神经网络模型
50+ const model = tf . sequential ( ) ;
51+ // flatten输入
52+ model . add ( tf . layers . flatten ( {
53+ inputShape : layer . outputShape . slice ( 1 ) // [null,7,7,256] => [7,7,256],null表示个数不定,此处删除
54+ } ) ) ;
55+ // 双层神经网络
56+ model . add ( tf . layers . dense ( {
57+ units : 10 ,
58+ activation : 'relu'
59+ } ) ) ;
60+ model . add ( tf . layers . dense ( {
61+ units : NUM_CLASSES , // 输出类别数量
62+ activation : 'softmax'
63+ } ) ) ;
64+
65+ // 4. 训练
66+ // 4.1 定义损失函数和优化器
67+ model . compile ( {
68+ loss : 'categoricalCrossentropy' , // 交叉熵
69+ optimizer : tf . train . adam ( )
70+ } ) ;
71+ // 4.2 数据预处理: 处理输入为截断模型接受的数据格式,即 mobilenet 接受的格式
72+ const { xs, ys } = tf . tidy ( ( ) => {
73+ // img2x: img 转 mobilenet 接受的tensor格式,并合并单个 tensor 为一个大 tensor
74+ const xs = tf . concat ( inputs . map ( imgEl => truncatedMobilenet . predict ( img2x ( imgEl ) ) ) ) ;
75+ const ys = tf . tensor ( labels ) ;
76+ return { xs, ys } ;
77+ } ) ;
78+ // 4.3 通过 fit 方法训练
79+ await model . fit ( xs , ys , {
80+ epochs : 20 ,
81+ callbacks : tfvis . show . fitCallbacks (
82+ { name : '训练效果' } ,
83+ [ 'loss' ] ,
84+ { callbacks : [ 'onEpochEnd' ] }
85+ )
86+ } ) ;
87+
88+ // 5. 迁移学习下的模型预测
89+ window . predict = async ( file ) => {
90+ const img = await file2img ( file ) ;
91+ document . body . appendChild ( img ) ;
92+ const pred = tf . tidy ( ( ) => {
93+ // img 转 tensor
94+ const x = img2x ( img ) ;
95+ // 截断模型先执行
96+ const input = truncatedMobilenet . predict ( x ) ;
97+ // 再用新模型预测出最终结果
98+ return model . predict ( input ) ;
99+ } ) ;
100+ const index = pred . argMax ( 1 ) . dataSync ( ) [ 0 ] ;
101+ setTimeout ( ( ) => {
102+ alert ( `预测结果:${ BRAND_CLASSES [ index ] } ` ) ;
103+ } , 0 ) ;
104+ } ;
105+
106+ // 6. 模型的保存 tfjs_layers_model
107+ // json + 权重bin
108+ window . download = async ( ) => {
109+ await model . save ( 'downloads://model' ) ;
110+ } ;
111+ } ;
0 commit comments