Skip to content

Commit 5b2f44d

Browse files
author
Ziqun Ye
committed
run black
1 parent 00dfd55 commit 5b2f44d

File tree

4 files changed

+84
-31
lines changed

4 files changed

+84
-31
lines changed

ads/opctl/conda/cmds.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,9 @@ def _install(
311311
"""
312312
ns, bucket, path, slug = parse_conda_uri(conda_uri)
313313
if bucket == "service-conda-packs":
314-
raise ValueError("Download service conda pack is not allowed. Only custom conda pack can be downloaded to local machine. You need to publish it to your own bucket first.")
314+
raise ValueError(
315+
"Download service conda pack is not allowed. Only custom conda pack can be downloaded to local machine. You need to publish it to your own bucket first."
316+
)
315317
os.makedirs(conda_pack_folder, exist_ok=True)
316318
pack_path = os.path.join(os.path.expanduser(conda_pack_folder), slug + ".tar.gz")
317319
pack_folder_path = os.path.join(os.path.expanduser(conda_pack_folder), slug)

ads/opctl/model/cli.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,46 @@ def commands():
1919

2020
@commands.command()
2121
@click.argument("ocid", required=True)
22-
@click.option("--model-save-folder", "-mf", nargs=1, required=False, default=DEFAULT_MODEL_FOLDER, help="Which location to store model artifact folders. Defaults to ~/.ads_ops/models. This is only used when model id is passed to `ocid` and a local predict is conducted.")
22+
@click.option(
23+
"--model-save-folder",
24+
"-mf",
25+
nargs=1,
26+
required=False,
27+
default=DEFAULT_MODEL_FOLDER,
28+
help="Which location to store model artifact folders. Defaults to ~/.ads_ops/models. This is only used when model id is passed to `ocid` and a local predict is conducted.",
29+
)
2330
@click.option(
2431
"--auth",
2532
"-a",
2633
help="authentication method",
2734
type=click.Choice(AuthType.values()),
2835
default=None,
2936
)
30-
@click.option("--bucket-uri", nargs=1, required=False, help="The OCI Object Storage URI where model artifacts will be copied to. The `bucket_uri` is only necessary for uploading large artifacts which size is greater than 2GB. Example: `oci://<bucket_name>@<namespace>/prefix/`. This is only used when the model id is passed.")
31-
@click.option("--region", nargs=1, required=False, help="The destination Object Storage bucket region. By default the value will be extracted from the `OCI_REGION_METADATA` environment variables. This is only used when the model id is passed.")
32-
@click.option("--timeout", nargs=1, required=False, help="The connection timeout in seconds for the client. This is only used when the model id is passed.")
33-
@click.option("--force-overwrite", "-f", help="Overwrite existing model.", is_flag=True, default=False)
37+
@click.option(
38+
"--bucket-uri",
39+
nargs=1,
40+
required=False,
41+
help="The OCI Object Storage URI where model artifacts will be copied to. The `bucket_uri` is only necessary for uploading large artifacts which size is greater than 2GB. Example: `oci://<bucket_name>@<namespace>/prefix/`. This is only used when the model id is passed.",
42+
)
43+
@click.option(
44+
"--region",
45+
nargs=1,
46+
required=False,
47+
help="The destination Object Storage bucket region. By default the value will be extracted from the `OCI_REGION_METADATA` environment variables. This is only used when the model id is passed.",
48+
)
49+
@click.option(
50+
"--timeout",
51+
nargs=1,
52+
required=False,
53+
help="The connection timeout in seconds for the client. This is only used when the model id is passed.",
54+
)
55+
@click.option(
56+
"--force-overwrite",
57+
"-f",
58+
help="Overwrite existing model.",
59+
is_flag=True,
60+
default=False,
61+
)
3462
@click.option("--debug", "-d", help="set debug mode", is_flag=True, default=False)
3563
def download(**kwargs):
36-
suppress_traceback(kwargs["debug"])(download_model_cmd)(**kwargs)
64+
suppress_traceback(kwargs["debug"])(download_model_cmd)(**kwargs)

ads/opctl/model/cmds.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,41 +19,59 @@ def download_model(**kwargs):
1919
oci_auth = create_signer(
2020
auth_type,
2121
oci_config,
22-
profile ,
22+
profile,
23+
)
24+
model_folder = os.path.expanduser(
25+
p.config["execution"].get("model_save_folder", DEFAULT_MODEL_FOLDER)
2326
)
24-
model_folder = os.path.expanduser(p.config["execution"].get("model_save_folder", DEFAULT_MODEL_FOLDER))
2527
force_overwrite = p.config["execution"].get("force_overwrite", False)
26-
27-
artifact_directory = os.path.join(model_folder, str(ocid))
28-
if (not os.path.exists(artifact_directory) or len(os.listdir(artifact_directory)) == 0) or force_overwrite:
2928

29+
artifact_directory = os.path.join(model_folder, str(ocid))
30+
if (
31+
not os.path.exists(artifact_directory)
32+
or len(os.listdir(artifact_directory)) == 0
33+
) or force_overwrite:
3034
region = p.config["execution"].get("region", None)
3135
bucket_uri = p.config["execution"].get("bucket_uri", None)
3236
timeout = p.config["execution"].get("timeout", None)
33-
logger.info(f"No cached model found. Downloading the model {ocid} to {artifact_directory}. If you already have a copy of the model, specify `artifact_directory` instead of `ocid`. You can specify `model_save_folder` to decide where to store the model artifacts.")
34-
_download_model(ocid=ocid, artifact_directory=artifact_directory, region=region, bucket_uri=bucket_uri, timeout=timeout, force_overwrite=force_overwrite, oci_auth=oci_auth)
37+
logger.info(
38+
f"No cached model found. Downloading the model {ocid} to {artifact_directory}. If you already have a copy of the model, specify `artifact_directory` instead of `ocid`. You can specify `model_save_folder` to decide where to store the model artifacts."
39+
)
40+
_download_model(
41+
ocid=ocid,
42+
artifact_directory=artifact_directory,
43+
region=region,
44+
bucket_uri=bucket_uri,
45+
timeout=timeout,
46+
force_overwrite=force_overwrite,
47+
oci_auth=oci_auth,
48+
)
3549
else:
3650
logger.error(f"Model already exists. Set `force_overwrite=True` to overwrite.")
37-
raise ValueError(f"Model already exists. Set `force_overwrite=True` to overwrite.")
51+
raise ValueError(
52+
f"Model already exists. Set `force_overwrite=True` to overwrite."
53+
)
3854

3955

40-
def _download_model(ocid, artifact_directory, oci_auth, region, bucket_uri, timeout, force_overwrite):
56+
def _download_model(
57+
ocid, artifact_directory, oci_auth, region, bucket_uri, timeout, force_overwrite
58+
):
4159
os.makedirs(artifact_directory, exist_ok=True)
4260
os.chmod(artifact_directory, 777)
43-
61+
4462
try:
4563
dsc_model = DataScienceModel.from_id(ocid)
4664
dsc_model.download_artifact(
47-
target_dir=artifact_directory,
48-
force_overwrite=force_overwrite,
49-
overwrite_existing_artifact=True,
50-
remove_existing_artifact=True,
51-
auth=oci_auth,
52-
region=region,
53-
timeout=timeout,
54-
bucket_uri=bucket_uri,
65+
target_dir=artifact_directory,
66+
force_overwrite=force_overwrite,
67+
overwrite_existing_artifact=True,
68+
remove_existing_artifact=True,
69+
auth=oci_auth,
70+
region=region,
71+
timeout=timeout,
72+
bucket_uri=bucket_uri,
5573
)
5674
except Exception as e:
5775
print(type(e))
5876
shutil.rmtree(artifact_directory, ignore_errors=True)
59-
raise e
77+
raise e

ads/opctl/script.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,23 @@
77

88
def verify(artifact_dir, data, compartment_id, project_id):
99
with tempfile.TemporaryDirectory() as td:
10-
model = GenericModel.from_model_artifact(uri = artifact_dir, artifact_dir=artifact_dir,
11-
force_overwrite=True,
12-
compartment_id=compartment_id,
13-
project_id=project_id)
10+
model = GenericModel.from_model_artifact(
11+
uri=artifact_dir,
12+
artifact_dir=artifact_dir,
13+
force_overwrite=True,
14+
compartment_id=compartment_id,
15+
project_id=project_id,
16+
)
1417

1518
data = json.loads(data)
1619
print(model.verify(data, auto_serialize_data=False))
1720

1821

1922
def main():
2023
args = sys.argv[1:]
21-
verify(artifact_dir = args[0], data=args[1], compartment_id=args[2], project_id=args[3])
24+
verify(
25+
artifact_dir=args[0], data=args[1], compartment_id=args[2], project_id=args[3]
26+
)
2227
return 0
2328

2429

0 commit comments

Comments
 (0)