@@ -828,6 +828,8 @@ def predict(
828828 data : Any = None ,
829829 serializer : "ads.model.ModelInputSerializer" = model_input_serializer ,
830830 auto_serialize_data : bool = False ,
831+ model_name : str = None ,
832+ model_version : str = None ,
831833 ** kwargs ,
832834 ) -> dict :
833835 """Returns prediction of input data run against the model deployment endpoint.
@@ -860,6 +862,10 @@ def predict(
860862 If `auto_serialize_data=False`, `data` required to be bytes or json serializable
861863 and `json_input` required to be json serializable. If `auto_serialize_data` set
862864 to True, data will be serialized before sending to model deployment endpoint.
865+ model_name: str
866+ Defaults to None. When the `Inference_server="triton"`, the name of the model to invoke.
867+ model_version: str
868+ Defaults to None. When the `Inference_server="triton"`, the version of the model to invoke.
863869 kwargs:
864870 content_type: str
865871 Used to indicate the media type of the resource.
@@ -917,9 +923,16 @@ def predict(
917923 raise TypeError (
918924 "`data` is not bytes or json serializable. Set `auto_serialize_data` to `True` to serialize the input data."
919925 )
920-
926+ if model_name and model_version :
927+ header ['model-name' ] = model_name
928+ header ['model-version' ] = model_version
929+ elif not model_version and not model_name :
930+
931+ pass
932+ else :
933+ raise ValueError ("`model_name` and `model_version` have to be provided together." )
921934 prediction = send_request (
922- data = data , endpoint = endpoint , is_json_payload = is_json_payload , header = header
935+ data = data , endpoint = endpoint , is_json_payload = is_json_payload , header = header ,
923936 )
924937 return prediction
925938
@@ -1391,9 +1404,9 @@ def _update_from_oci_model(self, oci_model_instance) -> "ModelDeployment":
13911404 infrastructure .CONST_WEB_CONCURRENCY ,
13921405 runtime .env .get ("WEB_CONCURRENCY" , None ),
13931406 )
1394- if runtime .env .get ("CONTAINER_TYPE" , None ) == "TRITON" :
1407+ if runtime .env .pop ("CONTAINER_TYPE" , None ) == "TRITON" :
13951408 runtime .set_spec (
1396- runtime .CONST_TRITON , True
1409+ runtime .CONST_INFERENCE_SERVER , "triton"
13971410 )
13981411
13991412 self .set_spec (self .CONST_INFRASTRUCTURE , infrastructure )
@@ -1571,7 +1584,7 @@ def _build_model_deployment_configuration_details(self) -> Dict:
15711584 infrastructure .web_concurrency
15721585 )
15731586 runtime .set_spec (runtime .CONST_ENV , environment_variables )
1574- if runtime .triton :
1587+ if runtime .inference_server . lower () == " triton" :
15751588 environment_variables ["CONTAINER_TYPE" ] = "TRITON"
15761589 runtime .set_spec (runtime .CONST_ENV , environment_variables )
15771590 environment_configuration_details = {
0 commit comments