diff --git a/ads/aqua/app.py b/ads/aqua/app.py index 95b77b6a8..3675bdf2a 100644 --- a/ads/aqua/app.py +++ b/ads/aqua/app.py @@ -6,14 +6,14 @@ import os import traceback from dataclasses import fields -from typing import Dict, Union +from typing import Dict, Optional, Union import oci from oci.data_science.models import UpdateModelDetails, UpdateModelProvenanceDetails from ads import set_auth from ads.aqua import logger -from ads.aqua.common.enums import Tags +from ads.aqua.common.enums import ConfigFolder, Tags from ads.aqua.common.errors import AquaRuntimeError, AquaValueError from ads.aqua.common.utils import ( _is_valid_mvs, @@ -268,7 +268,12 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool: logger.info(f"Artifact not found in model {model_id}.") return False - def get_config(self, model_id: str, config_file_name: str) -> Dict: + def get_config( + self, + model_id: str, + config_file_name: str, + config_folder: Optional[str] = ConfigFolder.CONFIG, + ) -> Dict: """Gets the config for the given Aqua model. Parameters @@ -277,12 +282,17 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict: The OCID of the Aqua model. config_file_name: str name of the config file + config_folder: (str, optional): + subfolder path where config_file_name needs to be searched + Defaults to `ConfigFolder.CONFIG`. + When searching inside model artifact directory , the value is ConfigFolder.ARTIFACT` Returns ------- Dict: A dict of allowed configs. """ + config_folder = config_folder or ConfigFolder.CONFIG oci_model = self.ds_client.get_model(model_id).data oci_aqua = ( ( @@ -304,22 +314,25 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict: f"Base model found for the model: {oci_model.id}. " f"Loading {config_file_name} for base model {base_model_ocid}." ) - base_model = self.ds_client.get_model(base_model_ocid).data - artifact_path = get_artifact_path(base_model.custom_metadata_list) + if config_folder == ConfigFolder.ARTIFACT: + artifact_path = get_artifact_path(oci_model.custom_metadata_list) + else: + base_model = self.ds_client.get_model(base_model_ocid).data + artifact_path = get_artifact_path(base_model.custom_metadata_list) else: logger.info(f"Loading {config_file_name} for model {oci_model.id}...") artifact_path = get_artifact_path(oci_model.custom_metadata_list) - if not artifact_path: logger.debug( f"Failed to get artifact path from custom metadata for the model: {model_id}" ) return config - config_path = f"{os.path.dirname(artifact_path)}/config/" + config_path = os.path.join(os.path.dirname(artifact_path), config_folder) if not is_path_exists(config_path): - config_path = f"{artifact_path.rstrip('/')}/config/" - + config_path = os.path.join(artifact_path.rstrip("/"), config_folder) + if not is_path_exists(config_path): + config_path = f"{artifact_path.rstrip('/')}/" config_file_path = f"{config_path}{config_file_name}" if is_path_exists(config_file_path): try: diff --git a/ads/aqua/common/enums.py b/ads/aqua/common/enums.py index 9f505d9f4..d893e30cf 100644 --- a/ads/aqua/common/enums.py +++ b/ads/aqua/common/enums.py @@ -92,3 +92,8 @@ class TextEmbeddingInferenceContainerParams(ExtendedEnum): MODEL_ID = "model-id" PORT = "port" + + +class ConfigFolder(ExtendedEnum): + CONFIG = "config" + ARTIFACT = "artifact" diff --git a/ads/aqua/constants.py b/ads/aqua/constants.py index 0b03a1507..8e0d5ca76 100644 --- a/ads/aqua/constants.py +++ b/ads/aqua/constants.py @@ -10,6 +10,7 @@ README = "README.md" LICENSE_TXT = "config/LICENSE.txt" DEPLOYMENT_CONFIG = "deployment_config.json" +AQUA_MODEL_TOKENIZER_CONFIG = "tokenizer_config.json" COMPARTMENT_MAPPING_KEY = "service-model-compartment" CONTAINER_INDEX = "container_index.json" EVALUATION_REPORT_JSON = "report.json" diff --git a/ads/aqua/extension/model_handler.py b/ads/aqua/extension/model_handler.py index 269337d2d..a5b89f8d1 100644 --- a/ads/aqua/extension/model_handler.py +++ b/ads/aqua/extension/model_handler.py @@ -14,6 +14,7 @@ from ads.aqua.common.errors import AquaRuntimeError, AquaValueError from ads.aqua.common.utils import ( get_hf_model_info, + is_valid_ocid, list_hf_models, ) from ads.aqua.extension.base_handler import AquaAPIhandler @@ -316,8 +317,30 @@ def post(self, *args, **kwargs): # noqa: ARG002 ) +class AquaModelTokenizerConfigHandler(AquaAPIhandler): + def get(self, model_id): + """ + Handles requests for retrieving the Hugging Face tokenizer configuration of a specified model. + Expected request format: GET /aqua/models//tokenizer + + """ + + path_list = urlparse(self.request.path).path.strip("/").split("/") + # Path should be /aqua/models/ocid1.iad.ahdxxx/tokenizer + # path_list=['aqua','models','','tokenizer'] + if ( + len(path_list) == 4 + and is_valid_ocid(path_list[2]) + and path_list[3] == "tokenizer" + ): + return self.finish(AquaModelApp().get_hf_tokenizer_config(model_id)) + + raise HTTPError(400, f"The request {self.request.path} is invalid.") + + __handlers__ = [ ("model/?([^/]*)", AquaModelHandler), ("model/?([^/]*)/license", AquaModelLicenseHandler), + ("model/?([^/]*)/tokenizer", AquaModelTokenizerConfigHandler), ("model/hf/search/?([^/]*)", AquaHuggingFaceHandler), ] diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index 59f8decff..fff23578f 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -15,6 +15,7 @@ from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger from ads.aqua.app import AquaApp from ads.aqua.common.enums import ( + ConfigFolder, CustomInferenceContainerTypeFamily, FineTuningContainerTypeFamily, InferenceContainerTypeFamily, @@ -44,6 +45,7 @@ AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME, AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE, AQUA_MODEL_ARTIFACT_FILE, + AQUA_MODEL_TOKENIZER_CONFIG, AQUA_MODEL_TYPE_CUSTOM, HF_METADATA_FOLDER, LICENSE_TXT, @@ -568,6 +570,26 @@ def _build_ft_metrics( training_final, ] + def get_hf_tokenizer_config(self, model_id): + """Gets the default chat template for the given Aqua model. + + Parameters + ---------- + model_id: str + The OCID of the Aqua model. + + Returns + ------- + str: + Chat template string. + """ + config = self.get_config( + model_id, AQUA_MODEL_TOKENIZER_CONFIG, ConfigFolder.ARTIFACT + ) + if not config: + logger.debug(f"Tokenizer config for model: {model_id} is not available.") + return config + @staticmethod def to_aqua_model( model: Union[ diff --git a/ads/aqua/modeldeployment/entities.py b/ads/aqua/modeldeployment/entities.py index 0b73ffe25..9a6eb5b7b 100644 --- a/ads/aqua/modeldeployment/entities.py +++ b/ads/aqua/modeldeployment/entities.py @@ -41,6 +41,7 @@ class AquaDeployment(DataClassSerializable): id: str = None display_name: str = None aqua_service_model: bool = None + model_id: str = None aqua_model_name: str = None state: str = None description: str = None @@ -97,7 +98,7 @@ def from_oci_model_deployment( else None ), ) - + model_id = oci_model_deployment._model_deployment_configuration_details.model_configuration_details.model_id tags = {} tags.update(oci_model_deployment.freeform_tags or UNKNOWN_DICT) tags.update(oci_model_deployment.defined_tags or UNKNOWN_DICT) @@ -110,6 +111,7 @@ def from_oci_model_deployment( return AquaDeployment( id=oci_model_deployment.id, + model_id=model_id, display_name=oci_model_deployment.display_name, aqua_service_model=aqua_service_model_tag is not None, aqua_model_name=aqua_model_name, diff --git a/tests/unitary/with_extras/aqua/test_deployment.py b/tests/unitary/with_extras/aqua/test_deployment.py index 74612ac8d..cad2759f1 100644 --- a/tests/unitary/with_extras/aqua/test_deployment.py +++ b/tests/unitary/with_extras/aqua/test_deployment.py @@ -254,6 +254,7 @@ class TestDataset: "created_by": "ocid1.user.oc1..", "endpoint": MODEL_DEPLOYMENT_URL, "private_endpoint_id": null, + "model_id": "ocid1.datasciencemodel.oc1..", "environment_variables": { "BASE_MODEL": "service_models/model-name/artifact", "MODEL_DEPLOY_ENABLE_STREAMING": "true", diff --git a/tests/unitary/with_extras/aqua/test_model_handler.py b/tests/unitary/with_extras/aqua/test_model_handler.py index a9236597d..d14424990 100644 --- a/tests/unitary/with_extras/aqua/test_model_handler.py +++ b/tests/unitary/with_extras/aqua/test_model_handler.py @@ -9,7 +9,7 @@ import pytest from huggingface_hub.hf_api import HfApi, ModelInfo from huggingface_hub.utils import GatedRepoError -from notebook.base.handlers import IPythonHandler +from notebook.base.handlers import IPythonHandler, HTTPError from parameterized import parameterized from ads.aqua.common.errors import AquaRuntimeError @@ -18,6 +18,7 @@ AquaHuggingFaceHandler, AquaModelHandler, AquaModelLicenseHandler, + AquaModelTokenizerConfigHandler, ) from ads.aqua.model import AquaModelApp from ads.aqua.model.entities import AquaModel, AquaModelSummary, HFModelSummary @@ -250,6 +251,41 @@ def test_get(self, mock_load_license): mock_load_license.assert_called_with("test_model_id") +class ModelTokenizerConfigHandlerTestCase(TestCase): + @patch.object(IPythonHandler, "__init__") + def setUp(self, ipython_init_mock) -> None: + ipython_init_mock.return_value = None + self.model_tokenizer_config_handler = AquaModelTokenizerConfigHandler( + MagicMock(), MagicMock() + ) + self.model_tokenizer_config_handler.finish = MagicMock() + self.model_tokenizer_config_handler.request = MagicMock() + + @patch.object(AquaModelApp, "get_hf_tokenizer_config") + @patch("ads.aqua.extension.model_handler.urlparse") + def test_get(self, mock_urlparse, mock_get_hf_tokenizer_config): + request_path = MagicMock(path="aqua/model/ocid1.xx./tokenizer") + mock_urlparse.return_value = request_path + self.model_tokenizer_config_handler.get(model_id="test_model_id") + self.model_tokenizer_config_handler.finish.assert_called_with( + mock_get_hf_tokenizer_config.return_value + ) + mock_get_hf_tokenizer_config.assert_called_with("test_model_id") + + @patch.object(AquaModelApp, "get_hf_tokenizer_config") + @patch("ads.aqua.extension.model_handler.urlparse") + def test_get_invalid_path(self, mock_urlparse, mock_get_hf_tokenizer_config): + """Test invalid request path should raise HTTPError(400)""" + request_path = MagicMock(path="/invalid/path") + mock_urlparse.return_value = request_path + + with self.assertRaises(HTTPError) as context: + self.model_tokenizer_config_handler.get(model_id="test_model_id") + self.assertEqual(context.exception.status_code, 400) + self.model_tokenizer_config_handler.finish.assert_not_called() + mock_get_hf_tokenizer_config.assert_not_called() + + class TestAquaHuggingFaceHandler: def setup_method(self): with patch.object(IPythonHandler, "__init__"):