Skip to content

Commit f8c09fd

Browse files
committed
add more unit tests
1 parent 4a6bca9 commit f8c09fd

File tree

2 files changed

+219
-5
lines changed

2 files changed

+219
-5
lines changed

ads/opctl/operator/lowcode/feature_store_marketplace/operator_utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,11 @@ def _create_new_stack(apigw_config: APIGatewayConfig):
157157
return stack.id
158158

159159

160-
def detect_or_create_stack(apigw_config: APIGatewayConfig):
161-
def _print_stack_detail(stack: StackSummary):
162-
print(
163-
f"Detected stack :'{stack.display_name}' created on: '{stack.time_created}'"
164-
)
160+
def _print_stack_detail(stack: StackSummary):
161+
print(f"Detected stack :'{stack.display_name}' created on: '{stack.time_created}'")
165162

163+
164+
def detect_or_create_stack(apigw_config: APIGatewayConfig):
166165
resource_manager_client: oci.resource_manager.ResourceManagerClient = (
167166
OCIClientFactory(**authutil.default_signer()).create_client(
168167
oci.resource_manager.ResourceManagerClient

tests/unitary/with_extras/operator/feature-store/test_operator_utils.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1+
import base64
12
from unittest.mock import Mock, patch
23

4+
import oci.resource_manager.models
35
import pytest
6+
from typing import List
7+
8+
from ads.opctl.operator.lowcode.feature_store_marketplace.models.apigw_config import (
9+
APIGatewayConfig,
10+
)
411

512
from ads.opctl.operator.lowcode.feature_store_marketplace.const import LISTING_ID
613
from ads.opctl.operator.lowcode.feature_store_marketplace.models.db_config import (
@@ -13,6 +20,10 @@
1320
from ads.opctl.operator.lowcode.feature_store_marketplace.operator_utils import (
1421
get_db_details,
1522
get_latest_listing_version,
23+
_create_new_stack,
24+
detect_or_create_stack,
25+
get_admin_group,
26+
get_api_gw_details,
1627
)
1728

1829

@@ -114,3 +125,207 @@ def throw_exception(*args, **kwargs):
114125
client_mock.get_listing.side_effect = throw_exception
115126
with pytest.raises(TestException):
116127
get_latest_listing_version("compartment_id")
128+
129+
130+
@patch(
131+
"ads.opctl.operator.lowcode.feature_store_marketplace.operator_utils.OCIClientFactory"
132+
)
133+
@patch("ads.common.auth.default_signer")
134+
@patch("ads.opctl.operator.lowcode.feature_store_marketplace.operator_utils.requests")
135+
def test_create_new_stack(requests_mock: Mock, auth_mock: Mock, client_factory: Mock):
136+
zip_file = b"hello"
137+
requests_mock.get = Mock()
138+
requests_mock.get.return_value = Mock(content=zip_file)
139+
140+
def validate_stack_details(
141+
stack_details: oci.resource_manager.models.CreateStackDetails,
142+
):
143+
source_details: oci.resource_manager.models.CreateZipUploadConfigSourceDetails = (
144+
stack_details.config_source
145+
)
146+
assert (
147+
source_details.zip_file_base64_encoded
148+
== base64.b64encode(zip_file).decode()
149+
)
150+
stack = oci.resource_manager.models.Stack()
151+
stack.id = "ID"
152+
return oci.Response(data=stack, request=None, headers=None, status=None)
153+
154+
client_mock = Mock()
155+
client_factory.return_value = Mock(create_client=Mock(return_value=client_mock))
156+
client_mock.create_stack = Mock()
157+
client_mock.create_stack.side_effect = validate_stack_details
158+
_create_new_stack(APIGatewayConfig())
159+
assert client_mock.create_stack.call_count == 1
160+
161+
162+
@patch(
163+
"ads.opctl.operator.lowcode.feature_store_marketplace.operator_utils.OCIClientFactory"
164+
)
165+
@patch("ads.common.auth.default_signer")
166+
@patch(
167+
"ads.opctl.operator.lowcode.feature_store_marketplace.operator_utils._create_new_stack"
168+
)
169+
@patch(
170+
"ads.opctl.operator.lowcode.feature_store_marketplace.operator_utils._print_stack_detail"
171+
)
172+
@patch("click.prompt")
173+
def test_detect_or_create_new_stack(
174+
mock_prompt: Mock,
175+
print_mock: Mock,
176+
create_new_mock_stack: Mock,
177+
auth_mock: Mock,
178+
client_factory: Mock,
179+
):
180+
import oci.resource_manager.models as models
181+
182+
ocid = "id"
183+
mock_prompt.side_effect = ["1"]
184+
client_mock = Mock()
185+
client_factory.return_value = Mock(create_client=Mock(return_value=client_mock))
186+
client_mock.list_stacks = Mock()
187+
create_new_mock_stack.return_value = ocid
188+
stacks: List[models.StackSummary] = [models.StackSummary(), models.StackSummary()]
189+
client_mock.list_stacks.return_value = oci.Response(
190+
data=stacks, request=None, headers=None, status=None
191+
)
192+
assert detect_or_create_stack(apigw_config=APIGatewayConfig()) == ocid
193+
assert print_mock.call_count == len(stacks)
194+
assert create_new_mock_stack.call_count == 1
195+
196+
197+
@patch(
198+
"ads.opctl.operator.lowcode.feature_store_marketplace.operator_utils.OCIClientFactory"
199+
)
200+
@patch("ads.common.auth.default_signer")
201+
@patch(
202+
"ads.opctl.operator.lowcode.feature_store_marketplace.operator_utils._print_stack_detail"
203+
)
204+
@patch("click.prompt")
205+
def test_detect_or_create_existing_stack(
206+
mock_prompt: Mock,
207+
print_mock: Mock,
208+
auth_mock: Mock,
209+
client_factory: Mock,
210+
):
211+
ocid = "id"
212+
import oci.resource_manager.models as models
213+
214+
mock_prompt.side_effect = ["2", ocid]
215+
client_mock = Mock()
216+
client_factory.return_value = Mock(create_client=Mock(return_value=client_mock))
217+
client_mock.list_stacks = Mock()
218+
219+
stacks: List[models.StackSummary] = []
220+
client_mock.list_stacks.return_value = oci.Response(
221+
data=stacks, request=None, headers=None, status=None
222+
)
223+
assert detect_or_create_stack(apigw_config=APIGatewayConfig()) == ocid
224+
assert print_mock.call_count == len(stacks)
225+
226+
227+
@patch(
228+
"ads.opctl.operator.lowcode.feature_store_marketplace.operator_utils.OCIClientFactory"
229+
)
230+
@patch("ads.common.auth.default_signer")
231+
def test_get_admin_group(auth_mock: Mock, client_factory: Mock):
232+
import oci.identity.models as models
233+
234+
ocid = "id"
235+
client_mock = Mock()
236+
client_factory.return_value = Mock(create_client=Mock(return_value=client_mock))
237+
client_mock.list_groups = Mock()
238+
groups: List[models.Group] = [models.Group(), models.Group()]
239+
groups[0].id = ocid
240+
client_mock.list_groups.return_value = oci.Response(
241+
data=groups, request=None, headers=None, status=None
242+
)
243+
assert get_admin_group("tenant_id") == ocid
244+
assert client_mock.list_groups.call_count == 1
245+
246+
247+
@patch(
248+
"ads.opctl.operator.lowcode.feature_store_marketplace.operator_utils.OCIClientFactory"
249+
)
250+
@patch("ads.common.auth.default_signer")
251+
def test_get_admin_group_with_no_groups(auth_mock: Mock, client_factory: Mock):
252+
import oci.identity.models as models
253+
254+
ocid = "id"
255+
client_mock = Mock()
256+
client_factory.return_value = Mock(create_client=Mock(return_value=client_mock))
257+
client_mock.list_groups = Mock()
258+
groups: List[models.Group] = []
259+
client_mock.list_groups.return_value = oci.Response(
260+
data=groups, request=None, headers=None, status=None
261+
)
262+
assert get_admin_group("tenant_id") is None
263+
assert client_mock.list_groups.call_count == 1
264+
265+
266+
@patch("click.prompt")
267+
@patch("click.confirm")
268+
@patch("ads.opctl.operator.lowcode.feature_store_marketplace.operator_utils.get_region")
269+
@patch(
270+
"ads.opctl.operator.lowcode.feature_store_marketplace.operator_utils.detect_or_create_stack"
271+
)
272+
@patch(
273+
"ads.opctl.operator.lowcode.feature_store_marketplace.operator_utils.get_admin_group"
274+
)
275+
def test_get_api_gw_details(
276+
mock_get_admin_group: Mock,
277+
mock_detect_or_create_stack: Mock,
278+
mock_region: Mock,
279+
mock_confirm: Mock,
280+
mock_prompt: Mock,
281+
):
282+
admin_group = "group"
283+
comp_id = "comp_id"
284+
stack_id = "stack_id"
285+
region = "ashburn"
286+
mock_confirm.side_effect = ["Y"]
287+
mock_region.return_value = None
288+
mock_prompt.side_effect = [comp_id, region, admin_group]
289+
mock_get_admin_group.return_value = ""
290+
mock_detect_or_create_stack.return_value = stack_id
291+
api_gw_details: APIGatewayConfig = get_api_gw_details("")
292+
assert mock_detect_or_create_stack.call_count == 1
293+
assert mock_confirm.call_count == 1
294+
assert mock_prompt.call_count == 3
295+
assert api_gw_details.root_compartment_id == comp_id
296+
assert api_gw_details.stack_id == stack_id
297+
assert api_gw_details.region == region
298+
299+
300+
@patch("click.prompt")
301+
@patch("click.confirm")
302+
@patch("ads.opctl.operator.lowcode.feature_store_marketplace.operator_utils.get_region")
303+
@patch(
304+
"ads.opctl.operator.lowcode.feature_store_marketplace.operator_utils.detect_or_create_stack"
305+
)
306+
@patch(
307+
"ads.opctl.operator.lowcode.feature_store_marketplace.operator_utils.get_admin_group"
308+
)
309+
def test_get_api_gw_details_auto_detect_compartment_and_region(
310+
mock_get_admin_group: Mock,
311+
mock_detect_or_create_stack: Mock,
312+
mock_region: Mock,
313+
mock_confirm: Mock,
314+
mock_prompt: Mock,
315+
):
316+
admin_group = "group"
317+
comp_id = "tenancy"
318+
stack_id = "stack_id"
319+
region = "ashburn"
320+
mock_confirm.side_effect = ["Y"]
321+
mock_region.return_value = region
322+
mock_prompt.side_effect = [admin_group]
323+
mock_get_admin_group.return_value = ""
324+
mock_detect_or_create_stack.return_value = stack_id
325+
api_gw_details: APIGatewayConfig = get_api_gw_details(comp_id)
326+
assert mock_detect_or_create_stack.call_count == 1
327+
assert mock_confirm.call_count == 1
328+
assert mock_prompt.call_count == 1
329+
assert api_gw_details.root_compartment_id == comp_id
330+
assert api_gw_details.stack_id == stack_id
331+
assert api_gw_details.region == region

0 commit comments

Comments
 (0)