1616
1717import asyncio
1818import collections
19- import concurrent .futures
2019import contextlib
2120import datetime
2221import functools
22+ import io
2323import multiprocessing
2424import multiprocessing .pool
2525import os
5050
5151.. py:attribute:: filename
5252
53- Complete file path to the temporary file in the filesyste ,
53+ Complete file path to the temporary file in the filesystem ,
5454
5555.. py:attribute:: content_type
5656
6161 Filename of the original file being uploaded.
6262"""
6363
64+ ReturnedFile = collections .namedtuple ("ReturnedFile" , ("name" ,
65+ "filename" ,
66+ "content_type" ,
67+ "original_filename" ))
68+ """Class to pass the files returned from predict in a pickable way
69+
70+ .. py:attribute:: name
71+
72+ Name of the argument where this file is being sent.
73+
74+ .. py:attribute:: filename
75+
76+ Complete file path to the temporary file in the filesystem,
77+
78+ .. py:attribute:: content_type
79+
80+ Content-type of the uploaded file
81+
82+ .. py:attribute:: original_filename
83+
84+ Filename of the original file being uploaded.
85+ """
86+
87+
6488# set defaults to None, mainly for compatibility (vkoz)
6589UploadedFile .__new__ .__defaults__ = (None , None , None , None )
90+ ReturnedFile .__new__ .__defaults__ = (None , None , None , None )
6691
6792
6893class ModelWrapper (object ):
@@ -75,7 +100,7 @@ class ModelWrapper(object):
75100 :param name: Model name
76101 :param model: Model object
77102 :raises HTTPInternalServerError: in case that a model has defined
78- a reponse schema that is nod JSON schema valid (DRAFT 4)
103+ a response schema that is not JSON schema valid (DRAFT 4)
79104 """
80105 def __init__ (self , name , model_obj , app ):
81106 self .name = name
@@ -84,11 +109,8 @@ def __init__(self, name, model_obj, app):
84109
85110 self ._loop = asyncio .get_event_loop ()
86111
87- self ._predict_workers = CONF .predict_workers
88- self ._predict_executor = self ._init_predict_executor ()
89-
90- self ._train_workers = CONF .train_workers
91- self ._train_executor = self ._init_train_executor ()
112+ self ._workers = CONF .workers
113+ self ._executor = self ._init_executor ()
92114
93115 self ._setup_cleanup ()
94116
@@ -125,16 +147,10 @@ def _setup_cleanup(self):
125147 self ._app .on_cleanup .append (self ._close_executors )
126148
127149 async def _close_executors (self , app ):
128- self ._train_executor .shutdown ()
129- self ._predict_executor .shutdown ()
130-
131- def _init_predict_executor (self ):
132- n = self ._predict_workers
133- executor = concurrent .futures .ThreadPoolExecutor (max_workers = n )
134- return executor
150+ self ._executor .shutdown ()
135151
136- def _init_train_executor (self ):
137- n = self ._train_workers
152+ def _init_executor (self ):
153+ n = self ._workers
138154 executor = CancellablePool (max_workers = n )
139155 return executor
140156
@@ -168,7 +184,7 @@ def validate_response(self, response):
168184 If the wrapped model has defined a ``response`` attribute we will
169185 validate the response that
170186
171- :param response: The reponse that will be validated.
187+ :param response: The response that will be validated.
172188 :raises exceptions.InternalServerError: in case the reponse cannot be
173189 validated.
174190 """
@@ -213,18 +229,10 @@ def get_metadata(self):
213229 }
214230 return d
215231
216- def _run_in_predict_pool (self , func , * args , ** kwargs ):
217- async def task (fn ):
218- return await self ._loop .run_in_executor (self ._predict_executor , fn )
219-
220- return self ._loop .create_task (
221- task (functools .partial (func , * args , ** kwargs ))
222- )
223-
224- def _run_in_train_pool (self , func , * args , ** kwargs ):
232+ def _run_in_pool (self , func , * args , ** kwargs ):
225233 fn = functools .partial (func , * args , ** kwargs )
226234 ret = self ._loop .create_task (
227- self ._train_executor .apply (fn )
235+ self ._executor .apply (fn )
228236 )
229237 return ret
230238
@@ -243,17 +251,27 @@ async def warm(self):
243251 LOG .debug ("Cannot warm (initialize) model '%s'" % self .name )
244252 return
245253
246- run = self ._loop .run_in_executor
247- executor = self ._predict_executor
248- n = self ._predict_workers
249254 try :
255+ n = self ._workers
250256 LOG .debug ("Warming '%s' model with %s workers" % (self .name , n ))
251- fs = [run ( executor , func ) for i in range (0 , n )]
257+ fs = [self . _run_in_pool ( func ) for _ in range (0 , n )]
252258 await asyncio .gather (* fs )
253259 LOG .debug ("Model '%s' has been warmed" % self .name )
254260 except NotImplementedError :
255261 LOG .debug ("Cannot warm (initialize) model '%s'" % self .name )
256262
263+ @staticmethod
264+ def predict_wrap (predict_func , * args , ** kwargs ):
265+ """Wrapper function to allow returning files from predict
266+ This wrapper exists because buffer objects are not pickable,
267+ thus cannot be returned from the executor.
268+ """
269+ ret = predict_func (* args , ** kwargs )
270+ if isinstance (ret , io .BufferedReader ):
271+ ret = ReturnedFile (filename = ret .name )
272+
273+ return ret
274+
257275 def predict (self , * args , ** kwargs ):
258276 """Perform a prediction on wrapped model's ``predict`` method.
259277
@@ -280,8 +298,8 @@ def predict(self, *args, **kwargs):
280298 # FIXME(aloga); cleanup of tmpfile here
281299
282300 with self ._catch_error ():
283- return self ._run_in_predict_pool (
284- self .model_obj .predict , * args , ** kwargs
301+ return self ._run_in_pool (
302+ self .predict_wrap , self . model_obj .predict , * args , ** kwargs
285303 )
286304
287305 def train (self , * args , ** kwargs ):
@@ -296,7 +314,7 @@ def train(self, *args, **kwargs):
296314 """
297315
298316 with self ._catch_error ():
299- return self ._run_in_train_pool (
317+ return self ._run_in_pool (
300318 self .model_obj .train , * args , ** kwargs
301319 )
302320
0 commit comments