|
17 | 17 | from ads.common import auth as authutil |
18 | 18 | import pandas as pd |
19 | 19 | from ads.model.serde.model_input import JsonModelInputSERDE |
20 | | -from ads.common import auth, oci_client |
21 | 20 | from ads.common.oci_logging import ( |
22 | 21 | LOG_INTERVAL, |
23 | 22 | LOG_RECORDS_LIMIT, |
|
63 | 62 |
|
64 | 63 | MODEL_DEPLOYMENT_KIND = "deployment" |
65 | 64 | MODEL_DEPLOYMENT_TYPE = "modelDeployment" |
| 65 | +MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON = "TRITON" |
66 | 66 |
|
67 | 67 | MODEL_DEPLOYMENT_INSTANCE_SHAPE = "VM.Standard.E4.Flex" |
68 | 68 | MODEL_DEPLOYMENT_INSTANCE_OCPUS = 1 |
@@ -926,10 +926,7 @@ def predict( |
926 | 926 | if model_name and model_version: |
927 | 927 | header['model-name'] = model_name |
928 | 928 | header['model-version'] = model_version |
929 | | - elif not model_version and not model_name: |
930 | | - |
931 | | - pass |
932 | | - else: |
| 929 | + elif bool(model_version) ^ bool(model_name): |
933 | 930 | raise ValueError("`model_name` and `model_version` have to be provided together.") |
934 | 931 | prediction = send_request( |
935 | 932 | data=data, endpoint=endpoint, is_json_payload=is_json_payload, header=header, |
@@ -1404,9 +1401,9 @@ def _update_from_oci_model(self, oci_model_instance) -> "ModelDeployment": |
1404 | 1401 | infrastructure.CONST_WEB_CONCURRENCY, |
1405 | 1402 | runtime.env.get("WEB_CONCURRENCY", None), |
1406 | 1403 | ) |
1407 | | - if runtime.env.get("CONTAINER_TYPE", None) == "TRITON": |
| 1404 | + if runtime.env.get("CONTAINER_TYPE", None) == MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON: |
1408 | 1405 | runtime.set_spec( |
1409 | | - runtime.CONST_INFERENCE_SERVER, "triton" |
| 1406 | + runtime.CONST_INFERENCE_SERVER, MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON.lower() |
1410 | 1407 | ) |
1411 | 1408 |
|
1412 | 1409 | self.set_spec(self.CONST_INFRASTRUCTURE, infrastructure) |
|
0 commit comments