Skip to content

Commit e40302d

Browse files
author
Ziqun Ye
committed
commit the code for triton
1 parent c937a9b commit e40302d

File tree

4 files changed

+43
-10
lines changed

4 files changed

+43
-10
lines changed

ads/model/deployment/common/utils.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -119,19 +119,15 @@ def send_request(
119119
Returns:
120120
A JSON representive of a requests.Response object.
121121
"""
122-
headers = dict()
122+
123123
if is_json_payload:
124-
headers["Content-Type"] = (
125-
header.get("content_type") or DEFAULT_CONTENT_TYPE_JSON
126-
)
124+
header["Content-Type"] = header.pop("content_type", DEFAULT_CONTENT_TYPE_JSON)
127125
request_kwargs = {"json": data}
128126
else:
129-
headers["Content-Type"] = (
130-
header.get("content_type") or DEFAULT_CONTENT_TYPE_BYTES
131-
)
127+
header["Content-Type"] = header.pop("content_type", DEFAULT_CONTENT_TYPE_BYTES)
132128
request_kwargs = {"data": data} # should pass bytes when using data
133-
134-
request_kwargs["headers"] = headers
129+
130+
request_kwargs["headers"] = header
135131

136132
if dry_run:
137133
request_kwargs["headers"]["Accept"] = "*/*"
@@ -140,7 +136,7 @@ def send_request(
140136
return json.loads(req.body)
141137
return req.body
142138
else:
143-
request_kwargs["auth"] = header.get("signer")
139+
request_kwargs["auth"] = header.pop("signer")
144140
return requests.post(endpoint, **request_kwargs).json()
145141

146142

ads/model/deployment/model_deployment.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,7 @@ def predict(
878878
"signer": signer,
879879
"content_type": kwargs.get("content_type", None),
880880
}
881+
header.update(kwargs.pop("headers", {}))
881882

882883
if data is None and json_input is None:
883884
raise AttributeError(
@@ -1390,6 +1391,10 @@ def _update_from_oci_model(self, oci_model_instance) -> "ModelDeployment":
13901391
infrastructure.CONST_WEB_CONCURRENCY,
13911392
runtime.env.get("WEB_CONCURRENCY", None),
13921393
)
1394+
if runtime.env.get("CONTAINER_TYPE", None) == "TRITON":
1395+
runtime.set_spec(
1396+
runtime.CONST_TRITON, True
1397+
)
13931398

13941399
self.set_spec(self.CONST_INFRASTRUCTURE, infrastructure)
13951400
self.set_spec(self.CONST_RUNTIME, runtime)
@@ -1566,6 +1571,9 @@ def _build_model_deployment_configuration_details(self) -> Dict:
15661571
infrastructure.web_concurrency
15671572
)
15681573
runtime.set_spec(runtime.CONST_ENV, environment_variables)
1574+
if runtime.triton:
1575+
environment_variables["CONTAINER_TYPE"] = "TRITON"
1576+
runtime.set_spec(runtime.CONST_ENV, environment_variables)
15691577
environment_configuration_details = {
15701578
runtime.CONST_ENVIRONMENT_CONFIG_TYPE: runtime.environment_config_type,
15711579
runtime.CONST_ENVIRONMENT_VARIABLES: runtime.env,

ads/model/deployment/model_deployment_runtime.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ class ModelDeploymentContainerRuntime(ModelDeploymentRuntime):
330330
CONST_ENTRYPOINT = "entrypoint"
331331
CONST_SERVER_PORT = "serverPort"
332332
CONST_HEALTH_CHECK_PORT = "healthCheckPort"
333+
CONST_TRITON = "triton"
333334

334335
attribute_map = {
335336
**ModelDeploymentRuntime.attribute_map,
@@ -339,6 +340,7 @@ class ModelDeploymentContainerRuntime(ModelDeploymentRuntime):
339340
CONST_ENTRYPOINT: "entrypoint",
340341
CONST_SERVER_PORT: "server_port",
341342
CONST_HEALTH_CHECK_PORT: "health_check_port",
343+
CONST_TRITON: "triton"
342344
}
343345

344346
payload_attribute_map = {
@@ -532,3 +534,29 @@ def with_health_check_port(
532534
The ModelDeploymentContainerRuntime instance (self).
533535
"""
534536
return self.set_spec(self.CONST_HEALTH_CHECK_PORT, health_check_port)
537+
538+
@property
539+
def triton(self) -> str:
540+
"""Whether container is triton or not.
541+
542+
Returns
543+
-------
544+
bool
545+
Whether container is triton or not.
546+
"""
547+
return self.get_spec(self.CONST_TRITON, False)
548+
549+
def with_triton(self, triton: bool = True) -> "ModelDeploymentRuntime":
550+
"""Sets the flag for triton.
551+
552+
Parameters
553+
----------
554+
triton: bool
555+
Whether it is a triton container.
556+
557+
Returns
558+
-------
559+
ModelDeploymentRuntime
560+
The ModelDeploymentRuntime instance (self).
561+
"""
562+
return self.set_spec(self.CONST_TRITON, triton)

ads/model/generic_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2054,6 +2054,7 @@ def deploy(
20542054
raise ValueError("`compartment_id` has to be provided.")
20552055
if not (self.properties.project_id or existing_infrastructure.project_id):
20562056
raise ValueError("`project_id` has to be provided.")
2057+
20572058
infrastructure = (
20582059
ModelDeploymentInfrastructure()
20592060
.with_compartment_id(

0 commit comments

Comments
 (0)