Skip to content

Commit 8b5eae3

Browse files
committed
Adding bugfixes and unit tests
1 parent b6dfdd2 commit 8b5eae3

File tree

5 files changed

+165
-107
lines changed

5 files changed

+165
-107
lines changed

ads/opctl/backend/marketplace/marketplace_utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,6 @@ def _export_helm_chart_(listing_details: HelmMarketplaceListingDetails):
140140
# Get the data from response
141141

142142

143-
def get_marketplace_request_status(work_request_id):
144-
get_work_request_response = get_marketplace_client().get_work_request(
145-
work_request_id=work_request_id
146-
)
147-
return get_work_request_response.data
148-
149-
150143
def list_container_images(
151144
compartment_id: str, ocir_image_path: str
152145
) -> oci.artifacts.models.ContainerImageCollection:

ads/opctl/operator/lowcode/feature_store_marketplace/cmd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
from ads.opctl.operator.lowcode.feature_store_marketplace.operator_utils import (
1111
get_latest_listing_version,
1212
get_api_gw_details,
13+
get_db_details,
1314
)
1415

1516
from ads.opctl.operator.lowcode.feature_store_marketplace.models.db_config import (
1617
DBConfig,
1718
)
18-
from ads.opctl.operator.lowcode.feature_store_marketplace.prompts import get_db_details
1919
from ads.opctl.backend.marketplace.marketplace_utils import Color, print_heading
2020
from ads.opctl.operator.common.utils import _load_yaml_from_uri
2121
from ads.opctl.operator.common.operator_yaml_generator import YamlGenerator

ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py

Lines changed: 48 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def get_db_details() -> DBConfig:
7070
prefix_newline_count=2,
7171
)
7272
mysql_db_config.username = click.prompt("Username", default="admin")
73+
7374
mysql_db_config.auth_type = MySqlConfig.MySQLAuthType(
7475
click.prompt(
7576
"Is password provided as plain-text or via a Vault secret?\n"
@@ -128,7 +129,40 @@ def get_region() -> Optional[str]:
128129
return None
129130

130131

132+
def _create_new_stack(apigw_config: APIGatewayConfig):
133+
resource_manager_client: oci.resource_manager.ResourceManagerClient = (
134+
OCIClientFactory(**authutil.default_signer()).create_client(
135+
oci.resource_manager.ResourceManagerClient
136+
)
137+
)
138+
print("Creating new api gateway stack...")
139+
response = requests.get(STACK_URL)
140+
source_details = oci.resource_manager.models.CreateZipUploadConfigSourceDetails()
141+
source_details.zip_file_base64_encoded = base64.b64encode(response.content).decode()
142+
stack_details = oci.resource_manager.models.CreateStackDetails()
143+
stack_details.compartment_id = apigw_config.root_compartment_id
144+
stack_details.display_name = APIGW_STACK_NAME
145+
stack_details.config_source = source_details
146+
stack_details.variables = {
147+
"nlb_id": "",
148+
"tenancy_ocid": apigw_config.root_compartment_id,
149+
"function_img_ocir_url": "",
150+
"authorized_user_groups": apigw_config.authorized_user_groups,
151+
"region": apigw_config.region,
152+
}
153+
stack: oci.resource_manager.models.Stack = resource_manager_client.create_stack(
154+
stack_details
155+
).data
156+
print(f"Created stack {stack.display_name} with id {stack.id}")
157+
return stack.id
158+
159+
131160
def detect_or_create_stack(apigw_config: APIGatewayConfig):
161+
def _print_stack_detail(stack: StackSummary):
162+
print(
163+
f"Detected stack :'{stack.display_name}' created on: '{stack.time_created}'"
164+
)
165+
132166
resource_manager_client: oci.resource_manager.ResourceManagerClient = (
133167
OCIClientFactory(**authutil.default_signer()).create_client(
134168
oci.resource_manager.ResourceManagerClient
@@ -142,46 +176,21 @@ def detect_or_create_stack(apigw_config: APIGatewayConfig):
142176
sort_order="DESC",
143177
).data
144178

145-
if len(stacks) == 1:
146-
print(
147-
f"Auto-detected feature store APIGW stack: {stacks[0].display_name}({stacks[0].id}"
148-
)
149-
click.prompt(
150-
f"Auto detected existing feature store stack: '{stacks[0].display_name}({stacks[0].id}'\n.Provide an OCID to use or",
151-
default=stacks[0].id,
152-
)
153-
return stacks[0].id
154-
elif len(stacks) == 0:
155-
if not click.confirm(
156-
f"Couldn't detect any existing feature store api gateway stack. Should we create one?",
157-
default=True,
158-
):
159-
return click.prompt(
160-
"Enter the resource manager stack OCID of the stack to use"
161-
)
162-
print("Creating feature store API Gateway stack")
163-
response = requests.get(STACK_URL)
164-
source_details = (
165-
oci.resource_manager.models.CreateZipUploadConfigSourceDetails()
166-
)
167-
source_details.zip_file_base64_encoded = base64.b64encode(
168-
response.content
169-
).decode()
170-
stack_details = oci.resource_manager.models.CreateStackDetails()
171-
stack_details.compartment_id = apigw_config.root_compartment_id
172-
stack_details.display_name = APIGW_STACK_NAME
173-
stack_details.config_source = source_details
174-
stack_details.variables = {
175-
"nlb_id": "",
176-
"tenancy_ocid": apigw_config.root_compartment_id,
177-
"function_img_ocir_url": "",
178-
"authorized_user_groups": apigw_config.authorized_user_groups,
179-
"region": apigw_config.region,
180-
}
181-
return resource_manager_client.create_stack(stack_details).data.id
182-
elif len(stacks) > 1:
179+
if len(stacks) >= 1:
180+
print(f"Auto-detected feature store stack(s) in tenancy:")
181+
for stack in stacks:
182+
_print_stack_detail(stack)
183+
choices = {"1": "new", "2": "existing"}
184+
stack_provision_method = click.prompt(
185+
f"Select stack provisioning method:\n1.Create new stack\n2.Existing stack\n",
186+
type=click.Choice(list(choices.keys())),
187+
)
188+
if choices[stack_provision_method] == "new":
189+
return _create_new_stack(apigw_config)
190+
else:
183191
return click.prompt(
184-
"Multiple feature store apigw stacks detected. Please enter the resource manager stack OCID of the stack to use:"
192+
"Enter the resource manager stack OCID of the stack to use",
193+
show_choices=False,
185194
)
186195

187196

ads/opctl/operator/lowcode/feature_store_marketplace/prompts.py

Lines changed: 0 additions & 60 deletions
This file was deleted.
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from unittest.mock import Mock, patch
2+
3+
import pytest
4+
5+
from ads.opctl.operator.lowcode.feature_store_marketplace.const import LISTING_ID
6+
from ads.opctl.operator.lowcode.feature_store_marketplace.models.db_config import (
7+
DBConfig,
8+
)
9+
10+
from ads.opctl.operator.lowcode.feature_store_marketplace.models.mysql_config import (
11+
MySqlConfig,
12+
)
13+
from ads.opctl.operator.lowcode.feature_store_marketplace.operator_utils import (
14+
get_db_details,
15+
get_latest_listing_version,
16+
)
17+
18+
19+
@patch("click.prompt")
20+
def test_get_db_basic_details(prompt_mock: Mock):
21+
IP = "10.0.10.122:3306"
22+
DATABASE_NAME = "featurestore"
23+
expected_db_config = MySqlConfig()
24+
expected_db_config.username = "username"
25+
expected_db_config.auth_type = MySqlConfig.MySQLAuthType.BASIC.value
26+
expected_db_config.basic_config = MySqlConfig.BasicConfig()
27+
expected_db_config.basic_config.password = "password"
28+
expected_db_config.url = (
29+
f"jdbc:mysql://{IP}/{DATABASE_NAME}?createDatabaseIfNotExist=true"
30+
)
31+
32+
prompt_mock.side_effect = [
33+
expected_db_config.username,
34+
expected_db_config.auth_type,
35+
expected_db_config.basic_config.password,
36+
IP,
37+
DATABASE_NAME,
38+
]
39+
db_config: DBConfig = get_db_details()
40+
41+
assert db_config.mysql_config.url == expected_db_config.url
42+
assert db_config.mysql_config.username == expected_db_config.username
43+
assert (
44+
db_config.mysql_config.basic_config.password
45+
== expected_db_config.basic_config.password
46+
)
47+
48+
49+
@patch("click.prompt")
50+
def test_get_db_vault_details(prompt_mock: Mock):
51+
IP = "10.0.10.122:3306"
52+
DATABASE_NAME = "featurestore"
53+
expected_db_config = MySqlConfig()
54+
expected_db_config.username = "username"
55+
expected_db_config.auth_type = MySqlConfig.MySQLAuthType.VAULT.value
56+
expected_db_config.vault_config = MySqlConfig.VaultConfig()
57+
expected_db_config.vault_config.vault_ocid = "vaultocid"
58+
expected_db_config.vault_config.secret_name = "secretname"
59+
expected_db_config.url = (
60+
f"jdbc:mysql://{IP}/{DATABASE_NAME}?createDatabaseIfNotExist=true"
61+
)
62+
prompt_mock.side_effect = [
63+
expected_db_config.username,
64+
expected_db_config.auth_type,
65+
expected_db_config.vault_config.vault_ocid,
66+
expected_db_config.vault_config.secret_name,
67+
IP,
68+
DATABASE_NAME,
69+
]
70+
db_config: DBConfig = get_db_details()
71+
72+
assert db_config.mysql_config.url == expected_db_config.url
73+
assert db_config.mysql_config.username == expected_db_config.username
74+
assert (
75+
db_config.mysql_config.vault_config.vault_ocid
76+
== expected_db_config.vault_config.vault_ocid
77+
)
78+
assert (
79+
db_config.mysql_config.vault_config.secret_name
80+
== expected_db_config.vault_config.secret_name
81+
)
82+
83+
84+
@patch(
85+
"ads.opctl.operator.lowcode.feature_store_marketplace.operator_utils.OCIClientFactory"
86+
)
87+
@patch("ads.common.auth.default_signer")
88+
def test_get_latest_listing_revision(auth_mock: Mock, client_factory: Mock):
89+
client_mock = Mock()
90+
client_factory.return_value = Mock(create_client=Mock(return_value=client_mock))
91+
get_latest_listing_version("compartment_id")
92+
client_mock.get_listing.assert_called_once_with(
93+
LISTING_ID, compartment_id="compartment_id"
94+
)
95+
96+
97+
@patch(
98+
"ads.opctl.operator.lowcode.feature_store_marketplace.operator_utils.OCIClientFactory"
99+
)
100+
@patch("ads.common.auth.default_signer")
101+
def test_get_latest_listing_revision_with_exception(
102+
auth_mock: Mock, client_factory: Mock
103+
):
104+
class TestException(Exception):
105+
pass
106+
107+
def throw_exception(*args, **kwargs):
108+
raise TestException()
109+
110+
client_mock = Mock()
111+
112+
client_factory.return_value = Mock(create_client=Mock(return_value=client_mock))
113+
client_mock.get_listing = Mock()
114+
client_mock.get_listing.side_effect = throw_exception
115+
with pytest.raises(TestException):
116+
get_latest_listing_version("compartment_id")

0 commit comments

Comments
 (0)