Skip to content

Commit 429d4ea

Browse files
author
Ziqun Ye
committed
add implementation for predict
1 parent 072a048 commit 429d4ea

File tree

2 files changed

+46
-9
lines changed

2 files changed

+46
-9
lines changed

ads/model/deployment/model_deployment.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,8 @@ def predict(
828828
data: Any = None,
829829
serializer: "ads.model.ModelInputSerializer" = model_input_serializer,
830830
auto_serialize_data: bool = False,
831+
model_name: str = None,
832+
model_version: str = None,
831833
**kwargs,
832834
) -> dict:
833835
"""Returns prediction of input data run against the model deployment endpoint.
@@ -860,6 +862,10 @@ def predict(
860862
If `auto_serialize_data=False`, `data` required to be bytes or json serializable
861863
and `json_input` required to be json serializable. If `auto_serialize_data` set
862864
to True, data will be serialized before sending to model deployment endpoint.
865+
model_name: str
866+
Defaults to None. When the `Inference_server="triton"`, the name of the model to invoke.
867+
model_version: str
868+
Defaults to None. When the `Inference_server="triton"`, the version of the model to invoke.
863869
kwargs:
864870
content_type: str
865871
Used to indicate the media type of the resource.
@@ -917,9 +923,16 @@ def predict(
917923
raise TypeError(
918924
"`data` is not bytes or json serializable. Set `auto_serialize_data` to `True` to serialize the input data."
919925
)
920-
926+
if model_name and model_version:
927+
header['model-name'] = model_name
928+
header['model-version'] = model_version
929+
elif not model_version and not model_name:
930+
931+
pass
932+
else:
933+
raise ValueError("`model_name` and `model_version` have to be provided together.")
921934
prediction = send_request(
922-
data=data, endpoint=endpoint, is_json_payload=is_json_payload, header=header
935+
data=data, endpoint=endpoint, is_json_payload=is_json_payload, header=header,
923936
)
924937
return prediction
925938

@@ -1391,9 +1404,9 @@ def _update_from_oci_model(self, oci_model_instance) -> "ModelDeployment":
13911404
infrastructure.CONST_WEB_CONCURRENCY,
13921405
runtime.env.get("WEB_CONCURRENCY", None),
13931406
)
1394-
if runtime.env.get("CONTAINER_TYPE", None) == "TRITON":
1407+
if runtime.env.pop("CONTAINER_TYPE", None) == "TRITON":
13951408
runtime.set_spec(
1396-
runtime.CONST_TRITON, True
1409+
runtime.CONST_INFERENCE_SERVER, "triton"
13971410
)
13981411

13991412
self.set_spec(self.CONST_INFRASTRUCTURE, infrastructure)
@@ -1571,7 +1584,7 @@ def _build_model_deployment_configuration_details(self) -> Dict:
15711584
infrastructure.web_concurrency
15721585
)
15731586
runtime.set_spec(runtime.CONST_ENV, environment_variables)
1574-
if runtime.triton:
1587+
if runtime.inference_server.lower() == "triton":
15751588
environment_variables["CONTAINER_TYPE"] = "TRITON"
15761589
runtime.set_spec(runtime.CONST_ENV, environment_variables)
15771590
environment_configuration_details = {

ads/model/deployment/model_deployment_runtime.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ class ModelDeploymentContainerRuntime(ModelDeploymentRuntime):
330330
CONST_ENTRYPOINT = "entrypoint"
331331
CONST_SERVER_PORT = "serverPort"
332332
CONST_HEALTH_CHECK_PORT = "healthCheckPort"
333-
CONST_TRITON = "triton"
333+
CONST_INFERENCE_SERVER = "inferenceServer"
334334

335335
attribute_map = {
336336
**ModelDeploymentRuntime.attribute_map,
@@ -340,7 +340,7 @@ class ModelDeploymentContainerRuntime(ModelDeploymentRuntime):
340340
CONST_ENTRYPOINT: "entrypoint",
341341
CONST_SERVER_PORT: "server_port",
342342
CONST_HEALTH_CHECK_PORT: "health_check_port",
343-
CONST_TRITON: "triton"
343+
CONST_INFERENCE_SERVER: "inference_server"
344344
}
345345

346346
payload_attribute_map = {
@@ -544,7 +544,7 @@ def inference_server(self) -> str:
544544
str
545545
The inference server.
546546
"""
547-
return self.get_spec(self.CONST_TRITON, None)
547+
return self.get_spec(self.CONST_INFERENCE_SERVER, None)
548548

549549
def with_inference_server(self, inference_server: str = "triton") -> "ModelDeploymentRuntime":
550550
"""Sets the inference server. Current supported inference server is "triton".
@@ -559,5 +559,29 @@ def with_inference_server(self, inference_server: str = "triton") -> "ModelDeplo
559559
-------
560560
ModelDeploymentRuntime
561561
The ModelDeploymentRuntime instance (self).
562+
563+
Example
564+
-------
565+
>>> infrastructure = ModelDeploymentInfrastructure()\
566+
... .with_project_id(<project_id>)\
567+
... .with_compartment_id(<comparment_id>)\
568+
... .with_shape_name("VM.Standard.E4.Flex")\
569+
... .with_replica(2)\
570+
... .with_bandwidth_mbps(10)\
571+
... .with_access_log(log_group_id=<deployment_log_group_id>, log_id=<deployment_access_log_id>)\
572+
... .with_predict_log(log_group_id=<deployment_log_group_id>, log_id=<deployment_predict_log_id>)
573+
574+
>>> runtime = ModelDeploymentContainerRuntime()\
575+
... .with_image(<container_image>)\
576+
... .with_server_port(<server_port>)\
577+
... .with_health_check_port(<health_check_port>)\
578+
... .with_model_uri(<model_id>)\
579+
... .with_env({"key":"value", "key2":"value2"})\
580+
... .with_inference_server("triton")
581+
... deployment = ModelDeployment()\
582+
... .with_display_name("Triton Example")\
583+
... .with_infrastructure(infrastructure)\
584+
... .with_runtime(runtime)
585+
>>> deployment.deploy()
562586
"""
563-
return self.set_spec(self.CONST_TRITON, inference_server.lower())
587+
return self.set_spec(self.CONST_INFERENCE_SERVER, inference_server.lower())

0 commit comments

Comments
 (0)