Skip to content

Commit c8ed767

Browse files
authored
Fix integration tests to run with instance_principal (#339)
1 parent 6b75c58 commit c8ed767

18 files changed

+156
-84
lines changed

ads/jobs/serializer.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,9 @@
1212
import yaml
1313
from ads.common.auth import default_signer
1414

15+
# Special type to represent the current enclosed class.
16+
# This type is used by factory class method or when a method returns ``self``.
1517
Self = TypeVar("Self", bound="Serializable")
16-
"""Special type to represent the current enclosed class.
17-
18-
This type is used by factory class method or when a method returns ``self``.
19-
"""
2018

2119

2220
class Serializable(ABC):
@@ -72,6 +70,14 @@ def _write_to_file(s: str, uri: str, **kwargs) -> None:
7270
"if you wish to overwrite."
7371
)
7472

73+
# Add default signer if the uri is an object storage uri, and
74+
# the user does not specify config or signer.
75+
if (
76+
uri.startswith("oci://")
77+
and "config" not in kwargs
78+
and "signer" not in kwargs
79+
):
80+
kwargs.update(default_signer())
7581
with fsspec.open(uri, "w", **kwargs) as f:
7682
f.write(s)
7783

ads/opctl/config/merger.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def _fill_config_with_defaults(self, ads_config_path: str) -> None:
117117
else:
118118
self.config["execution"]["auth"] = AuthType.API_KEY
119119
# determine profile
120-
if self.config["execution"]["auth"] == AuthType.RESOURCE_PRINCIPAL:
120+
if self.config["execution"]["auth"] != AuthType.API_KEY:
121121
profile = self.config["execution"]["auth"].upper()
122122
exec_config.pop("oci_profile", None)
123123
self.config["execution"]["oci_profile"] = None
@@ -202,20 +202,23 @@ def _get_service_config(self, oci_profile: str, ads_config_folder: str) -> Dict:
202202
def _config_flex_shape_details(self):
203203
infrastructure = self.config["infrastructure"]
204204
backend = self.config["execution"].get("backend", None)
205-
if backend == BACKEND_NAME.JOB.value or backend == BACKEND_NAME.MODEL_DEPLOYMENT.value:
205+
if (
206+
backend == BACKEND_NAME.JOB.value
207+
or backend == BACKEND_NAME.MODEL_DEPLOYMENT.value
208+
):
206209
shape_name = infrastructure.get("shape_name", "")
207210
if shape_name.endswith(".Flex"):
208211
if (
209-
"ocpus" not in infrastructure or
210-
"memory_in_gbs" not in infrastructure
212+
"ocpus" not in infrastructure
213+
or "memory_in_gbs" not in infrastructure
211214
):
212215
raise ValueError(
213216
"Parameters `ocpus` and `memory_in_gbs` must be provided for using flex shape. "
214217
"Call `ads opctl config` to specify."
215218
)
216219
infrastructure["shape_config_details"] = {
217220
"ocpus": infrastructure.pop("ocpus"),
218-
"memory_in_gbs": infrastructure.pop("memory_in_gbs")
221+
"memory_in_gbs": infrastructure.pop("memory_in_gbs"),
219222
}
220223
elif backend == BACKEND_NAME.DATAFLOW.value:
221224
executor_shape = infrastructure.get("executor_shape", "")
@@ -224,7 +227,7 @@ def _config_flex_shape_details(self):
224227
"driver_shape_memory_in_gbs",
225228
"driver_shape_ocpus",
226229
"executor_shape_memory_in_gbs",
227-
"executor_shape_ocpus"
230+
"executor_shape_ocpus",
228231
]
229232
# executor_shape and driver_shape must be the same shape family
230233
if executor_shape.endswith(".Flex") or driver_shape.endswith(".Flex"):
@@ -236,9 +239,9 @@ def _config_flex_shape_details(self):
236239
)
237240
infrastructure["driver_shape_config"] = {
238241
"ocpus": infrastructure.pop("driver_shape_ocpus"),
239-
"memory_in_gbs": infrastructure.pop("driver_shape_memory_in_gbs")
242+
"memory_in_gbs": infrastructure.pop("driver_shape_memory_in_gbs"),
240243
}
241244
infrastructure["executor_shape_config"] = {
242245
"ocpus": infrastructure.pop("executor_shape_ocpus"),
243-
"memory_in_gbs": infrastructure.pop("executor_shape_memory_in_gbs")
246+
"memory_in_gbs": infrastructure.pop("executor_shape_memory_in_gbs"),
244247
}

ads/opctl/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8; -*-
33

4-
# Copyright (c) 2022 Oracle and/or its affiliates.
4+
# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77

@@ -88,9 +88,8 @@ def get_namespace(auth: dict) -> str:
8888

8989

9090
def get_region_key(auth: dict) -> str:
91-
if len(auth["config"]) > 0:
92-
tenancy = auth["config"]["tenancy"]
93-
else:
91+
tenancy = auth["config"].get("tenancy")
92+
if not tenancy:
9493
tenancy = auth["signer"].tenancy_id
9594
client = OCIClientFactory(**auth).identity
9695
return client.get_tenancy(tenancy).data.home_region_key

tests/integration/jobs/test_dsc_job.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,10 @@ def assert_job_creation(self, job, expected_infra_spec, expected_runtime_spec):
222222
random.seed(threading.get_ident() + os.getpid())
223223
random_suffix = "".join(random.choices(string.ascii_uppercase, k=6))
224224
yaml_uri = f"oci://{self.BUCKET}@{self.NAMESPACE}/tests/{timestamp}/example_job_{random_suffix}.yaml"
225-
config_path = "~/.oci/config"
226-
job.to_yaml(uri=yaml_uri, config=config_path)
225+
job.to_yaml(uri=yaml_uri)
227226
print(f"Job YAML saved to {yaml_uri}")
228227
try:
229-
job = Job.from_yaml(uri=yaml_uri, config=config_path)
228+
job = Job.from_yaml(uri=yaml_uri)
230229
except Exception:
231230
self.fail(f"Failed to load job from YAML\n{traceback.format_exc()}")
232231

tests/integration/jobs/test_jobs_cli.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,38 +7,88 @@
77

88
from click.testing import CliRunner
99

10+
from ads.common.auth import AuthType
1011
from ads.jobs.cli import run, watch, delete
1112

1213

1314
class TestJobsCLI:
15+
# TeamCity will use Instance Principal, when running locally - set OCI_IAM_TYPE to security_token
16+
auth = os.environ.get("OCI_IAM_TYPE", AuthType.INSTANCE_PRINCIPAL)
17+
1418
def test_create_watch_delete_job(self):
1519
curr_dir = os.path.dirname(os.path.abspath(__file__))
1620
runner = CliRunner()
1721
res = runner.invoke(
18-
run, args=["-f", os.path.join(curr_dir, "../yamls", "sample_job.yaml")]
22+
run,
23+
args=[
24+
"-f",
25+
os.path.join(curr_dir, "../yamls", "sample_job.yaml"),
26+
"--auth",
27+
self.auth,
28+
],
1929
)
2030
assert res.exit_code == 0, res.output
2131
run_id = res.output.split("\n")[1]
22-
res2 = runner.invoke(watch, args=[run_id])
32+
res2 = runner.invoke(
33+
watch,
34+
args=[
35+
run_id,
36+
"--auth",
37+
self.auth,
38+
],
39+
)
2340
assert res2.exit_code == 0, res2.output
2441

25-
res3 = runner.invoke(delete, args=[run_id])
42+
res3 = runner.invoke(
43+
delete,
44+
args=[
45+
run_id,
46+
"--auth",
47+
self.auth,
48+
],
49+
)
2650
assert res3.exit_code == 0, res3.output
2751

2852
def test_create_watch_delete_dataflow(self):
2953
curr_dir = os.path.dirname(os.path.abspath(__file__))
3054
runner = CliRunner()
3155
res = runner.invoke(
32-
run, args=["-f", os.path.join(curr_dir, "../yamls", "sample_dataflow.yaml")]
56+
run,
57+
args=[
58+
"-f",
59+
os.path.join(curr_dir, "../yamls", "sample_dataflow.yaml"),
60+
"--auth",
61+
self.auth,
62+
],
3363
)
3464
assert res.exit_code == 0, res.output
3565
run_id = res.output.split("\n")[1]
36-
res2 = runner.invoke(watch, args=[run_id])
66+
res2 = runner.invoke(
67+
watch,
68+
args=[
69+
run_id,
70+
"--auth",
71+
self.auth,
72+
],
73+
)
3774
assert res2.exit_code == 0, res2.output
3875

3976
res3 = runner.invoke(
40-
run, args=["-f", os.path.join(curr_dir, "../yamls", "sample_dataflow.yaml")]
77+
run,
78+
args=[
79+
"-f",
80+
os.path.join(curr_dir, "../yamls", "sample_dataflow.yaml"),
81+
"--auth",
82+
self.auth,
83+
],
4184
)
4285
run_id2 = res3.output.split("\n")[1]
43-
res4 = runner.invoke(delete, args=[run_id2])
86+
res4 = runner.invoke(
87+
delete,
88+
args=[
89+
run_id2,
90+
"--auth",
91+
self.auth,
92+
],
93+
)
4494
assert res4.exit_code == 0, res4.output

tests/integration/jobs/test_jobs_notebook.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import tempfile
99

1010
import fsspec
11+
from ads.common.auth import default_signer, AuthType
1112
from ads.jobs.builders.infrastructure.dsc_job_runtime import (
1213
NotebookRuntimeHandler,
1314
)
@@ -64,9 +65,7 @@ def run_notebook(
6465
# Clear the files in output URI
6566
try:
6667
# Ignore the error for unit tests.
67-
fs = fsspec.filesystem(
68-
"oci", config=os.path.expanduser("~/.oci/config")
69-
)
68+
fs = fsspec.filesystem("oci", **default_signer())
7069
if fs.find(output_uri):
7170
fs.rm(output_uri, recursive=True)
7271
except:

tests/integration/jobs/test_jobs_notebook_runtime.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
import json
7+
import pytest
78
import os
89
import tempfile
910
from zipfile import ZipFile
1011

1112
import fsspec
12-
13+
from ads.common.auth import default_signer
1314
from tests.integration.config import secrets
1415
from tests.integration.jobs.test_dsc_job import DSCJobTestCaseWithCleanUp
1516
from tests.integration.jobs.test_jobs_notebook import NotebookDriverRunTest
@@ -19,7 +20,9 @@
1920

2021

2122
class NotebookRuntimeTest(DSCJobTestCaseWithCleanUp):
22-
NOTEBOOK_PATH = os.path.join(os.path.dirname(__file__), "../fixtures/ads_check.ipynb")
23+
NOTEBOOK_PATH = os.path.join(
24+
os.path.dirname(__file__), "../fixtures/ads_check.ipynb"
25+
)
2326
NOTEBOOK_PATH_EXCLUDE = os.path.join(
2427
os.path.dirname(__file__), "../fixtures/exclude_check.ipynb"
2528
)
@@ -86,10 +89,15 @@ def test_create_job_with_notebook(self):
8689

8790

8891
class NotebookDriverIntegrationTest(NotebookDriverRunTest):
92+
@pytest.mark.skip(
93+
reason="api_keys not an option anymore, this test is candidate to be removed"
94+
)
8995
def test_notebook_driver_with_outputs(self):
9096
"""Tests run the notebook driver with a notebook plotting and saving data."""
9197
# Notebook to be executed
92-
notebook_path = os.path.join(os.path.dirname(__file__), "../fixtures/plot.ipynb")
98+
notebook_path = os.path.join(
99+
os.path.dirname(__file__), "../fixtures/plot.ipynb"
100+
)
93101
# Object storage output location
94102
output_uri = f"oci://{secrets.jobs.BUCKET_B}@{secrets.common.NAMESPACE}/notebook_driver_int_test/plot/"
95103
# Run the notebook with driver and check the logs
@@ -100,7 +108,7 @@ def test_notebook_driver_with_outputs(self):
100108
# Check the notebook saved to object storage.
101109
with fsspec.open(
102110
os.path.join(output_uri, os.path.basename(notebook_path)),
103-
config=os.path.expanduser("~/.oci/config"),
111+
**default_signer(),
104112
) as f:
105113
outputs = [cell.get("outputs") for cell in json.load(f).get("cells")]
106114
# There should be 7 cells in the notebook
@@ -113,7 +121,7 @@ def test_notebook_driver_with_outputs(self):
113121
# Check the JSON output file from the notebook
114122
with fsspec.open(
115123
os.path.join(output_uri, "data.json"),
116-
config=os.path.expanduser("~/.oci/config"),
124+
**default_signer(),
117125
) as f:
118126
data = json.load(f)
119127
# There should be 10 data points

tests/integration/jobs/test_jobs_runs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def job_run_test_infra(self):
117117
@staticmethod
118118
def list_objects(uri: str) -> list:
119119
"""Lists objects on OCI object storage."""
120-
oci_os = fsspec.filesystem("oci", config=oci.config.from_file())
120+
oci_os = fsspec.filesystem("oci", **default_signer())
121121
if uri.startswith("oci://"):
122122
uri = uri[len("oci://") :]
123123
items = oci_os.ls(uri, detail=False, refresh=True)
@@ -126,7 +126,7 @@ def list_objects(uri: str) -> list:
126126
@staticmethod
127127
def remove_objects(uri: str):
128128
"""Removes objects from OCI object storage."""
129-
oci_os = fsspec.filesystem("oci", config=oci.config.from_file())
129+
oci_os = fsspec.filesystem("oci", **default_signer())
130130
try:
131131
oci_os.rm(uri, recursive=True)
132132
except FileNotFoundError:

tests/integration/opctl/test_opctl_cli.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@
1515
)
1616
ADS_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
1717

18+
if "TEAMCITY_VERSION" in os.environ:
19+
# When running in TeamCity we specify dir, which is CHECKOUT_DIR="%teamcity.build.checkoutDir%"
20+
WORK_DIR = os.getenv("CHECKOUT_DIR", "~")
21+
CONDA_PACK_FOLDER = f"{WORK_DIR}/conda"
22+
else:
23+
CONDA_PACK_FOLDER = "~/conda"
24+
1825

1926
def _assert_run_command(cmd_str, expected_outputs: list = None):
2027
runner = CliRunner()
@@ -48,7 +55,7 @@ class TestLocalRunsWithConda:
4855
# For tests, we can always run the command in debug mode (-d)
4956
# By default, pytest only print the logs if the test is failed,
5057
# in which case we would like to see the debug logs.
51-
CMD_OPTIONS = "-d -b local "
58+
CMD_OPTIONS = f"-d -b local --conda-pack-folder {CONDA_PACK_FOLDER} "
5259

5360
def test_hello_world(self):
5461
test_folder = os.path.join(TESTS_FILES_DIR, "hello_world_test")
@@ -79,6 +86,9 @@ def test_linear_reg_test(self):
7986
]
8087
_assert_run_command(cmd, expected_outputs)
8188

89+
@pytest.mark.skip(
90+
reason="spark do not support instance principal - this test candidate to remove"
91+
)
8292
def test_spark_run(self):
8393
test_folder = os.path.join(TESTS_FILES_DIR, "spark_test")
8494
cmd = (

0 commit comments

Comments
 (0)