|
| 1 | +#!/usr/bin/env python |
| 2 | +from typing import Union |
| 3 | + |
| 4 | +# Copyright (c) 2024 Oracle and/or its affiliates. |
| 5 | +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ |
| 6 | +import pytest |
| 7 | +from ads.opctl.operator.lowcode.common.utils import ( |
| 8 | + load_data, |
| 9 | +) |
| 10 | +from ads.opctl.operator.common.operator_config import InputData |
| 11 | +from unittest.mock import patch, Mock, MagicMock |
| 12 | +import unittest |
| 13 | +import pandas as pd |
| 14 | + |
| 15 | +mock_secret = { |
| 16 | + 'user_name': 'mock_user', |
| 17 | + 'password': 'mock_password', |
| 18 | + 'service_name': 'mock_service_name' |
| 19 | +} |
| 20 | + |
| 21 | +mock_connect_args = { |
| 22 | + 'user': 'mock_user', |
| 23 | + 'password': 'mock_password', |
| 24 | + 'service_name': 'mock_service_name', |
| 25 | + 'dsn': 'mock_dsn' |
| 26 | +} |
| 27 | + |
| 28 | +# Mock data for testing |
| 29 | +mock_data = pd.DataFrame({ |
| 30 | + 'id': [1, 2, 3], |
| 31 | + 'name': ['Alice', 'Bob', 'Charlie'] |
| 32 | +}) |
| 33 | + |
| 34 | +mock_db_connection = MagicMock() |
| 35 | + |
| 36 | +load_secret_err_msg = "Vault exception message" |
| 37 | +db_connect_err_msg = "Mocked DB connection error" |
| 38 | + |
| 39 | + |
| 40 | +def mock_oracledb_connect_failure(*args, **kwargs): |
| 41 | + raise Exception(db_connect_err_msg) |
| 42 | + |
| 43 | + |
| 44 | +def mock_oracledb_connect(**kwargs): |
| 45 | + assert kwargs == mock_connect_args, f"Expected connect_args {mock_connect_args}, but got {kwargs}" |
| 46 | + return mock_db_connection |
| 47 | + |
| 48 | + |
| 49 | +class MockADBSecretKeeper: |
| 50 | + @staticmethod |
| 51 | + def __enter__(*args, **kwargs): |
| 52 | + return mock_secret |
| 53 | + |
| 54 | + @staticmethod |
| 55 | + def __exit__(*args, **kwargs): |
| 56 | + pass |
| 57 | + |
| 58 | + @staticmethod |
| 59 | + def load_secret(vault_secret_id, wallet_dir): |
| 60 | + return MockADBSecretKeeper() |
| 61 | + |
| 62 | + @staticmethod |
| 63 | + def load_secret_fail(*args, **kwargs): |
| 64 | + raise Exception(load_secret_err_msg) |
| 65 | + |
| 66 | + |
| 67 | +class TestDataLoad(unittest.TestCase): |
| 68 | + def setUp(self): |
| 69 | + self.data_spec = Mock(spec=InputData) |
| 70 | + self.data_spec.connect_args = { |
| 71 | + 'dsn': 'mock_dsn' |
| 72 | + } |
| 73 | + self.data_spec.vault_secret_id = 'mock_secret_id' |
| 74 | + self.data_spec.table_name = 'mock_table_name' |
| 75 | + self.data_spec.url = None |
| 76 | + self.data_spec.format = None |
| 77 | + self.data_spec.columns = None |
| 78 | + self.data_spec.limit = None |
| 79 | + |
| 80 | + def testLoadSecretAndDBConnection(self): |
| 81 | + with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=MockADBSecretKeeper.load_secret): |
| 82 | + with patch('oracledb.connect', side_effect=mock_oracledb_connect): |
| 83 | + with patch('pandas.read_sql', return_value=mock_data) as mock_read_sql: |
| 84 | + data = load_data(self.data_spec) |
| 85 | + mock_read_sql.assert_called_once_with(f"SELECT * FROM {self.data_spec.table_name}", |
| 86 | + mock_db_connection) |
| 87 | + pd.testing.assert_frame_equal(data, mock_data) |
| 88 | + |
| 89 | + def testLoadVaultFailure(self): |
| 90 | + with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=MockADBSecretKeeper.load_secret_fail): |
| 91 | + with pytest.raises(Exception) as e: |
| 92 | + load_data(self.data_spec) |
| 93 | + |
| 94 | + expected_msg = f"Could not retrieve database credentials from vault {self.data_spec.vault_secret_id}: {load_secret_err_msg}" |
| 95 | + assert str(e.value) == expected_msg, f"Expected exception message '{expected_msg}', but got '{str(e)}'" |
| 96 | + |
| 97 | + def testDBConnectionFailure(self): |
| 98 | + with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=MockADBSecretKeeper.load_secret): |
| 99 | + with patch('oracledb.connect', side_effect=mock_oracledb_connect_failure): |
| 100 | + with pytest.raises(Exception) as e: |
| 101 | + load_data(self.data_spec) |
| 102 | + |
| 103 | + assert str(e.value) == db_connect_err_msg , f"Expected exception message '{db_connect_err_msg }', but got '{str(e)}'" |
0 commit comments