1010import os
1111import random
1212import re
13+ import shlex
14+ import subprocess
15+ from datetime import datetime , timedelta
1316from functools import wraps
1417from pathlib import Path
1518from string import Template
1619from typing import List , Union
1720
1821import fsspec
1922import oci
23+ from cachetools import TTLCache , cached
24+ from huggingface_hub .hf_api import HfApi , ModelInfo
25+ from huggingface_hub .utils import (
26+ GatedRepoError ,
27+ HfHubHTTPError ,
28+ RepositoryNotFoundError ,
29+ RevisionNotFoundError ,
30+ )
2031from oci .data_science .models import JobRun , Model
32+ from oci .object_storage .models import ObjectSummary
2133
2234from ads .aqua .common .enums import (
2335 InferenceContainerParamType ,
3446 COMPARTMENT_MAPPING_KEY ,
3547 CONSOLE_LINK_RESOURCE_TYPE_MAPPING ,
3648 CONTAINER_INDEX ,
49+ HF_LOGIN_DEFAULT_TIMEOUT ,
3750 MAXIMUM_ALLOWED_DATASET_IN_BYTE ,
3851 MODEL_BY_REFERENCE_OSS_PATH_KEY ,
3952 SERVICE_MANAGED_CONTAINER_URI_SCHEME ,
4457 VLLM_INFERENCE_RESTRICTED_PARAMS ,
4558)
4659from ads .aqua .data import AquaResourceIdentifier
47- from ads .common .auth import default_signer
48- from ads .common .decorator .threaded import threaded
60+ from ads .common .auth import AuthState , default_signer
4961from ads .common .extended_enum import ExtendedEnumMeta
5062from ads .common .object_storage_details import ObjectStorageDetails
5163from ads .common .oci_resource import SEARCH_TYPE , OCIResource
@@ -213,7 +225,6 @@ def read_file(file_path: str, **kwargs) -> str:
213225 return UNKNOWN
214226
215227
216- @threaded ()
217228def load_config (file_path : str , config_file_name : str , ** kwargs ) -> dict :
218229 artifact_path = f"{ file_path .rstrip ('/' )} /{ config_file_name } "
219230 signer = default_signer () if artifact_path .startswith ("oci://" ) else {}
@@ -228,6 +239,32 @@ def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
228239 return config
229240
230241
242+ def list_os_files_with_extension (oss_path : str , extension : str ) -> [str ]:
243+ """
244+ List files in the specified directory with the given extension.
245+
246+ Parameters:
247+ - oss_path: The path to the directory where files are located.
248+ - extension: The file extension to filter by (e.g., 'txt' for text files).
249+
250+ Returns:
251+ - A list of file paths matching the specified extension.
252+ """
253+
254+ oss_client = ObjectStorageDetails .from_path (oss_path )
255+
256+ # Ensure the extension is prefixed with a dot if not already
257+ if not extension .startswith ("." ):
258+ extension = "." + extension
259+ files : List [ObjectSummary ] = oss_client .list_objects ().objects
260+
261+ return [
262+ file .name [len (oss_client .filepath ) :].lstrip ("/" )
263+ for file in files
264+ if file .name .endswith (extension )
265+ ]
266+
267+
231268def is_valid_ocid (ocid : str ) -> bool :
232269 """Checks if the given ocid is valid.
233270
@@ -503,6 +540,7 @@ def container_config_path():
503540 return f"oci://{ AQUA_SERVICE_MODELS_BUCKET } @{ CONDA_BUCKET_NS } /service_models/config"
504541
505542
543+ @cached (cache = TTLCache (maxsize = 1 , ttl = timedelta (hours = 5 ), timer = datetime .now ))
506544def get_container_config ():
507545 config = load_config (
508546 file_path = container_config_path (),
@@ -743,6 +781,33 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
743781 return ocid [- key_len :] if ocid and len (ocid ) > key_len else ""
744782
745783
784+ def upload_folder (os_path : str , local_dir : str , model_name : str ) -> str :
785+ """Upload the local folder to the object storage
786+
787+ Args:
788+ os_path (str): object storage URI with prefix. This is the path to upload
789+ local_dir (str): Local directory where the object is downloaded
790+ model_name (str): Name of the huggingface model
791+ Retuns:
792+ str: Object name inside the bucket
793+ """
794+ os_details : ObjectStorageDetails = ObjectStorageDetails .from_path (os_path )
795+ if not os_details .is_bucket_versioned ():
796+ raise ValueError (f"Version is not enabled at object storage location { os_path } " )
797+ auth_state = AuthState ()
798+ object_path = os_details .filepath .rstrip ("/" ) + "/" + model_name + "/"
799+ command = f"oci os object bulk-upload --src-dir { local_dir } --prefix { object_path } -bn { os_details .bucket } -ns { os_details .namespace } --auth { auth_state .oci_iam_type } --profile { auth_state .oci_key_profile } --no-overwrite"
800+ try :
801+ logger .info (f"Running: { command } " )
802+ subprocess .check_call (shlex .split (command ))
803+ except subprocess .CalledProcessError as e :
804+ logger .error (
805+ f"Error uploading the object. Exit code: { e .returncode } with error { e .stdout } "
806+ )
807+
808+ return f"oci://{ os_details .bucket } @{ os_details .namespace } " + "/" + object_path
809+
810+
746811def is_service_managed_container (container ):
747812 return container and container .startswith (SERVICE_MANAGED_CONTAINER_URI_SCHEME )
748813
@@ -881,6 +946,8 @@ def get_container_params_type(container_type_name: str) -> str:
881946 return InferenceContainerParamType .PARAM_TYPE_VLLM
882947 elif InferenceContainerType .CONTAINER_TYPE_TGI in container_type_name .lower ():
883948 return InferenceContainerParamType .PARAM_TYPE_TGI
949+ elif InferenceContainerType .CONTAINER_TYPE_LLAMA_CPP in container_type_name .lower ():
950+ return InferenceContainerParamType .PARAM_TYPE_LLAMA_CPP
884951 else :
885952 return UNKNOWN
886953
@@ -905,3 +972,93 @@ def get_restricted_params_by_container(container_type_name: str) -> set:
905972 return TGI_INFERENCE_RESTRICTED_PARAMS
906973 else :
907974 return set ()
975+
976+
977+ def get_huggingface_login_timeout () -> int :
978+ """This helper function returns the huggingface login timeout, returns default if not set via
979+ env var.
980+ Returns
981+ -------
982+ timeout: int
983+ huggingface login timeout.
984+
985+ """
986+ timeout = HF_LOGIN_DEFAULT_TIMEOUT
987+ try :
988+ timeout = int (
989+ os .environ .get ("HF_LOGIN_DEFAULT_TIMEOUT" , HF_LOGIN_DEFAULT_TIMEOUT )
990+ )
991+ except ValueError :
992+ pass
993+ return timeout
994+
995+
996+ def format_hf_custom_error_message (error : HfHubHTTPError ):
997+ """
998+ Formats a custom error message based on the Hugging Face error response.
999+
1000+ Parameters
1001+ ----------
1002+ error (HfHubHTTPError): The caught exception.
1003+
1004+ Raises
1005+ ------
1006+ AquaRuntimeError: A user-friendly error message.
1007+ """
1008+ # Extract the repository URL from the error message if present
1009+ match = re .search (r"(https://huggingface.co/[^\s]+)" , str (error ))
1010+ url = match .group (1 ) if match else "the requested Hugging Face URL."
1011+
1012+ if isinstance (error , RepositoryNotFoundError ):
1013+ raise AquaRuntimeError (
1014+ reason = f"Failed to access `{ url } `. Please check if the provided repository name is correct. "
1015+ "If the repo is private, make sure you are authenticated and have a valid HF token registered. "
1016+ "To register your token, run this command in your terminal: `huggingface-cli login`" ,
1017+ service_payload = {"error" : "RepositoryNotFoundError" },
1018+ )
1019+
1020+ if isinstance (error , GatedRepoError ):
1021+ raise AquaRuntimeError (
1022+ reason = f"Access denied to `{ url } ` "
1023+ "This repository is gated. Access is restricted to authorized users. "
1024+ "Please request access or check with the repository administrator. "
1025+ "If you are trying to access a gated repository, ensure you have a valid HF token registered. "
1026+ "To register your token, run this command in your terminal: `huggingface-cli login`" ,
1027+ service_payload = {"error" : "GatedRepoError" },
1028+ )
1029+
1030+ if isinstance (error , RevisionNotFoundError ):
1031+ raise AquaRuntimeError (
1032+ reason = f"The specified revision could not be found at `{ url } ` "
1033+ "Please check the revision identifier and try again." ,
1034+ service_payload = {"error" : "RevisionNotFoundError" },
1035+ )
1036+
1037+ raise AquaRuntimeError (
1038+ reason = f"An error occurred while accessing `{ url } ` "
1039+ "Please check your network connection and try again. "
1040+ "If you are trying to access a gated repository, ensure you have a valid HF token registered. "
1041+ "To register your token, run this command in your terminal: `huggingface-cli login`" ,
1042+ service_payload = {"error" : "Error" },
1043+ )
1044+
1045+
1046+ @cached (cache = TTLCache (maxsize = 1 , ttl = timedelta (hours = 5 ), timer = datetime .now ))
1047+ def get_hf_model_info (repo_id : str ) -> ModelInfo :
1048+ """Gets the model information object for the given model repository name. For models that requires a token,
1049+ this method assumes that the token validation is already done.
1050+
1051+ Parameters
1052+ ----------
1053+ repo_id: str
1054+ hugging face model repository name
1055+
1056+ Returns
1057+ -------
1058+ instance of ModelInfo object
1059+
1060+ """
1061+ try :
1062+ return HfApi ().model_info (repo_id = repo_id )
1063+ except HfHubHTTPError as err :
1064+ raise format_hf_custom_error_message (err ) from err
0 commit comments