@@ -30,7 +30,7 @@ def printVector(f, vector, name):
3030 f .write ('\n };\n \n ' )
3131 return ;
3232
33- def printLayer (f , hf , layer ):
33+ def printLayer (f , layer ):
3434 weights = layer .get_weights ()
3535 printVector (f , weights [0 ], layer .name + '_weights' )
3636 if len (weights ) > 2 :
@@ -39,19 +39,24 @@ def printLayer(f, hf, layer):
3939 name = layer .name
4040 activation = re .search ('function (.*) at' , str (layer .activation )).group (1 ).upper ()
4141 if len (weights ) > 2 :
42- f .write ('const GRULayer {} = {{\n {}_bias,\n {}_weights,\n {}_recurrent_weights,\n {}, {}, ACTIVATION_{}\n }};\n \n '
42+ f .write ('static const GRULayer {} = {{\n {}_bias,\n {}_weights,\n {}_recurrent_weights,\n {}, {}, ACTIVATION_{}\n }};\n \n '
4343 .format (name , name , name , name , weights [0 ].shape [0 ], weights [0 ].shape [1 ]/ 3 , activation ))
44- hf .write ('#define {}_SIZE {}\n ' .format (name .upper (), weights [0 ].shape [1 ]/ 3 ))
45- hf .write ('extern const GRULayer {};\n \n ' .format (name ));
4644 else :
47- f .write ('const DenseLayer {} = {{\n {}_bias,\n {}_weights,\n {}, {}, ACTIVATION_{}\n }};\n \n '
45+ f .write ('static const DenseLayer {} = {{\n {}_bias,\n {}_weights,\n {}, {}, ACTIVATION_{}\n }};\n \n '
4846 .format (name , name , name , weights [0 ].shape [0 ], weights [0 ].shape [1 ], activation ))
49- hf .write ('#define {}_SIZE {}\n ' .format (name .upper (), weights [0 ].shape [1 ]))
50- hf .write ('extern const DenseLayer {};\n \n ' .format (name ));
47+
48+ def structLayer (f , layer ):
49+ weights = layer .get_weights ()
50+ name = layer .name
51+ if len (weights ) > 2 :
52+ f .write (' {},\n ' .format (weights [0 ].shape [1 ]/ 3 ))
53+ else :
54+ f .write (' {},\n ' .format (weights [0 ].shape [1 ]))
55+ f .write (' &{},\n ' .format (name ))
5156
5257
5358def foo (c , name ):
54- return 1
59+ return None
5560
5661def mean_squared_sqrt_error (y_true , y_pred ):
5762 return K .mean (K .square (K .sqrt (y_pred ) - K .sqrt (y_true )), axis = - 1 )
@@ -62,27 +67,26 @@ def mean_squared_sqrt_error(y_true, y_pred):
6267weights = model .get_weights ()
6368
6469f = open (sys .argv [2 ], 'w' )
65- hf = open (sys .argv [3 ], 'w' )
6670
6771f .write ('/*This file is automatically generated from a Keras model*/\n \n ' )
6872f .write ('#ifdef HAVE_CONFIG_H\n #include "config.h"\n #endif\n \n #include "rnn.h"\n \n ' )
6973
70- hf .write ('/*This file is automatically generated from a Keras model*/\n \n ' )
71- hf .write ('#ifndef RNN_DATA_H\n #define RNN_DATA_H\n \n #include "rnn.h"\n \n ' )
72-
7374layer_list = []
7475for i , layer in enumerate (model .layers ):
7576 if len (layer .get_weights ()) > 0 :
76- printLayer (f , hf , layer )
77+ printLayer (f , layer )
7778 if len (layer .get_weights ()) > 2 :
7879 layer_list .append (layer .name )
7980
80- hf .write ('struct RNNState {\n ' )
81- for i , name in enumerate (layer_list ):
82- hf .write (' float {}_state[{}_SIZE];\n ' .format (name , name .upper ()))
83- hf .write ('};\n ' )
81+ f .write ('const struct RNNModel rnnoise_model_{} = {{\n ' .format (sys .argv [3 ]))
82+ for i , layer in enumerate (model .layers ):
83+ if len (layer .get_weights ()) > 0 :
84+ structLayer (f , layer )
85+ f .write ('};\n ' )
8486
85- hf .write ('\n \n #endif\n ' )
87+ #hf.write('struct RNNState {\n')
88+ #for i, name in enumerate(layer_list):
89+ # hf.write(' float {}_state[{}_SIZE];\n'.format(name, name.upper()))
90+ #hf.write('};\n')
8691
8792f .close ()
88- hf .close ()
0 commit comments