Skip to content

Commit 3901ab3

Browse files
authored
Support query params, headers, and non-json payloads (bytes and forms) (#1062)
1 parent 25e8148 commit 3901ab3

File tree

4 files changed

+120
-33
lines changed

4 files changed

+120
-33
lines changed

docs/deployments/predictors.md

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,13 @@ class PythonPredictor:
5555
"""
5656
pass
5757

58-
def predict(self, payload):
58+
def predict(self, payload, query_params, headers):
5959
"""Called once per request. Preprocesses the request payload (if necessary), runs inference, and postprocesses the inference output (if necessary).
6060
6161
Args:
62-
payload: The parsed JSON request payload.
62+
payload: The request payload (see below for the possible payload types) (optional).
63+
query_params: A dictionary of the query parameters used in the request (optional).
64+
headers: A dictionary of the headers sent in the request (optional).
6365
6466
Returns:
6567
Prediction or a batch of predictions.
@@ -69,6 +71,8 @@ class PythonPredictor:
6971

7072
For proper separation of concerns, it is recommended to use the constructor's `config` paramater for information such as from where to download the model and initialization files, or any configurable model parameters. You define `config` in your [API configuration](api-configuration.md), and it is passed through to your Predictor's constructor.
7173

74+
The `payload` parameter is parsed according to the `Content-Type` header in the request. For `Content-Type: application/json`, `payload` will be the parsed JSON body. For `Content-Type: multipart/form` or `Content-Type: application/x-www-form-urlencoded`, `payload` will be `starlette.datastructures.FormData` (key-value pairs where the value is a `string` for form data, or `starlette.datastructures.UploadFile` for file uploads, see [Starlette's documentation](https://www.starlette.io/requests/#request-files)). For all other `Content-Type` values, `payload` will the the raw `bytes` of the request body.
75+
7276
### Examples
7377

7478
<!-- CORTEX_VERSION_MINOR -->
@@ -173,11 +177,13 @@ class TensorFlowPredictor:
173177
self.client = tensorflow_client
174178
# Additional initialization may be done here
175179

176-
def predict(self, payload):
180+
def predict(self, payload, query_params, headers):
177181
"""Called once per request. Preprocesses the request payload (if necessary), runs inference (e.g. by calling self.client.predict(model_input)), and postprocesses the inference output (if necessary).
178182
179183
Args:
180-
payload: The parsed JSON request payload.
184+
payload: The request payload (see below for the possible payload types) (optional).
185+
query_params: A dictionary of the query parameters used in the request (optional).
186+
headers: A dictionary of the headers sent in the request (optional).
181187
182188
Returns:
183189
Prediction or a batch of predictions.
@@ -190,6 +196,8 @@ Cortex provides a `tensorflow_client` to your Predictor's constructor. `tensorfl
190196

191197
For proper separation of concerns, it is recommended to use the constructor's `config` paramater for information such as configurable model parameters or download links for initialization files. You define `config` in your [API configuration](api-configuration.md), and it is passed through to your Predictor's constructor.
192198

199+
The `payload` parameter is parsed according to the `Content-Type` header in the request. For `Content-Type: application/json`, `payload` will be the parsed JSON body. For `Content-Type: multipart/form` or `Content-Type: application/x-www-form-urlencoded`, `payload` will be `starlette.datastructures.FormData` (key-value pairs where the value is a `string` for form data, or `starlette.datastructures.UploadFile` for file uploads, see [Starlette's documentation](https://www.starlette.io/requests/#request-files)). For all other `Content-Type` values, `payload` will the the raw `bytes` of the request body.
200+
193201
### Examples
194202

195203
<!-- CORTEX_VERSION_MINOR -->
@@ -249,11 +257,13 @@ class ONNXPredictor:
249257
self.client = onnx_client
250258
# Additional initialization may be done here
251259

252-
def predict(self, payload):
260+
def predict(self, payload, query_params, headers):
253261
"""Called once per request. Preprocesses the request payload (if necessary), runs inference (e.g. by calling self.client.predict(model_input)), and postprocesses the inference output (if necessary).
254262
255263
Args:
256-
payload: The parsed JSON request payload.
264+
payload: The request payload (see below for the possible payload types) (optional).
265+
query_params: A dictionary of the query parameters used in the request (optional).
266+
headers: A dictionary of the headers sent in the request (optional).
257267
258268
Returns:
259269
Prediction or a batch of predictions.
@@ -266,6 +276,8 @@ Cortex provides an `onnx_client` to your Predictor's constructor. `onnx_client`
266276

267277
For proper separation of concerns, it is recommended to use the constructor's `config` paramater for information such as configurable model parameters or download links for initialization files. You define `config` in your [API configuration](api-configuration.md), and it is passed through to your Predictor's constructor.
268278

279+
The `payload` parameter is parsed according to the `Content-Type` header in the request. For `Content-Type: application/json`, `payload` will be the parsed JSON body. For `Content-Type: multipart/form` or `Content-Type: application/x-www-form-urlencoded`, `payload` will be `starlette.datastructures.FormData` (key-value pairs where the value is a `string` for form data, or `starlette.datastructures.UploadFile` for file uploads, see [Starlette's documentation](https://www.starlette.io/requests/#request-files)). For all other `Content-Type` values, `payload` will the the raw `bytes` of the request body.
280+
269281
### Examples
270282

271283
<!-- CORTEX_VERSION_MINOR -->

pkg/workloads/cortex/lib/type/predictor.py

Lines changed: 62 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,12 @@ def initialize_client(self, model_dir=None, tf_serving_host=None, tf_serving_por
5757
def initialize_impl(self, project_dir, client=None):
5858
class_impl = self.class_impl(project_dir)
5959
try:
60-
if self.type == "python":
61-
return class_impl(self.config)
60+
if self.type == "onnx":
61+
return class_impl(onnx_client=client, config=self.config)
62+
elif self.type == "tensorflow":
63+
return class_impl(tensorflow_client=client, config=self.config)
6264
else:
63-
return class_impl(client, self.config)
65+
return class_impl(config=self.config)
6466
except Exception as e:
6567
raise UserRuntimeException(self.path, "__init__", str(e)) from e
6668
finally:
@@ -128,55 +130,90 @@ def _load_module(self, module_name, impl_path):
128130

129131
PYTHON_CLASS_VALIDATION = {
130132
"required": [
131-
{"name": "__init__", "args": ["self", "config"]},
132-
{"name": "predict", "args": ["self", "payload"]},
133+
{"name": "__init__", "required_args": ["self", "config"]},
134+
{
135+
"name": "predict",
136+
"required_args": ["self"],
137+
"optional_args": ["payload", "query_params", "headers"],
138+
},
133139
]
134140
}
135141

136142
TENSORFLOW_CLASS_VALIDATION = {
137143
"required": [
138-
{"name": "__init__", "args": ["self", "tensorflow_client", "config"]},
139-
{"name": "predict", "args": ["self", "payload"]},
144+
{"name": "__init__", "required_args": ["self", "tensorflow_client", "config"]},
145+
{
146+
"name": "predict",
147+
"required_args": ["self"],
148+
"optional_args": ["payload", "query_params", "headers"],
149+
},
140150
]
141151
}
142152

143153
ONNX_CLASS_VALIDATION = {
144154
"required": [
145-
{"name": "__init__", "args": ["self", "onnx_client", "config"]},
146-
{"name": "predict", "args": ["self", "payload"]},
155+
{"name": "__init__", "required_args": ["self", "onnx_client", "config"]},
156+
{
157+
"name": "predict",
158+
"required_args": ["self"],
159+
"optional_args": ["payload", "query_params", "headers"],
160+
},
147161
]
148162
}
149163

150164

151165
def _validate_impl(impl, impl_req):
152-
for optional_func in impl_req.get("optional", []):
153-
_validate_optional_fn_args(impl, optional_func["name"], optional_func["args"])
166+
for optional_func_signature in impl_req.get("optional", []):
167+
_validate_optional_fn_args(impl, optional_func_signature)
154168

155-
for required_func in impl_req.get("required", []):
156-
_validate_required_fn_args(impl, required_func["name"], required_func["args"])
169+
for required_func_signature in impl_req.get("required", []):
170+
_validate_required_fn_args(impl, required_func_signature)
157171

158172

159-
def _validate_optional_fn_args(impl, fn_name, args):
160-
if fn_name in vars(impl):
161-
_validate_required_fn_args(impl, fn_name, args)
173+
def _validate_optional_fn_args(impl, func_signature):
174+
if getattr(impl, func_signature["name"], None):
175+
_validate_required_fn_args(impl, func_signature)
162176

163177

164-
def _validate_required_fn_args(impl, fn_name, args):
165-
fn = getattr(impl, fn_name, None)
178+
def _validate_required_fn_args(impl, func_signature):
179+
fn = getattr(impl, func_signature["name"], None)
166180
if not fn:
167-
raise UserException('required function "{}" is not defined'.format(fn_name))
181+
raise UserException(f'required function "{func_signature["name"]}" is not defined')
168182

169183
if not callable(fn):
170-
raise UserException('"{}" is defined, but is not a function'.format(fn_name))
184+
raise UserException(f'"{func_signature["name"]}" is defined, but is not a function')
171185

172186
argspec = inspect.getfullargspec(fn)
173187

174-
if argspec.args != args:
175-
raise UserException(
176-
'invalid signature for function "{}": expected arguments ({}) but found ({})'.format(
177-
fn_name, ", ".join(args), ", ".join(argspec.args)
188+
required_args = func_signature.get("required_args", [])
189+
optional_args = func_signature.get("optional_args", [])
190+
fn_str = f'{func_signature["name"]}({", ".join(argspec.args)})'
191+
192+
for arg_name in required_args:
193+
if arg_name not in argspec.args:
194+
raise UserException(
195+
f'invalid signature for function "{fn_str}": "{arg_name}" is a required argument, but was not provided'
196+
)
197+
198+
if arg_name == "self":
199+
if argspec.args[0] != "self":
200+
raise UserException(
201+
f'invalid signature for function "{fn_str}": "self" must be the first argument'
202+
)
203+
204+
seen_args = []
205+
for arg_name in argspec.args:
206+
if arg_name not in required_args and arg_name not in optional_args:
207+
raise UserException(
208+
f'invalid signature for function "{fn_str}": "{arg_name}" is not a supported argument'
178209
)
179-
)
210+
211+
if arg_name in seen_args:
212+
raise UserException(
213+
f'invalid signature for function "{fn_str}": "{arg_name}" is duplicated'
214+
)
215+
216+
seen_args.append(arg_name)
180217

181218

182219
tf_expected_dir_structure = """tensorflow model directories must have the following structure:

pkg/workloads/cortex/serve/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ dill==0.3.1.1
44
fastapi==0.54.1
55
msgpack==1.0.0
66
numpy==1.18.4
7+
python-multipart==0.0.5
78
pyyaml==5.3.1
89
requests==2.23.0
910
uvicorn==0.11.5

pkg/workloads/cortex/serve/serve.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import sys
1616
import os
1717
import argparse
18+
import inspect
1819
import time
1920
import json
2021
import msgpack
@@ -154,11 +155,31 @@ async def register_request(request: Request, call_next):
154155
return response
155156

156157

157-
def predict(request: Any = Body(..., media_type="application/json")):
158+
@app.middleware("http")
159+
async def parse_payload(request: Request, call_next):
160+
if "payload" not in local_cache["predict_fn_args"]:
161+
return await call_next(request)
162+
163+
content_type = request.headers.get("content-type", "").lower()
164+
165+
if content_type.startswith("multipart/form") or content_type.startswith(
166+
"application/x-www-form-urlencoded"
167+
):
168+
request.state.payload = await request.form()
169+
elif content_type.startswith("application/json"):
170+
request.state.payload = await request.json()
171+
else:
172+
request.state.payload = await request.body()
173+
174+
return await call_next(request)
175+
176+
177+
def predict(request: Request):
158178
api = local_cache["api"]
159179
predictor_impl = local_cache["predictor_impl"]
180+
args = build_predict_args(request)
160181

161-
prediction = predictor_impl.predict(request)
182+
prediction = predictor_impl.predict(**args)
162183

163184
if isinstance(prediction, bytes):
164185
response = Response(content=prediction, media_type="application/octet-stream")
@@ -194,6 +215,19 @@ def predict(request: Any = Body(..., media_type="application/json")):
194215
return response
195216

196217

218+
def build_predict_args(request: Request):
219+
args = {}
220+
221+
if "payload" in local_cache["predict_fn_args"]:
222+
args["payload"] = request.state.payload
223+
if "headers" in local_cache["predict_fn_args"]:
224+
args["headers"] = request.headers
225+
if "query_params" in local_cache["predict_fn_args"]:
226+
args["query_params"] = request.query_params
227+
228+
return args
229+
230+
197231
def get_summary():
198232
response = {"message": API_SUMMARY_MESSAGE}
199233

@@ -230,6 +264,7 @@ def start():
230264
storage = LocalStorage(os.getenv("CORTEX_CACHE_DIR"))
231265
else:
232266
storage = S3(bucket=os.environ["CORTEX_BUCKET"], region=os.environ["AWS_REGION"])
267+
233268
try:
234269
raw_api_spec = get_spec(provider, storage, cache_dir, spec_path)
235270
api = API(provider=provider, storage=storage, cache_dir=cache_dir, **raw_api_spec)
@@ -243,13 +278,15 @@ def start():
243278
local_cache["provider"] = provider
244279
local_cache["client"] = client
245280
local_cache["predictor_impl"] = predictor_impl
281+
local_cache["predict_fn_args"] = inspect.getfullargspec(predictor_impl.predict).args
246282
predict_route = "/"
247283
if provider != "local":
248284
predict_route = "/predict"
249285
local_cache["predict_route"] = predict_route
250286
except:
251287
cx_logger().exception("failed to start api")
252288
sys.exit(1)
289+
253290
if (
254291
provider != "local"
255292
and api.monitoring is not None

0 commit comments

Comments
 (0)