Skip to content

Commit eebe150

Browse files
committed
Add unit tests for the opctl backends
1 parent fa650c1 commit eebe150

File tree

6 files changed

+189
-19
lines changed

6 files changed

+189
-19
lines changed

ads/jobs/ads_job.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ def from_dict(cls, config: dict) -> "Job":
502502
Raises
503503
------
504504
NotImplementedError
505-
If the type of the intrastructure or runtime is not supported.
505+
If the type of the infrastructure or runtime is not supported.
506506
"""
507507
if not isinstance(config, dict):
508508
raise ValueError("The config data for initializing the job is invalid.")

ads/opctl/backend/ads_ml_pipeline.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,10 @@ def init(
130130

131131
# define a pipeline
132132
pipeline = (
133-
Pipeline(**(self.config.get("infrastructure", {}) or {}))
133+
Pipeline(
134+
name="Pipeline Name",
135+
spec=(self.config.get("infrastructure", {}) or {}),
136+
)
134137
.with_step_details([pipeline_step])
135138
.with_dag(["pipeline_step_name_1"])
136139
.build()

tests/unitary/with_extras/opctl/test_opctl_dataflow_backend.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,46 @@
33
# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

6+
import json
67
import os
78
import tempfile
89
from pathlib import Path
910
from unittest.mock import patch
10-
import json
1111

1212
import pytest
13+
import yaml
1314

1415
from ads.opctl.backend.ads_dataflow import DataFlowBackend
1516

1617

1718
class TestDataFlowBackend:
18-
def test_dataflow_apply(self):
19-
config = {
19+
@property
20+
def curr_dir(self):
21+
return os.path.dirname(os.path.abspath(__file__))
22+
23+
@property
24+
def config(self):
25+
return {
2026
"execution": {
2127
"backend": "dataflow",
28+
"auth": "api_key",
2229
"oci_profile": "DEFAULT",
2330
"oci_config": "~/.oci/config",
24-
}
31+
},
32+
"infrastructure": {
33+
"compartment_id": "ocid1.compartment.oc1..<unique_id>",
34+
"driver_shape": "VM.Standard.E2.4",
35+
"executor_shape": "VM.Standard.E2.4",
36+
"logs_bucket_uri": "oci://bucket@namespace",
37+
"script_bucket": "oci://bucket@namespace/prefix",
38+
"num_executors": "1",
39+
"spark_version": "3.2.1",
40+
},
2541
}
26-
with pytest.raises(ValueError):
27-
DataFlowBackend(config).apply()
42+
43+
def test_dataflow_apply(self):
44+
with pytest.raises(NotImplementedError):
45+
DataFlowBackend(self.config).apply()
2846

2947
@patch("ads.jobs.builders.infrastructure.dataflow.DataFlowApp.create")
3048
@patch("ads.opctl.backend.ads_dataflow.Job.run")
@@ -68,3 +86,30 @@ def test_dataflow_run(self, file_upload, job_run, job_create):
6886
"oci://<bucket_name>@<namespace>/<prefix>",
6987
False,
7088
)
89+
90+
@pytest.mark.parametrize(
91+
"runtime_type",
92+
["dataFlow", "dataFlowNotebook"],
93+
)
94+
def test_init(self, runtime_type, monkeypatch):
95+
"""Ensures that starter YAML can be generated for every supported runtime of the Data Flow."""
96+
monkeypatch.delenv("NB_SESSION_OCID", raising=False)
97+
98+
with tempfile.TemporaryDirectory() as td:
99+
test_yaml_uri = os.path.join(td, f"dataflow_{runtime_type}.yaml")
100+
expected_yaml_uri = os.path.join(
101+
self.curr_dir, "test_files", f"dataflow_{runtime_type}.yaml"
102+
)
103+
104+
DataFlowBackend(self.config).init(
105+
uri=test_yaml_uri,
106+
overwrite=False,
107+
runtime_type=runtime_type,
108+
)
109+
110+
with open(test_yaml_uri, "r") as stream:
111+
test_yaml_dict = yaml.safe_load(stream)
112+
with open(expected_yaml_uri, "r") as stream:
113+
expected_yaml_dict = yaml.safe_load(stream)
114+
115+
assert test_yaml_dict == expected_yaml_dict

tests/unitary/with_extras/opctl/test_opctl_ml_job_backend.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import tempfile
77
import os
88
from unittest.mock import patch
9+
import pytest
10+
import yaml
911

1012
from ads.opctl.backend.ads_ml_job import MLJobBackend
1113
from ads.jobs import Job, DataScienceJobRun
@@ -106,3 +108,31 @@ def test_run_with_image(self, rt, job_run, job_create):
106108
rt.with_cmd.assert_called_with("-n,hello-world,-c,~/.oci/config,-p,DEFAULT")
107109
job_create.assert_called()
108110
job_run.assert_called()
111+
112+
@pytest.mark.parametrize(
113+
"runtime_type",
114+
["container", "script", "python", "notebook", "gitPython"],
115+
)
116+
def test_init(self, runtime_type, monkeypatch):
117+
"""Ensures that starter YAML can be generated for every supported runtime of the Job."""
118+
119+
monkeypatch.delenv("NB_SESSION_OCID", raising=False)
120+
121+
with tempfile.TemporaryDirectory() as td:
122+
test_yaml_uri = os.path.join(td, f"job_{runtime_type}.yaml")
123+
expected_yaml_uri = os.path.join(
124+
self.curr_dir, "test_files", f"job_{runtime_type}.yaml"
125+
)
126+
127+
MLJobBackend(self.config).init(
128+
uri=test_yaml_uri,
129+
overwrite=False,
130+
runtime_type=runtime_type,
131+
)
132+
133+
with open(test_yaml_uri, "r") as stream:
134+
test_yaml_dict = yaml.safe_load(stream)
135+
with open(expected_yaml_uri, "r") as stream:
136+
expected_yaml_dict = yaml.safe_load(stream)
137+
138+
assert test_yaml_dict == expected_yaml_dict

tests/unitary/with_extras/opctl/test_opctl_ml_pipeline_backend.py

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,23 @@
33
# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

6-
import unittest
6+
7+
import os
8+
import tempfile
79
from unittest.mock import patch
810

9-
try:
10-
from ads.opctl.backend.ads_ml_pipeline import PipelineBackend
11-
from ads.pipeline import Pipeline, PipelineRun
12-
except (ImportError, AttributeError) as e:
13-
raise unittest.SkipTest(
14-
"OCI MLPipeline is not available. Skipping the MLPipeline tests."
15-
)
11+
import pytest
12+
import yaml
13+
14+
from ads.opctl.backend.ads_ml_pipeline import PipelineBackend
15+
from ads.pipeline import Pipeline, PipelineRun
1616

1717

1818
class TestMLPipelineBackend:
19+
@property
20+
def curr_dir(self):
21+
return os.path.dirname(os.path.abspath(__file__))
22+
1923
@property
2024
def config(self):
2125
return {
@@ -27,7 +31,13 @@ def config(self):
2731
"oci_profile": "DEFAULT",
2832
"ocid": "test",
2933
"auth": "api_key",
30-
}
34+
},
35+
"infrastructure": {
36+
"compartment_id": "ocid1.compartment.oc1..<unique_id>",
37+
"project_id": "ocid1.datascienceproject.oc1.<unique_id>",
38+
"log_group_id": "ocid1.loggroup.oc1.iad.<unique_id>",
39+
"log_id": "ocid1.log.oc1.iad.<unique_id>",
40+
},
3141
}
3242

3343
@patch(
@@ -123,3 +133,36 @@ def test_watch(self, mock_from_ocid, mock_watch):
123133
backend.watch()
124134
mock_from_ocid.assert_called_with("test_pipeline_run_id")
125135
mock_watch.assert_called_with(log_type="custom_log")
136+
137+
@pytest.mark.parametrize(
138+
"runtime_type",
139+
["container", "script", "python", "notebook", "gitPython"],
140+
)
141+
def test_init(self, runtime_type, monkeypatch):
142+
"""Ensures that starter YAML can be generated for every supported runtime of the Data Flow."""
143+
144+
monkeypatch.delenv("NB_SESSION_OCID", raising=False)
145+
monkeypatch.setenv(
146+
"NB_SESSION_COMPARTMENT_OCID",
147+
self.config["infrastructure"]["compartment_id"],
148+
)
149+
monkeypatch.setenv("PROJECT_OCID", self.config["infrastructure"]["project_id"])
150+
151+
with tempfile.TemporaryDirectory() as td:
152+
test_yaml_uri = os.path.join(td, f"pipeline_{runtime_type}.yaml")
153+
expected_yaml_uri = os.path.join(
154+
self.curr_dir, "test_files", f"pipeline_{runtime_type}.yaml"
155+
)
156+
157+
PipelineBackend(self.config).init(
158+
uri=test_yaml_uri,
159+
overwrite=False,
160+
runtime_type=runtime_type,
161+
)
162+
163+
with open(test_yaml_uri, "r") as stream:
164+
test_yaml_dict = yaml.safe_load(stream)
165+
with open(expected_yaml_uri, "r") as stream:
166+
expected_yaml_dict = yaml.safe_load(stream)
167+
168+
assert test_yaml_dict == expected_yaml_dict

tests/unitary/with_extras/opctl/test_opctl_model_deployment_backend.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,22 @@
44
# Copyright (c) 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7+
import os
8+
import tempfile
79
from unittest.mock import patch
810

9-
from ads.opctl.backend.ads_model_deployment import ModelDeploymentBackend
11+
import pytest
12+
import yaml
13+
1014
from ads.model import ModelDeployment
15+
from ads.opctl.backend.ads_model_deployment import ModelDeploymentBackend
1116

1217

1318
class TestModelDeploymentBackend:
19+
@property
20+
def curr_dir(self):
21+
return os.path.dirname(os.path.abspath(__file__))
22+
1423
@property
1524
def config(self):
1625
return {
@@ -27,7 +36,17 @@ def config(self):
2736
"log_type": "predict",
2837
"log_filter": "test_filter",
2938
"interval": 3,
30-
}
39+
},
40+
"infrastructure": {
41+
"compartment_id": "ocid1.compartment.oc1..<unique_id>",
42+
"project_id": "ocid1.datascienceproject.oc1.<unique_id>",
43+
"log_group_id": "ocid1.loggroup.oc1.iad.<unique_id>",
44+
"log_id": "ocid1.log.oc1.iad.<unique_id>",
45+
"shape_name": "VM.Standard.E2.4",
46+
"bandwidth_mbps": 10,
47+
"replica": 1,
48+
"web_concurrency": 10,
49+
},
3150
}
3251

3352
@patch("ads.opctl.backend.ads_model_deployment.ModelDeployment.deploy")
@@ -103,3 +122,33 @@ def test_watch(self, mock_from_id, mock_watch):
103122
mock_watch.assert_called_with(
104123
log_type="predict", interval=3, log_filter="test_filter"
105124
)
125+
126+
@pytest.mark.parametrize(
127+
"runtime_type",
128+
["container", "conda"],
129+
)
130+
def test_init(self, runtime_type, monkeypatch):
131+
"""Ensures that starter YAML can be generated for every supported runtime of the Data Flow."""
132+
133+
# For every supported runtime generate a YAML -> test_files
134+
# On second iteration remove a temporary code and compare result YAML.
135+
monkeypatch.delenv("NB_SESSION_OCID", raising=False)
136+
137+
with tempfile.TemporaryDirectory() as td:
138+
test_yaml_uri = os.path.join(td, f"modeldeployment_{runtime_type}.yaml")
139+
expected_yaml_uri = os.path.join(
140+
self.curr_dir, "test_files", f"modeldeployment_{runtime_type}.yaml"
141+
)
142+
143+
ModelDeploymentBackend(self.config).init(
144+
uri=test_yaml_uri,
145+
overwrite=False,
146+
runtime_type=runtime_type,
147+
)
148+
149+
with open(test_yaml_uri, "r") as stream:
150+
test_yaml_dict = yaml.safe_load(stream)
151+
with open(expected_yaml_uri, "r") as stream:
152+
expected_yaml_dict = yaml.safe_load(stream)
153+
154+
assert test_yaml_dict == expected_yaml_dict

0 commit comments

Comments
 (0)