|
10 | 10 | from translate.graph import Graph |
11 | 11 |
|
12 | 12 | app = Flask(__name__) |
| 13 | +app.config['UPLOAD_EXTENSIONS'] = ['.h5'] |
| 14 | +app.config['UPLOAD_PATH'] = 'uploads' |
13 | 15 | ok_status = 200 |
| 16 | +error_status = 400 |
14 | 17 | json_type = {'ContentType': 'application/json'} |
15 | 18 | text_type = {'ContentType': 'text/plain'} |
16 | 19 |
|
@@ -60,6 +63,20 @@ def replace_references(net): |
60 | 63 | outp.append(i) |
61 | 64 | layer.output = outp |
62 | 65 |
|
| 66 | +def check_uploads_path_exists(identifier): |
| 67 | + if not os.path.exists(app.config['UPLOAD_PATH']): |
| 68 | + os.mkdir(app.config['UPLOAD_PATH']) |
| 69 | + if not os.path.exists(os.path.join(app.config['UPLOAD_PATH'], identifier)): |
| 70 | + os.mkdir(os.path.join(app.config['UPLOAD_PATH'], identifier)) |
| 71 | + os.mkdir(os.path.join(app.config['UPLOAD_PATH'], identifier, 'visualizations')) |
| 72 | + copyfile(os.path.join('default', 'layer_types_current.json'), |
| 73 | + os.path.join(app.config['UPLOAD_PATH'], identifier, 'layer_types_current.json')) |
| 74 | + copyfile(os.path.join('default', 'preferences.json'), |
| 75 | + os.path.join(app.config['UPLOAD_PATH'], identifier, 'preferences.json')) |
| 76 | + copyfile(os.path.join('default', 'groups.json'), |
| 77 | + os.path.join(app.config['UPLOAD_PATH'], identifier, 'groups.json')) |
| 78 | + copyfile(os.path.join('default', 'legend_preferences.json'), |
| 79 | + os.path.join(app.config['UPLOAD_PATH'], identifier, 'legend_preferences.json')) |
63 | 80 |
|
64 | 81 | def check_exists(identifier): |
65 | 82 | """Check if the desired model already exists. |
@@ -149,7 +166,12 @@ def get_network(identifier): |
149 | 166 | object -- a http response containing the network as json |
150 | 167 | """ |
151 | 168 | check_exists(identifier) |
152 | | - graph = translate_keras(os.path.join('models', identifier, |
| 169 | + check_uploads_path_exists(identifier) |
| 170 | + if 'model.h5' in ls(os.path.join(app.config['UPLOAD_PATH'], identifier)): |
| 171 | + graph = translate_keras(os.path.join(app.config['UPLOAD_PATH'], identifier, |
| 172 | + 'model.h5')) |
| 173 | + else: |
| 174 | + graph = translate_keras(os.path.join('models', identifier, |
153 | 175 | 'model_current.py')) |
154 | 176 | if isinstance(graph, Graph): |
155 | 177 | net = {'layers': make_jsonifyable(graph)} |
@@ -192,6 +214,26 @@ def update_code(identifier): |
192 | 214 | file.write(content.decode("utf-8")) |
193 | 215 | return content, ok_status, text_type |
194 | 216 |
|
| 217 | +@app.route('/api/upload_model/<identifier>', methods=['POST']) |
| 218 | +def upload_model(identifier): |
| 219 | + """Update the Code. |
| 220 | +
|
| 221 | + Arguments: |
| 222 | + identifier {String} -- the identifier for the requested network |
| 223 | +
|
| 224 | + Returns: |
| 225 | + object -- a http response signaling the change worked |
| 226 | + """ |
| 227 | + check_uploads_path_exists(identifier) |
| 228 | + uploaded_file = request.files['model'] |
| 229 | + filename = uploaded_file.filename |
| 230 | + file_ext = os.path.splitext(filename)[1] |
| 231 | + if file_ext not in app.config['UPLOAD_EXTENSIONS']: |
| 232 | + return "", error_status, text_type |
| 233 | + file_path = os.path.join(app.config['UPLOAD_PATH'], identifier, 'model.h5') |
| 234 | + uploaded_file.save(file_path) |
| 235 | + return "", ok_status, text_type |
| 236 | + |
195 | 237 |
|
196 | 238 | @app.route('/api/get_layer_types/<identifier>') |
197 | 239 | def get_layer_types(identifier): |
|
0 commit comments