|
10 | 10 | import os |
11 | 11 | import random |
12 | 12 | import re |
| 13 | +from datetime import datetime, timedelta |
13 | 14 | from functools import wraps |
14 | 15 | from pathlib import Path |
15 | 16 | from string import Template |
16 | 17 | from typing import List, Union |
17 | 18 |
|
18 | 19 | import fsspec |
19 | | -import oci |
20 | | -from oci.data_science.models import JobRun, Model |
| 20 | +from cachetools import TTLCache, cached |
21 | 21 |
|
| 22 | +import oci |
22 | 23 | from ads.aqua.common.enums import ( |
23 | 24 | InferenceContainerParamType, |
24 | 25 | InferenceContainerType, |
|
52 | 53 | from ads.common.utils import copy_file, get_console_link, upload_to_os |
53 | 54 | from ads.config import AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, TENANCY_OCID |
54 | 55 | from ads.model import DataScienceModel, ModelVersionSet |
| 56 | +from oci.data_science.models import JobRun, Model |
| 57 | +from oci.object_storage.models import ObjectSummary |
55 | 58 |
|
56 | 59 | logger = logging.getLogger("ads.aqua") |
57 | 60 |
|
@@ -228,6 +231,32 @@ def load_config(file_path: str, config_file_name: str, **kwargs) -> dict: |
228 | 231 | return config |
229 | 232 |
|
230 | 233 |
|
| 234 | +def list_os_files_with_extension(oss_path: str, extension: str) -> [str]: |
| 235 | + """ |
| 236 | + List files in the specified directory with the given extension. |
| 237 | +
|
| 238 | + Parameters: |
| 239 | + - oss_path: The path to the directory where files are located. |
| 240 | + - extension: The file extension to filter by (e.g., 'txt' for text files). |
| 241 | +
|
| 242 | + Returns: |
| 243 | + - A list of file paths matching the specified extension. |
| 244 | + """ |
| 245 | + |
| 246 | + oss_client = ObjectStorageDetails.from_path(oss_path) |
| 247 | + |
| 248 | + # Ensure the extension is prefixed with a dot if not already |
| 249 | + if not extension.startswith("."): |
| 250 | + extension = "." + extension |
| 251 | + files: List[ObjectSummary] = oss_client.list_objects().objects |
| 252 | + |
| 253 | + return [ |
| 254 | + file.name |
| 255 | + for file in files |
| 256 | + if file.name.endswith(extension) and "/" not in file.name |
| 257 | + ] |
| 258 | + |
| 259 | + |
231 | 260 | def is_valid_ocid(ocid: str) -> bool: |
232 | 261 | """Checks if the given ocid is valid. |
233 | 262 |
|
@@ -503,6 +532,7 @@ def container_config_path(): |
503 | 532 | return f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config" |
504 | 533 |
|
505 | 534 |
|
| 535 | +@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now)) |
506 | 536 | def get_container_config(): |
507 | 537 | config = load_config( |
508 | 538 | file_path=container_config_path(), |
|
0 commit comments