55from typing import List , Union
66from urllib .parse import urlparse
77
8+ from tornado .iostream import StreamClosedError
89from tornado .web import HTTPError
910
1011from ads .aqua .common .decorator import handle_exceptions
@@ -175,21 +176,9 @@ def list_shapes(self):
175176 )
176177
177178
178- class AquaDeploymentInferenceHandler (AquaAPIhandler ):
179- @staticmethod
180- def validate_predict_url (endpoint ):
181- try :
182- url = urlparse (endpoint )
183- if url .scheme != "https" :
184- return False
185- if not url .netloc :
186- return False
187- return url .path .endswith ("/predict" )
188- except Exception :
189- return False
190-
179+ class AquaDeploymentStreamingInferenceHandler (AquaAPIhandler ):
191180 @handle_exceptions
192- def post (self , * args , ** kwargs ): # noqa: ARG002
181+ async def post (self , * args , ** kwargs ): # noqa: ARG002
193182 """
194183 Handles inference request for the Active Model Deployments
195184 Raises
@@ -205,12 +194,7 @@ def post(self, *args, **kwargs): # noqa: ARG002
205194 if not input_data :
206195 raise HTTPError (400 , Errors .NO_INPUT_DATA )
207196
208- endpoint = input_data .get ("endpoint" )
209- if not endpoint :
210- raise HTTPError (400 , Errors .MISSING_REQUIRED_PARAMETER .format ("endpoint" ))
211-
212- if not self .validate_predict_url (endpoint ):
213- raise HTTPError (400 , Errors .INVALID_INPUT_DATA_FORMAT .format ("endpoint" ))
197+ model_deployment_id = input_data .get ("id" )
214198
215199 prompt = input_data .get ("prompt" )
216200 if not prompt :
@@ -226,11 +210,24 @@ def post(self, *args, **kwargs): # noqa: ARG002
226210 400 , Errors .INVALID_INPUT_DATA_FORMAT .format ("model_params" )
227211 ) from ex
228212
229- return self .finish (
230- MDInferenceResponse (prompt , model_params_obj ).get_model_deployment_response (
231- endpoint
232- )
233- )
213+ self .set_header ("Content-Type" , "text/event-stream" )
214+ self .set_header ("Cache-Control" , "no-cache" )
215+ self .set_header ("Transfer-Encoding" , "chunked" )
216+ await self .flush ()
217+
218+ try :
219+ response_gen = MDInferenceResponse (
220+ prompt , model_params_obj
221+ ).get_model_deployment_response (model_deployment_id )
222+ for chunk in response_gen :
223+ if not chunk :
224+ continue
225+ self .write (f"data: { chunk } \n \n " )
226+ await self .flush ()
227+ except StreamClosedError :
228+ self .log .warning ("Client disconnected." )
229+ finally :
230+ self .finish ()
234231
235232
236233class AquaDeploymentParamsHandler (AquaAPIhandler ):
@@ -294,5 +291,5 @@ def post(self, *args, **kwargs): # noqa: ARG002
294291 ("deployments/?([^/]*)" , AquaDeploymentHandler ),
295292 ("deployments/?([^/]*)/activate" , AquaDeploymentHandler ),
296293 ("deployments/?([^/]*)/deactivate" , AquaDeploymentHandler ),
297- ("inference" , AquaDeploymentInferenceHandler ),
294+ ("inference" , AquaDeploymentStreamingInferenceHandler ),
298295]
0 commit comments