Skip to content

Commit 61fd464

Browse files
committed
added unit test and addressed PR review comments
1 parent 6257e13 commit 61fd464

File tree

4 files changed

+53
-8
lines changed

4 files changed

+53
-8
lines changed

ads/aqua/client/client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -590,10 +590,11 @@ def list_models(self) -> Union[Dict[str, Any], Iterator[Mapping[str, Any]]]:
590590
Returns:
591591
Union[Dict[str, Any], Iterator[Mapping[str, Any]]]: The server's response, typically including the generated embeddings.
592592
"""
593-
headers = {"Content-Type", "application/json"}
594-
response = self._client.get(self.endpoint, headers=headers, json={}).json()
593+
# headers = {"Content-Type", "application/json"}
594+
response = self._client.get(self.endpoint)
595+
logger.debug(f"Response JSON: {response}")
595596
json_response = response.json()
596-
logger.debug(f"Response JSON: {json_response}")
597+
print(json_response)
597598
return json_response
598599

599600

ads/aqua/config/container_config.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
UNKNOWN_JSON_STR,
1616
)
1717
from ads.common.extended_enum import ExtendedEnum
18+
from ads.common.utils import UNKNOWN
1819

1920

2021
class Usage(ExtendedEnum):
@@ -189,10 +190,22 @@ def from_service_config(
189190
container_item.model_formats.append(
190191
additional_configurations.get("modelFormats")
191192
)
192-
env_vars_dict = json.loads(
193-
additional_configurations.get("env_vars") or "{}"
194-
)
195-
env_vars = [{key: str(value)} for key, value in env_vars_dict.items()]
193+
194+
# Parse environment variables from `additional_configurations`.
195+
# Only keys present in the configuration will be added to the result.
196+
config_keys = {
197+
"MODEL_DEPLOY_PREDICT_ENDPOINT": UNKNOWN,
198+
"MODEL_DEPLOY_HEALTH_ENDPOINT": UNKNOWN,
199+
"PORT": UNKNOWN,
200+
"HEALTH_CHECK_PORT": UNKNOWN,
201+
"VLLM_USE_V1": UNKNOWN,
202+
}
203+
204+
env_vars = [
205+
{key: additional_configurations.get(key, default)}
206+
for key, default in config_keys.items()
207+
if key in additional_configurations
208+
]
196209

197210
# Build container spec
198211
container_item.spec = AquaContainerConfigSpec(

ads/aqua/extension/deployment_handler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,11 +395,13 @@ def get(self, model_deployment_id):
395395
self.set_header("Content-Type", "application/json")
396396

397397
model_deployment = AquaDeploymentApp().get(model_deployment_id)
398-
endpoint = model_deployment.endpoint + "/predict/v1/models"
399398

399+
endpoint = model_deployment.endpoint + "/predict/v1/models"
400+
print(endpoint)
400401
aqua_client = Client(endpoint=endpoint)
401402
try:
402403
list_model_result = aqua_client.list_models()
404+
print(list_model_result)
403405
return self.finish(list_model_result)
404406
except Exception as ex:
405407
raise HTTPError(500, str(ex))

tests/unitary/with_extras/aqua/test_deployment_handler.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
AquaDeploymentHandler,
1919
AquaDeploymentParamsHandler,
2020
AquaDeploymentStreamingInferenceHandler,
21+
AquaModelListHandler,
2122
)
2223

2324

@@ -260,3 +261,31 @@ def test_post(self, mock_get_model_deployment_response):
260261
self.handler.write.assert_any_call("chunk1")
261262
self.handler.write.assert_any_call("chunk2")
262263
self.handler.finish.assert_called_once()
264+
265+
266+
class AquaModelListHandlerTestCase(unittest.TestCase):
267+
default_params = ["--seed 42", "--trust-remote-code"]
268+
269+
@patch.object(IPythonHandler, "__init__")
270+
def setUp(self, ipython_init_mock) -> None:
271+
ipython_init_mock.return_value = None
272+
self.test_instance = AquaModelListHandler(MagicMock(), MagicMock())
273+
274+
@patch("notebook.base.handlers.APIHandler.finish")
275+
# @patch("ads.aqua.modeldeployment.AquaDeploymentApp.get_deployment_default_params")
276+
def test_get_model_list(self, mock_get_model_list_default_params, mock_finish):
277+
"""Test to check the handler get method to return model list."""
278+
279+
mock_get_model_list_default_params.return_value = self.default_params
280+
mock_finish.side_effect = lambda x: x
281+
282+
# args = {"instance_shape": TestDataset.INSTANCE_SHAPE}
283+
# self.test_instance.get_argument = MagicMock(
284+
# side_effect=lambda arg, default=None : args.get(arg, default)
285+
# )
286+
result = self.test_instance.get(model_id="test_model_id")
287+
self.assertCountEqual(result["data"], self.default_params)
288+
289+
mock_get_model_list_default_params.assert_called_with(
290+
model_id="test_model_id",
291+
)

0 commit comments

Comments
 (0)