Skip to content

Commit d1e68e3

Browse files
author
Ziqun Ye
committed
adding model download
1 parent 8cc4e18 commit d1e68e3

File tree

2 files changed

+14
-13
lines changed

2 files changed

+14
-13
lines changed

ads/opctl/backend/local.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -627,9 +627,9 @@ def __init__(self, config: Dict) -> None:
627627
self.client = OCIClientFactory(**self.oci_auth).data_science
628628

629629
def predict(self) -> None:
630-
631630
ocid = self.config["execution"].get("ocid")
632631
data = self.config["execution"].get("data")
632+
model_folder = self.config["execution"].get("model_folder", DEFAULT_MODEL_FOLDER)
633633
conda_slug, conda_path = self._get_conda_info(ocid)
634634
compartment_id = self.config["execution"].get("compartment_id", self.config["infrastructure"].get("compartment_id"))
635635
project_id = self.config["execution"].get("project_id", self.config["infrastructure"].get("project_id"))
@@ -647,8 +647,9 @@ def predict(self) -> None:
647647
dir_path = os.path.dirname(os.path.realpath(__file__))
648648
script = "script.py"
649649
self.config["execution"]["source_folder"] = os.path.abspath(os.path.join(dir_path, ".."))
650-
# bind_volumes[os.path.join(dir_path, "..", "script.py")] = {"bind": script}
650+
651651
self.config["execution"]["entrypoint"] = script
652+
bind_volumes[os.path.join(model_folder)] = {"bind": script}
652653
if self.config["execution"].get("image"):
653654
exit_code = self._run_with_image(bind_volumes)
654655
elif self.config["execution"].get("conda_slug", conda_slug):
@@ -666,17 +667,17 @@ def predict(self) -> None:
666667
f"Run with the --debug argument to view container logs."
667668
)
668669

669-
def _download_model(self, ocid, region):
670+
def _download_model(self, ocid, region, bucket_uri, timeout):
670671
dsc_model = DataScienceModel.from_id(ocid)
671672
dsc_model.download_artifact(
672-
target_dir=self.config["execution"].get("source_folder", DEFAULT_MODEL_FOLDER),
673+
target_dir=os.path.join(self.config["execution"].get("source_folder", DEFAULT_MODEL_FOLDER), ocid),
673674
force_overwrite=True,
674675
overwrite_existing_artifact=True,
675676
remove_existing_artifact=True,
676677
auth=self.oci_auth,
677678
region=region,
678-
timeout=600,
679-
bucket_urr=None,
679+
timeout=timeout,
680+
bucket_uri=bucket_uri,
680681
)
681682

682683
def _get_conda_info(self, ocid):

ads/opctl/script.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,24 @@
11
import json
2-
import os
32
import sys
43
import tempfile
5-
import uuid
64

75
from ads.model.generic_model import GenericModel
86

97

10-
def verify(ocid, data, compartment_id, project_id):
8+
def verify(artifact_dir, data, compartment_id, project_id):
119
with tempfile.TemporaryDirectory() as td:
12-
model = GenericModel.from_model_catalog(ocid,
13-
artifact_dir=os.path.join(td, str(uuid.uuid4())),
14-
force_overwrite=True, compartment_id=compartment_id, project_id=project_id)
10+
model = GenericModel.from_model_artifact(artifact_dir=artifact_dir,
11+
force_overwrite=True,
12+
compartment_id=compartment_id,
13+
project_id=project_id)
14+
1515
data = json.loads(data)
1616
print(model.verify(data, auto_serialize_data=False))
1717

1818

1919
def main():
2020
args = sys.argv[1:]
21-
verify(ocid = args[0], data=args[1], compartment_id=args[2], project_id=args[3])
21+
verify(artifact_dir = args[0], data=args[1], compartment_id=args[2], project_id=args[3])
2222
return 0
2323

2424

0 commit comments

Comments
 (0)