|
45 | 45 | from ads.model.service.oci_datascience_model import OCIDataScienceModel |
46 | 46 |
|
47 | 47 |
|
48 | | -# Fixture that reloads the module before any patching is applied. |
49 | | -@pytest.fixture(autouse=True, scope="class") |
50 | | -def reload_model_module(): |
51 | | - reload(ads.aqua.model.model) |
52 | | - yield |
53 | | - |
54 | | - |
55 | 48 | @pytest.fixture(autouse=True, scope="class") |
56 | 49 | def mock_auth(): |
57 | 50 | with patch("ads.common.auth.default_signer") as mock_default_signer: |
58 | 51 | yield mock_default_signer |
59 | 52 |
|
60 | 53 |
|
| 54 | +def get_container_config(): |
| 55 | + with open( |
| 56 | + os.path.join( |
| 57 | + os.path.dirname(os.path.abspath(__file__)), |
| 58 | + "test_data/ui/container_index.json", |
| 59 | + ), |
| 60 | + "r", |
| 61 | + ) as _file: |
| 62 | + container_index_json = json.load(_file) |
| 63 | + |
| 64 | + return container_index_json |
| 65 | + |
| 66 | + |
61 | 67 | @pytest.fixture(autouse=True, scope="class") |
62 | 68 | def mock_get_container_config(): |
63 | 69 | with patch("ads.aqua.model.model.get_container_config") as mock_config: |
64 | | - with open( |
65 | | - os.path.join( |
66 | | - os.path.dirname(os.path.abspath(__file__)), |
67 | | - "test_data/ui/container_index.json", |
68 | | - ), |
69 | | - "r", |
70 | | - ) as _file: |
71 | | - container_index_json = json.load(_file) |
72 | | - mock_config.return_value = container_index_json |
| 70 | + mock_config.return_value = get_container_config() |
73 | 71 | yield mock_config |
74 | 72 |
|
75 | 73 |
|
@@ -283,7 +281,7 @@ def setup_class(cls): |
283 | 281 | os.environ["ODSC_MODEL_COMPARTMENT_OCID"] = TestDataset.SERVICE_COMPARTMENT_ID |
284 | 282 | reload(ads.config) |
285 | 283 | reload(ads.aqua) |
286 | | - # reload(ads.aqua.model.model) |
| 284 | + reload(ads.aqua.model.model) |
287 | 285 |
|
288 | 286 | @classmethod |
289 | 287 | def teardown_class(cls): |
@@ -382,6 +380,7 @@ def test_get_foundation_models( |
382 | 380 | mock_get_container_config, |
383 | 381 | mock_auth, |
384 | 382 | ): |
| 383 | + mock_get_container_config.return_value = get_container_config() |
385 | 384 | ds_model = MagicMock() |
386 | 385 | ds_model.id = "test_id" |
387 | 386 | ds_model.compartment_id = "test_compartment_id" |
@@ -496,6 +495,7 @@ def test_get_model_fine_tuned( |
496 | 495 | mock_get_container_config, |
497 | 496 | mock_auth, |
498 | 497 | ): |
| 498 | + mock_get_container_config.return_value = get_container_config() |
499 | 499 | ds_model = MagicMock() |
500 | 500 | ds_model.id = "test_id" |
501 | 501 | ds_model.compartment_id = "test_model_compartment_id" |
|
0 commit comments