|
10 | 10 | import os |
11 | 11 | import random |
12 | 12 | import re |
| 13 | +from datetime import datetime, timedelta |
13 | 14 | from functools import wraps |
14 | 15 | from pathlib import Path |
15 | 16 | from string import Template |
16 | 17 | from typing import List, Union |
17 | 18 |
|
18 | 19 | import fsspec |
19 | | -import oci |
20 | | -from oci.data_science.models import JobRun, Model |
| 20 | +import ocifs |
| 21 | +from cachetools import TTLCache, cached |
21 | 22 |
|
| 23 | +import oci |
22 | 24 | from ads.aqua.common.enums import ( |
23 | 25 | InferenceContainerParamType, |
24 | 26 | InferenceContainerType, |
|
52 | 54 | from ads.common.utils import copy_file, get_console_link, upload_to_os |
53 | 55 | from ads.config import AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, TENANCY_OCID |
54 | 56 | from ads.model import DataScienceModel, ModelVersionSet |
| 57 | +from oci.data_science.models import JobRun, Model |
55 | 58 |
|
56 | 59 | logger = logging.getLogger("ads.aqua") |
57 | 60 |
|
@@ -228,6 +231,29 @@ def load_config(file_path: str, config_file_name: str, **kwargs) -> dict: |
228 | 231 | return config |
229 | 232 |
|
230 | 233 |
|
| 234 | +def list_os_files_with_extension(oss_path: str, extension: str) -> [str]: |
| 235 | + """ |
| 236 | + List files in the specified directory with the given extension. |
| 237 | +
|
| 238 | + Parameters: |
| 239 | + - oss_path: The path to the directory where files are located. |
| 240 | + - extension: The file extension to filter by (e.g., 'txt' for text files). |
| 241 | +
|
| 242 | + Returns: |
| 243 | + - A list of file paths matching the specified extension. |
| 244 | + """ |
| 245 | + |
| 246 | + signer = default_signer() |
| 247 | + |
| 248 | + # Ensure the extension is prefixed with a dot if not already |
| 249 | + if not extension.startswith("."): |
| 250 | + extension = "." + extension |
| 251 | + fs = ocifs.OCIFileSystem(**signer) |
| 252 | + files: [str] = fs.ls(oss_path) |
| 253 | + |
| 254 | + return [file for file in files if file.endswith(extension)] |
| 255 | + |
| 256 | + |
231 | 257 | def is_valid_ocid(ocid: str) -> bool: |
232 | 258 | """Checks if the given ocid is valid. |
233 | 259 |
|
@@ -503,6 +529,7 @@ def container_config_path(): |
503 | 529 | return f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config" |
504 | 530 |
|
505 | 531 |
|
| 532 | +@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now)) |
506 | 533 | def get_container_config(): |
507 | 534 | config = load_config( |
508 | 535 | file_path=container_config_path(), |
|
0 commit comments