@@ -136,9 +136,13 @@ def create_prediction_request(transformed_sample):
136136 shape = [1 ]
137137 if util .is_list (value ):
138138 shape = [len (value )]
139- tensor_proto = tf .make_tensor_proto ([value ], dtype = data_type , shape = shape )
140- prediction_request .inputs [column_name ].CopyFrom (tensor_proto )
141-
139+ try :
140+ tensor_proto = tf .make_tensor_proto ([value ], dtype = data_type , shape = shape )
141+ prediction_request .inputs [column_name ].CopyFrom (tensor_proto )
142+ except Exception as e :
143+ raise UserException (
144+ 'key "{}"' .format (column_name ), "expected shape {}" .format (shape )
145+ ) from e
142146 return prediction_request
143147
144148
@@ -160,8 +164,15 @@ def create_raw_prediction_request(sample):
160164 shape = [1 ]
161165 value = [value ]
162166 sig_type = signature_def [signature_key ]["inputs" ][column_name ]["dtype" ]
163- tensor_proto = tf .make_tensor_proto (value , dtype = DTYPE_TO_TF_TYPE [sig_type ], shape = shape )
164- prediction_request .inputs [column_name ].CopyFrom (tensor_proto )
167+ try :
168+ tensor_proto = tf .make_tensor_proto (
169+ value , dtype = DTYPE_TO_TF_TYPE [sig_type ], shape = shape
170+ )
171+ prediction_request .inputs [column_name ].CopyFrom (tensor_proto )
172+ except Exception as e :
173+ raise UserException (
174+ 'key "{}"' .format (column_name ), "expected shape {}" .format (shape )
175+ ) from e
165176
166177 return prediction_request
167178
@@ -248,7 +259,7 @@ def create_get_model_metadata_request():
248259
249260def run_get_model_metadata ():
250261 request = create_get_model_metadata_request ()
251- resp = local_cache ["stub" ].GetModelMetadata (request , timeout = 10 .0 )
262+ resp = local_cache ["stub" ].GetModelMetadata (request , timeout = 30 .0 )
252263 sigAny = resp .metadata ["signature_def" ]
253264 signature_def_map = get_model_metadata_pb2 .SignatureDefMap ()
254265 sigAny .Unpack (signature_def_map )
@@ -272,14 +283,11 @@ def run_predict(sample):
272283 ctx = local_cache ["ctx" ]
273284 request_handler = local_cache .get ("request_handler" )
274285
275- logger .info ("sample: " + util .pp_str_flat (sample ))
276-
277286 prepared_sample = sample
278287 if request_handler is not None and util .has_function (request_handler , "pre_inference" ):
279288 prepared_sample = request_handler .pre_inference (
280289 sample , local_cache ["metadata" ]["signatureDef" ]
281290 )
282- logger .info ("pre_inference: " + util .pp_str_flat (prepared_sample ))
283291
284292 validate_sample (prepared_sample )
285293
@@ -291,24 +299,18 @@ def run_predict(sample):
291299 )
292300
293301 transformed_sample = transform_sample (prepared_sample )
294- logger .info ("transformed_sample: " + util .pp_str_flat (transformed_sample ))
295-
296302 prediction_request = create_prediction_request (transformed_sample )
297- response_proto = local_cache ["stub" ].Predict (prediction_request , timeout = 10 .0 )
303+ response_proto = local_cache ["stub" ].Predict (prediction_request , timeout = 100 .0 )
298304 result = parse_response_proto (response_proto )
299305
300306 result ["transformed_sample" ] = transformed_sample
301- logger .info ("inference: " + util .pp_str_flat (result ))
302307 else :
303308 prediction_request = create_raw_prediction_request (prepared_sample )
304- response_proto = local_cache ["stub" ].Predict (prediction_request , timeout = 10 .0 )
309+ response_proto = local_cache ["stub" ].Predict (prediction_request , timeout = 100 .0 )
305310 result = parse_response_proto_raw (response_proto )
306311
307- logger .info ("inference: " + util .pp_str_flat (result ))
308-
309312 if request_handler is not None and util .has_function (request_handler , "post_inference" ):
310313 result = request_handler .post_inference (result , local_cache ["metadata" ]["signatureDef" ])
311- logger .info ("post_inference: " + util .pp_str_flat (result ))
312314
313315 return result
314316
@@ -335,10 +337,8 @@ def validate_sample(sample):
335337 raise UserException ('missing key "{}"' .format (input_name ))
336338
337339
338- def prediction_failed (sample , reason = None ):
339- message = "prediction failed for sample: {}" .format (util .pp_str_flat (sample ))
340- if reason :
341- message += " ({})" .format (reason )
340+ def prediction_failed (reason ):
341+ message = "prediction failed: " + reason
342342
343343 logger .error (message )
344344 return message , status .HTTP_406_NOT_ACCEPTABLE
@@ -363,16 +363,12 @@ def predict(deployment_name, api_name):
363363 response = {}
364364
365365 if not util .is_dict (payload ) or "samples" not in payload :
366- util .log_pretty_flat (payload , logging_func = logger .error )
367- return prediction_failed (payload , "top level `samples` key not found in request" )
366+ return prediction_failed ('top level "samples" key not found in request' )
368367
369368 predictions = []
370369 samples = payload ["samples" ]
371370 if not util .is_list (samples ):
372- util .log_pretty_flat (samples , logging_func = logger .error )
373- return prediction_failed (
374- payload , "expected the value of key `samples` to be a list of json objects"
375- )
371+ return prediction_failed ('expected the value of key "samples" to be a list of json objects' )
376372
377373 for i , sample in enumerate (payload ["samples" ]):
378374 try :
@@ -385,14 +381,14 @@ def predict(deployment_name, api_name):
385381 api ["name" ]
386382 )
387383 )
388- return prediction_failed (sample , str (e ))
384+ return prediction_failed (str (e ))
389385 except Exception as e :
390386 logger .exception (
391387 "An error occurred, see `cortex logs -v api {}` for more details." .format (
392388 api ["name" ]
393389 )
394390 )
395- return prediction_failed (sample , str (e ))
391+ return prediction_failed (str (e ))
396392
397393 predictions .append (result )
398394
0 commit comments