Skip to content

Commit 3fe5738

Browse files
author
Ziqun Ye
committed
make predict work
1 parent c8ed754 commit 3fe5738

File tree

5 files changed

+63
-30
lines changed

5 files changed

+63
-30
lines changed

ads/opctl/backend/ads_model_deployment.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -178,26 +178,27 @@ def watch(self) -> None:
178178
def predict(self) -> None:
179179
ocid = self.config["execution"].get("ocid")
180180
data = self.config["execution"].get("data")
181-
if "datasciencemodeldeployment" in ocid:
182-
with AuthContext(auth=self.auth_type, profile=self.profile):
183-
model_deployment = ModelDeployment.from_id(ocid)
184-
return model_deployment.predict(data)
185-
elif "datasciencemodel":
186-
with AuthContext(auth=self.auth_type, profile=self.profile):
187-
import tempfile
188-
with tempfile.TemporaryDirectory() as td:
189-
190-
model = GenericModel.from_model_catalog(ocid, artifact_dir=os.path.join(td, str(uuid.uuid4())), force_overwrite=True)
181+
# if "datasciencemodeldeployment" in ocid:
182+
with AuthContext(auth=self.auth_type, profile=self.profile):
183+
model_deployment = ModelDeployment.from_id(ocid)
184+
data = json.loads(data)
185+
print(model_deployment.predict(data))
186+
# elif "datasciencemodel":
187+
# with AuthContext(auth=self.auth_type, profile=self.profile):
188+
# import tempfile
189+
# with tempfile.TemporaryDirectory() as td:
190+
191+
# model = GenericModel.from_model_catalog(ocid, artifact_dir=os.path.join(td, str(uuid.uuid4())), force_overwrite=True)
191192

192-
conda_pack = self.config["execution"].get("conda", None)
193-
if not conda_pack and hasattr(model.metadata_custom, "EnvironmentType") and model.metadata_custom.EnvironmentType == "published" and hasattr(model.metadata_custom, "CondaEnvironmentPath"):
194-
conda_pack = model.metadata_custom.CondaEnvironmentPath
195-
if conda_pack and "service-conda-packs" not in conda_pack:
196-
print("install conda pack and activate the conda pack.")
193+
# conda_pack = self.config["execution"].get("conda", None)
194+
# if not conda_pack and hasattr(model.metadata_custom, "EnvironmentType") and model.metadata_custom.EnvironmentType == "published" and hasattr(model.metadata_custom, "CondaEnvironmentPath"):
195+
# conda_pack = model.metadata_custom.CondaEnvironmentPath
196+
# if conda_pack and "service-conda-packs" not in conda_pack:
197+
# print("install conda pack and activate the conda pack.")
197198

198-
data = json.loads(data)
199-
print(model.verify(data))
200-
else:
201-
raise ValueError("Only model ocid or model deployment ocid is supported.")
199+
# data = json.loads(data)
200+
# print(model.verify(data))
201+
# else:
202+
# raise ValueError("Only model ocid or model deployment ocid is supported.")
202203

203204

ads/opctl/backend/local.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,5 +653,20 @@ def predict(self) -> None:
653653
f"Run with the --debug argument to view container logs."
654654
)
655655

656+
def _run_with_image(self, bind_volumes):
657+
ocid = self.config["execution"].get("ocid")
658+
data = self.config["execution"].get("data")
659+
image = self.config["execution"].get("image")
660+
env_vars = self.config["execution"]["env_vars"]
661+
# compartment_id = self.config["execution"].get("compartment_id", self.config["infrastructure"].get("compartment_id"))
662+
# project_id = self.config["execution"].get("project_id", self.config["infrastructure"].get("project_id"))
663+
entrypoint = self.config["execution"].get("entrypoint", None)
664+
command = self.config["execution"].get("command", None)
665+
if self.config["execution"].get("source_folder", None):
666+
bind_volumes.update(self._mount_source_folder_if_exists(bind_volumes))
667+
bind_volumes.update(self.config["execution"]["volumes"])
668+
669+
return run_container(image, bind_volumes, env_vars, command, entrypoint)
670+
656671
def _run_with_local_env(self, ):
657672
pass

ads/opctl/cli.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -503,23 +503,14 @@ def deactivate(**kwargs):
503503
@commands.command()
504504
@click.argument("ocid", nargs=1)
505505
@click.argument("data", nargs=1)
506-
@click.argument("conda_slug", nargs=1)
506+
@click.argument("conda_slug", nargs=1, required=False)
507507
@add_options(_options)
508508
def predict(**kwargs):
509509
"""
510510
Deactivates a data science service.
511511
"""
512512
suppress_traceback(kwargs["debug"])(predict_cmd)(**kwargs)
513513

514-
515-
@commands.command()
516-
@click.argument("conda", nargs=1)
517-
@click.argument("ocid", nargs=1)
518-
@click.argument("data", nargs=1)
519-
@add_options(_options)
520-
def verify(**kwargs):
521-
suppress_traceback(kwargs["debug"])(verify_cmd)(**kwargs)
522-
523514

524515
commands.add_command(ads.opctl.conda.cli.commands)
525516
commands.add_command(ads.opctl.spark.cli.commands)

ads/opctl/cmds.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,8 +484,8 @@ def predict(**kwargs):
484484
# p.config["execution"]["backend"] = b.value
485485

486486
# return _BackendFactory(p.config).backend.predict()
487-
488487
if "datasciencemodeldeployment" in p.config["execution"].get("ocid", ""):
488+
489489
return ModelDeploymentBackend(p.config).predict()
490490
elif "datasciencemodel" in p.config["execution"].get("ocid", ""):
491491
return LocalModelDeploymentBackend(p.config).predict()

ads/opctl/script.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import json
2+
import os
3+
import sys
4+
import tempfile
5+
import uuid
6+
7+
from ads.model.generic_model import GenericModel
8+
9+
10+
def verify(ocid, data, compartment_id, project_id):
11+
with tempfile.TemporaryDirectory() as td:
12+
model = GenericModel.from_model_catalog(ocid,
13+
artifact_dir=os.path.join(td, str(uuid.uuid4())),
14+
force_overwrite=True, compartment_id=compartment_id, project_id=project_id)
15+
data = json.loads(data)
16+
print(model.verify(data, auto_serialize_data=False))
17+
18+
19+
def main():
20+
args = sys.argv[1:]
21+
verify(ocid = args[0], data=args[1], compartment_id=args[2], project_id=args[3])
22+
return 0
23+
24+
25+
if __name__ == "__main__":
26+
main()

0 commit comments

Comments
 (0)