Skip to content

Commit a6bbd4b

Browse files
authored
Upload tf model instead of writing its code (backend) (#21)
* add uploaded model handling * remove unused import
1 parent 9c9eb2f commit a6bbd4b

File tree

2 files changed

+69
-8
lines changed

2 files changed

+69
-8
lines changed

backend/server.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
from translate.graph import Graph
1111

1212
app = Flask(__name__)
13+
app.config['UPLOAD_EXTENSIONS'] = ['.h5']
14+
app.config['UPLOAD_PATH'] = 'uploads'
1315
ok_status = 200
16+
error_status = 400
1417
json_type = {'ContentType': 'application/json'}
1518
text_type = {'ContentType': 'text/plain'}
1619

@@ -60,6 +63,20 @@ def replace_references(net):
6063
outp.append(i)
6164
layer.output = outp
6265

66+
def check_uploads_path_exists(identifier):
67+
if not os.path.exists(app.config['UPLOAD_PATH']):
68+
os.mkdir(app.config['UPLOAD_PATH'])
69+
if not os.path.exists(os.path.join(app.config['UPLOAD_PATH'], identifier)):
70+
os.mkdir(os.path.join(app.config['UPLOAD_PATH'], identifier))
71+
os.mkdir(os.path.join(app.config['UPLOAD_PATH'], identifier, 'visualizations'))
72+
copyfile(os.path.join('default', 'layer_types_current.json'),
73+
os.path.join(app.config['UPLOAD_PATH'], identifier, 'layer_types_current.json'))
74+
copyfile(os.path.join('default', 'preferences.json'),
75+
os.path.join(app.config['UPLOAD_PATH'], identifier, 'preferences.json'))
76+
copyfile(os.path.join('default', 'groups.json'),
77+
os.path.join(app.config['UPLOAD_PATH'], identifier, 'groups.json'))
78+
copyfile(os.path.join('default', 'legend_preferences.json'),
79+
os.path.join(app.config['UPLOAD_PATH'], identifier, 'legend_preferences.json'))
6380

6481
def check_exists(identifier):
6582
"""Check if the desired model already exists.
@@ -149,7 +166,12 @@ def get_network(identifier):
149166
object -- a http response containing the network as json
150167
"""
151168
check_exists(identifier)
152-
graph = translate_keras(os.path.join('models', identifier,
169+
check_uploads_path_exists(identifier)
170+
if 'model.h5' in ls(os.path.join(app.config['UPLOAD_PATH'], identifier)):
171+
graph = translate_keras(os.path.join(app.config['UPLOAD_PATH'], identifier,
172+
'model.h5'))
173+
else:
174+
graph = translate_keras(os.path.join('models', identifier,
153175
'model_current.py'))
154176
if isinstance(graph, Graph):
155177
net = {'layers': make_jsonifyable(graph)}
@@ -192,6 +214,26 @@ def update_code(identifier):
192214
file.write(content.decode("utf-8"))
193215
return content, ok_status, text_type
194216

217+
@app.route('/api/upload_model/<identifier>', methods=['POST'])
218+
def upload_model(identifier):
219+
"""Update the Code.
220+
221+
Arguments:
222+
identifier {String} -- the identifier for the requested network
223+
224+
Returns:
225+
object -- a http response signaling the change worked
226+
"""
227+
check_uploads_path_exists(identifier)
228+
uploaded_file = request.files['model']
229+
filename = uploaded_file.filename
230+
file_ext = os.path.splitext(filename)[1]
231+
if file_ext not in app.config['UPLOAD_EXTENSIONS']:
232+
return "", error_status, text_type
233+
file_path = os.path.join(app.config['UPLOAD_PATH'], identifier, 'model.h5')
234+
uploaded_file.save(file_path)
235+
return "", ok_status, text_type
236+
195237

196238
@app.route('/api/get_layer_types/<identifier>')
197239
def get_layer_types(identifier):

backend/translate/translate_keras.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from translate.graph import Graph
66
import translate.layer as layer
77

8+
keras_ext = '.h5'
9+
810

911
def translate_keras(filename):
1012
"""Translate a keras model defined in a file into the neural network graph.
@@ -19,13 +21,16 @@ def translate_keras(filename):
1921
epicbox.Profile('python', 'tf_plus_keras:latest')])
2022
general_reader = open('translate/keras_loader.txt', 'rb')
2123
general_code = general_reader.read()
22-
with open(filename, 'rb') as myfile:
23-
keras_code = myfile.read()
24-
try:
25-
return graph_from_external_file(keras_code, general_code)
26-
except Exception as err:
27-
return {'error_class': '', 'line_number': 1,
28-
'detail': str(err)}
24+
if keras_ext in filename:
25+
return graph_from_model_file(filename)
26+
else:
27+
with open(filename, 'rb') as myfile:
28+
keras_code = myfile.read()
29+
try:
30+
return graph_from_external_file(keras_code, general_code)
31+
except Exception as err:
32+
return {'error_class': '', 'line_number': 1,
33+
'detail': str(err)}
2934

3035

3136
def graph_from_external_file(keras_code, general_code):
@@ -61,6 +66,20 @@ def graph_from_external_file(keras_code, general_code):
6166
graph.resolve_input_names()
6267
return graph
6368

69+
def graph_from_model_file(keras_model_file):
70+
model_keras = keras.models.load_model(keras_model_file)
71+
model_json = model_keras.to_json()
72+
layers_extracted = model_json['config']['layers']
73+
graph = Graph()
74+
previous_node = ''
75+
for index, json_layer in enumerate(layers_extracted):
76+
if len(layers_extracted) > len(model_keras.layers):
77+
index = index - 1
78+
if index >= 0:
79+
previous_node = add_layer_type(json_layer, model_keras.layers[index], graph,
80+
previous_node)
81+
graph.resolve_input_names()
82+
return graph
6483

6584
def add_layer_type(layer_json, model_layer, graph, previous_node):
6685
"""Add a Layer. Layers are identified by name and equipped using the spec.

0 commit comments

Comments
 (0)