|
57 | 57 | VLLM_INFERENCE_RESTRICTED_PARAMS, |
58 | 58 | ) |
59 | 59 | from ads.aqua.data import AquaResourceIdentifier |
| 60 | +from ads.aqua.model.constants import ModelTask |
60 | 61 | from ads.common.auth import AuthState, default_signer |
61 | 62 | from ads.common.extended_enum import ExtendedEnumMeta |
62 | 63 | from ads.common.object_storage_details import ObjectStorageDetails |
63 | 64 | from ads.common.oci_resource import SEARCH_TYPE, OCIResource |
64 | 65 | from ads.common.utils import copy_file, get_console_link, upload_to_os |
65 | 66 | from ads.config import AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, TENANCY_OCID |
66 | 67 | from ads.model import DataScienceModel, ModelVersionSet |
| 68 | +from tests.unitary.with_extras.model.score import model_name |
67 | 69 |
|
68 | 70 | logger = logging.getLogger("ads.aqua") |
69 | 71 |
|
@@ -1062,3 +1064,12 @@ def get_hf_model_info(repo_id: str) -> ModelInfo: |
1062 | 1064 | return HfApi().model_info(repo_id=repo_id) |
1063 | 1065 | except HfHubHTTPError as err: |
1064 | 1066 | raise format_hf_custom_error_message(err) from err |
| 1067 | + |
| 1068 | +def list_hf_models(query:str) -> List[str]: |
| 1069 | + try: |
| 1070 | + models= HfApi().list_models(model_name=query,task=ModelTask.TEXT_GENERATION) |
| 1071 | + return [model.id for model in models if model.disabled is None] |
| 1072 | + except HfHubHTTPError as err: |
| 1073 | + raise format_hf_custom_error_message(err) from err |
| 1074 | + |
| 1075 | + |
0 commit comments