Skip to content

Commit a22ff17

Browse files
author
Ziqun Ye
committed
adding code for model
1 parent 109b606 commit a22ff17

File tree

3 files changed

+99
-0
lines changed

3 files changed

+99
-0
lines changed

ads/opctl/model/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8; -*-
3+
4+
# Copyright (c) 2023 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

ads/opctl/model/cli.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8; -*-
3+
4+
# Copyright (c) 2023 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
import click
8+
from ads.common.auth import AuthType
9+
from ads.opctl.utils import suppress_traceback
10+
from ads.opctl.model.cmds import download_model as download_model_cmd
11+
from ads.opctl.backend.local import DEFAULT_MODEL_FOLDER
12+
13+
14+
@click.group("model")
15+
@click.help_option("--help", "-h")
16+
def commands():
17+
pass
18+
19+
20+
@commands.command()
21+
@click.argument("ocid", required=True)
22+
@click.option("--model-save-folder", "-mf", nargs=1, required=False, default=DEFAULT_MODEL_FOLDER, help="Which location to store model artifact folders. Defaults to ~/.ads_ops/models. This is only used when model id is passed to `ocid` and a local predict is conducted.")
23+
@click.option(
24+
"--auth",
25+
"-a",
26+
help="authentication method",
27+
type=click.Choice(AuthType.values()),
28+
default=None,
29+
)
30+
@click.option("--bucket-uri", nargs=1, required=False, help="The OCI Object Storage URI where model artifacts will be copied to. The `bucket_uri` is only necessary for uploading large artifacts which size is greater than 2GB. Example: `oci://<bucket_name>@<namespace>/prefix/`. This is only used when the model id is passed.")
31+
@click.option("--region", nargs=1, required=False, help="The destination Object Storage bucket region. By default the value will be extracted from the `OCI_REGION_METADATA` environment variables. This is only used when the model id is passed.")
32+
@click.option("--timeout", nargs=1, required=False, help="The connection timeout in seconds for the client. This is only used when the model id is passed.")
33+
@click.option("--force-overwrite", "-f", help="Overwrite existing model.", is_flag=True, default=False)
34+
@click.option("--debug", "-d", help="set debug mode", is_flag=True, default=False)
35+
def download(**kwargs):
36+
suppress_traceback(kwargs["debug"])(download_model_cmd)(**kwargs)

ads/opctl/model/cmds.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import os
2+
import shutil
3+
4+
from ads.common.auth import create_signer
5+
from ads.model.datascience_model import DataScienceModel
6+
from ads.opctl import logger
7+
from ads.opctl.backend.local import DEFAULT_MODEL_FOLDER
8+
from ads.opctl.config.base import ConfigProcessor
9+
from ads.opctl.config.merger import ConfigMerger
10+
11+
12+
def download_model(**kwargs):
13+
p = ConfigProcessor().step(ConfigMerger, **kwargs)
14+
ocid = p.config["execution"]["ocid"]
15+
16+
auth_type = p.config["execution"].get("auth")
17+
profile = p.config["execution"].get("oci_profile", None)
18+
oci_config = p.config["execution"].get("oci_config", None)
19+
oci_auth = create_signer(
20+
auth_type,
21+
oci_config,
22+
profile ,
23+
)
24+
model_folder = os.path.expanduser(p.config["execution"].get("model_save_folder", DEFAULT_MODEL_FOLDER))
25+
force_overwrite = p.config["execution"].get("force_overwrite", False)
26+
27+
artifact_directory = os.path.join(model_folder, str(ocid))
28+
if (not os.path.exists(artifact_directory) or len(os.listdir(artifact_directory)) == 0) or force_overwrite:
29+
30+
region = p.config["execution"].get("region", None)
31+
bucket_uri = p.config["execution"].get("bucket_uri", None)
32+
timeout = p.config["execution"].get("timeout", None)
33+
logger.info(f"No cached model found. Downloading the model {ocid} to {artifact_directory}. If you already have a copy of the model, specify `artifact_directory` instead of `ocid`. You can specify `model_save_folder` to decide where to store the model artifacts.")
34+
_download_model(ocid=ocid, artifact_directory=artifact_directory, region=region, bucket_uri=bucket_uri, timeout=timeout, force_overwrite=force_overwrite, oci_auth=oci_auth)
35+
else:
36+
logger.error(f"Model already exists. Set `force_overwrite=True` to overwrite.")
37+
raise ValueError(f"Model already exists. Set `force_overwrite=True` to overwrite.")
38+
39+
40+
def _download_model(ocid, artifact_directory, oci_auth, region, bucket_uri, timeout, force_overwrite):
41+
os.makedirs(artifact_directory, exist_ok=True)
42+
os.chmod(artifact_directory, 777)
43+
44+
try:
45+
dsc_model = DataScienceModel.from_id(ocid)
46+
dsc_model.download_artifact(
47+
target_dir=artifact_directory,
48+
force_overwrite=force_overwrite,
49+
overwrite_existing_artifact=True,
50+
remove_existing_artifact=True,
51+
auth=oci_auth,
52+
region=region,
53+
timeout=timeout,
54+
bucket_uri=bucket_uri,
55+
)
56+
except Exception as e:
57+
print(str(e))
58+
shutil.rmtree(artifact_directory, ignore_errors=True)

0 commit comments

Comments
 (0)