Skip to content

Commit a78f4c9

Browse files
author
Ziqun Ye
committed
ODSC-29065: adding more unit test
1 parent cc904ab commit a78f4c9

File tree

5 files changed

+34
-8
lines changed

5 files changed

+34
-8
lines changed

ads/model/deployment/model_deployment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -863,9 +863,9 @@ def predict(
863863
and `json_input` required to be json serializable. If `auto_serialize_data` set
864864
to True, data will be serialized before sending to model deployment endpoint.
865865
model_name: str
866-
Defaults to None. When the `Inference_server="triton"`, the name of the model to invoke.
866+
Defaults to None. When the `inference_server="triton"`, the name of the model to invoke.
867867
model_version: str
868-
Defaults to None. When the `Inference_server="triton"`, the version of the model to invoke.
868+
Defaults to None. When the `inference_server="triton"`, the version of the model to invoke.
869869
kwargs:
870870
content_type: str
871871
Used to indicate the media type of the resource.

ads/opctl/backend/ads_model_deployment.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,12 @@ def watch(self) -> None:
126126
def predict(self) -> None:
127127
ocid = self.config["execution"].get("ocid")
128128
data = self.config["execution"].get("payload")
129+
model_name = self.config["execution"].get("model_name")
130+
model_version = self.config["execution"].get("model_version")
129131
with AuthContext(auth=self.auth_type, profile=self.profile):
130132
model_deployment = ModelDeployment.from_id(ocid)
131-
data = json.loads(data)
132-
print(model_deployment.predict(data))
133+
try:
134+
data = json.loads(data)
135+
except:
136+
pass
137+
print(model_deployment.predict(data=data, model_name=model_name, model_version=model_version))

ads/opctl/cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,8 @@ def deactivate(**kwargs):
558558
required=False,
559559
help="The conda env used to load the model and conduct the prediction. This is only used when model id is passed to `ocid` and a local predict is conducted. It should match the inference conda env specified in the runtime.yaml file which is the conda pack being used when conducting real model deployment.",
560560
)
561+
@click.option("--model-version", nargs=1, required=False, help="When the `inference_server='triton'`, the version of the model to invoke. This can only be used when model deployment id is passed in. For the other cases, it will be ignored.")
562+
@click.option("--model-name", nargs=1, required=False, help="When the `inference_server='triton'`, the name of the model to invoke. This can only be used when model deployment id is passed in. For the other cases, it will be ignored.")
561563
@click.option("--debug", "-d", help="set debug mode", is_flag=True, default=False)
562564
def predict(**kwargs):
563565
"""

ads/opctl/script.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ads.model.generic_model import GenericModel
66

77

8-
def verify(artifact_dir, data, compartment_id, project_id):
8+
def verify(artifact_dir, data, compartment_id, project_id): # pragma: no cover
99
with tempfile.TemporaryDirectory() as td:
1010
model = GenericModel.from_model_artifact(
1111
uri=artifact_dir,
@@ -15,17 +15,20 @@ def verify(artifact_dir, data, compartment_id, project_id):
1515
project_id=project_id,
1616
)
1717

18-
data = json.loads(data)
18+
try:
19+
data = json.loads(data)
20+
except:
21+
pass
1922
print(model.verify(data, auto_serialize_data=False))
2023

2124

22-
def main():
25+
def main(): # pragma: no cover
2326
args = sys.argv[1:]
2427
verify(
2528
artifact_dir=args[0], data=args[1], compartment_id=args[2], project_id=args[3]
2629
)
2730
return 0
2831

2932

30-
if __name__ == "__main__":
33+
if __name__ == "__main__": # pragma: no cover
3134
main()

tests/unitary/with_extras/opctl/test_opctl_model_deployment_backend.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,17 @@ def config(self):
2020
"oci_config": "~/.oci/config",
2121
"oci_profile": "DEFAULT",
2222
"run_id": "test_model_deployment_id",
23+
"ocid": "fake_model_id",
2324
"auth": "api_key",
2425
"wait_for_completion": False,
2526
"max_wait_time": 1000,
2627
"poll_interval": 12,
2728
"log_type": "predict",
2829
"log_filter": "test_filter",
2930
"interval": 3,
31+
"payload": "fake_payload",
32+
"model_name": "model_name",
33+
"model_version": "model_version",
3034
}
3135
}
3236

@@ -103,3 +107,15 @@ def test_watch(self, mock_from_id, mock_watch):
103107
mock_watch.assert_called_with(
104108
log_type="predict", interval=3, log_filter="test_filter"
105109
)
110+
111+
@patch("ads.opctl.backend.ads_model_deployment.ModelDeployment.predict")
112+
@patch("ads.opctl.backend.ads_model_deployment.ModelDeployment.from_id")
113+
def test_predict(self, mock_from_id, mock_predict):
114+
config = self.config
115+
mock_from_id.return_value = ModelDeployment()
116+
backend = ModelDeploymentBackend(config)
117+
backend.predict()
118+
mock_from_id.assert_called_with("fake_model_id")
119+
mock_predict.assert_called_with(
120+
data="fake_payload", model_name='model_name', model_version='model_version'
121+
)

0 commit comments

Comments
 (0)