Skip to content

Commit e06e47f

Browse files
authored
Add GGUF files validation and files endpoint (#900)
2 parents c267386 + 8f02859 commit e06e47f

File tree

10 files changed

+426
-139
lines changed

10 files changed

+426
-139
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ repos:
4545
rev: v8.18.4
4646
hooks:
4747
- id: gitleaks
48-
exclude: .github/workflows/reusable-actions/set-dummy-conf.yml
48+
exclude: .github/workflows/reusable-actions/set-dummy-conf.yml|./tests/operators/common/test_load_data.py
4949
# Oracle copyright checker
5050
- repo: https://github.com/oracle-samples/oci-data-science-ai-samples/
5151
rev: 1bc5270a443b791c62f634233c0f4966dfcc0dd6

ads/aqua/common/enums.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ class Tags(str, metaclass=ExtendedEnumMeta):
2828
TASK = "task"
2929
LICENSE = "license"
3030
ORGANIZATION = "organization"
31-
PLATFORM = "platform"
3231
AQUA_TAG = "OCI_AQUA"
3332
AQUA_SERVICE_MODEL_TAG = "aqua_service_model"
3433
AQUA_FINE_TUNED_MODEL_TAG = "aqua_fine_tuned_model"
@@ -39,6 +38,7 @@ class Tags(str, metaclass=ExtendedEnumMeta):
3938
READY_TO_IMPORT = "ready_to_import"
4039
BASE_MODEL_CUSTOM = "aqua_custom_base_model"
4140
AQUA_EVALUATION_MODEL_ID = "evaluation_model_id"
41+
MODEL_FORMAT = "model_format"
4242

4343

4444
class InferenceContainerType(str, metaclass=ExtendedEnumMeta):

ads/aqua/common/utils.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@
1010
import os
1111
import random
1212
import re
13+
from datetime import datetime, timedelta
1314
from functools import wraps
1415
from pathlib import Path
1516
from string import Template
1617
from typing import List, Union
1718

1819
import fsspec
19-
import oci
20-
from oci.data_science.models import JobRun, Model
20+
from cachetools import TTLCache, cached
2121

22+
import oci
2223
from ads.aqua.common.enums import (
2324
InferenceContainerParamType,
2425
InferenceContainerType,
@@ -52,6 +53,8 @@
5253
from ads.common.utils import copy_file, get_console_link, upload_to_os
5354
from ads.config import AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, TENANCY_OCID
5455
from ads.model import DataScienceModel, ModelVersionSet
56+
from oci.data_science.models import JobRun, Model
57+
from oci.object_storage.models import ObjectSummary
5558

5659
logger = logging.getLogger("ads.aqua")
5760

@@ -228,6 +231,32 @@ def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
228231
return config
229232

230233

234+
def list_os_files_with_extension(oss_path: str, extension: str) -> [str]:
235+
"""
236+
List files in the specified directory with the given extension.
237+
238+
Parameters:
239+
- oss_path: The path to the directory where files are located.
240+
- extension: The file extension to filter by (e.g., 'txt' for text files).
241+
242+
Returns:
243+
- A list of file paths matching the specified extension.
244+
"""
245+
246+
oss_client = ObjectStorageDetails.from_path(oss_path)
247+
248+
# Ensure the extension is prefixed with a dot if not already
249+
if not extension.startswith("."):
250+
extension = "." + extension
251+
files: List[ObjectSummary] = oss_client.list_objects().objects
252+
253+
return [
254+
file.name
255+
for file in files
256+
if file.name.endswith(extension) and "/" not in file.name
257+
]
258+
259+
231260
def is_valid_ocid(ocid: str) -> bool:
232261
"""Checks if the given ocid is valid.
233262
@@ -503,6 +532,7 @@ def container_config_path():
503532
return f"oci://{AQUA_SERVICE_MODELS_BUCKET}@{CONDA_BUCKET_NS}/service_models/config"
504533

505534

535+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
506536
def get_container_config():
507537
config = load_config(
508538
file_path=container_config_path(),
Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,36 @@
11
{
2+
"configuration": {
3+
"VM.Standard.A1.Flex": {
4+
"parameters": {},
5+
"shape_info": {
6+
"configs": [
7+
{
8+
"memory_in_gbs": 128,
9+
"ocpu": 32
10+
},
11+
{
12+
"memory_in_gbs": 256,
13+
"ocpu": 64
14+
},
15+
{
16+
"memory_in_gbs": 512,
17+
"ocpu": 128
18+
},
19+
{
20+
"memory_in_gbs": 1024,
21+
"ocpu": 256
22+
}
23+
],
24+
"type": "CPU"
25+
}
26+
}
27+
},
228
"shape": [
329
"VM.GPU.A10.1",
430
"VM.GPU.A10.2",
531
"BM.GPU.A10.4",
632
"BM.GPU4.8",
7-
"BM.GPU.A100-v2.8"
33+
"BM.GPU.A100-v2.8",
34+
"VM.Standard.A1.Flex"
835
]
936
}

ads/aqua/constants.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
DEFAULT_FT_REPLICA = 1
2222
DEFAULT_FT_BATCH_SIZE = 1
2323
DEFAULT_FT_VALIDATION_SET_SIZE = 0.1
24-
ARM_CPU="arm_cpu"
25-
NVIDIA_GPU="nvidia_gpu"
2624
MAXIMUM_ALLOWED_DATASET_IN_BYTE = 52428800 # 1024 x 1024 x 50 = 50MB
2725
JOB_INFRASTRUCTURE_TYPE_DEFAULT_NETWORKING = "ME_STANDALONE"
2826
NB_SESSION_IDENTIFIER = "NB_SESSION_OCID"
@@ -35,6 +33,7 @@
3533
AQUA_MODEL_ARTIFACT_CONFIG = "config.json"
3634
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME = "_name_or_path"
3735
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE = "model_type"
36+
AQUA_MODEL_ARTIFACT_FILE = "model_file"
3837

3938
TRAINING_METRICS_FINAL = "training_metrics_final"
4039
VALIDATION_METRICS_FINAL = "validation_metrics_final"

ads/aqua/extension/errors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

@@ -8,3 +7,4 @@ class Errors(str):
87
INVALID_INPUT_DATA_FORMAT = "Invalid format of input data."
98
NO_INPUT_DATA = "No input data provided."
109
MISSING_REQUIRED_PARAMETER = "Missing required parameter: '{}'"
10+
MISSING_ONEOF_REQUIRED_PARAMETER = "Either '{}' or '{}' is required."

ads/aqua/extension/model_handler.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,50 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

6-
import re
7-
from typing import Optional
85
from urllib.parse import urlparse
96

107
from tornado.web import HTTPError
11-
from ads.aqua.extension.errors import Errors
8+
129
from ads.aqua.common.decorator import handle_exceptions
10+
from ads.aqua.common.errors import AquaValueError
1311
from ads.aqua.extension.base_handler import AquaAPIhandler
12+
from ads.aqua.extension.errors import Errors
1413
from ads.aqua.model import AquaModelApp
14+
from ads.aqua.ui import ModelFormat
1515

1616

1717
class AquaModelHandler(AquaAPIhandler):
1818
"""Handler for Aqua Model REST APIs."""
1919

2020
@handle_exceptions
21-
def get(self, model_id=""):
21+
def get(
22+
self,
23+
model_id="",
24+
):
2225
"""Handle GET request."""
23-
if not model_id:
26+
url_parse = urlparse(self.request.path)
27+
paths = url_parse.path.strip("/")
28+
if paths.startswith("aqua/model/files"):
29+
os_path = self.get_argument("os_path")
30+
if not os_path:
31+
raise HTTPError(
32+
400, Errors.MISSING_REQUIRED_PARAMETER.format("os_path")
33+
)
34+
model_format = self.get_argument("model_format")
35+
if not model_format:
36+
raise HTTPError(
37+
400, Errors.MISSING_REQUIRED_PARAMETER.format("model_format")
38+
)
39+
try:
40+
model_format = ModelFormat(model_format.upper())
41+
except ValueError:
42+
raise AquaValueError(f"Invalid model format: {model_format}")
43+
else:
44+
return self.finish(AquaModelApp.get_model_files(os_path, model_format))
45+
elif not model_id:
2446
return self.list()
47+
2548
return self.read(model_id)
2649

2750
def read(self, model_id):
@@ -81,6 +104,7 @@ def post(self, *args, **kwargs):
81104
finetuning_container = input_data.get("finetuning_container")
82105
compartment_id = input_data.get("compartment_id")
83106
project_id = input_data.get("project_id")
107+
model_file = input_data.get("model_file")
84108

85109
return self.finish(
86110
AquaModelApp().register(
@@ -90,6 +114,7 @@ def post(self, *args, **kwargs):
90114
finetuning_container=finetuning_container,
91115
compartment_id=compartment_id,
92116
project_id=project_id,
117+
model_file=model_file,
93118
)
94119
)
95120

ads/aqua/model/entities.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
from typing import List, Optional
1515

1616
import oci
17-
1817
from ads.aqua import logger
1918
from ads.aqua.app import CLIBuilderMixin
2019
from ads.aqua.common import utils
2120
from ads.aqua.constants import LIFECYCLE_DETAILS_MISSING_JOBRUN, UNKNOWN_VALUE
2221
from ads.aqua.data import AquaResourceIdentifier
2322
from ads.aqua.model.enums import FineTuningDefinedMetadata
2423
from ads.aqua.training.exceptions import exit_code_dict
24+
from ads.aqua.ui import ModelFormat
2525
from ads.common.serializer import DataClassSerializable
2626
from ads.common.utils import get_log_links
2727
from ads.model.datascience_model import DataScienceModel
@@ -41,6 +41,12 @@ class AquaFineTuneValidation(DataClassSerializable):
4141
value: str = ""
4242

4343

44+
class ModelValidationResult:
45+
model_file: Optional[str] = None
46+
model_format: ModelFormat = None
47+
telemetry_model_name: str = None
48+
49+
4450
@dataclass(repr=False)
4551
class AquaFineTuningMetric(DataClassSerializable):
4652
name: str = field(default_factory=str)
@@ -76,7 +82,9 @@ class AquaModelSummary(DataClassSerializable):
7682
ready_to_deploy: bool = True
7783
ready_to_finetune: bool = False
7884
ready_to_import: bool = False
79-
platform: List[str] = field(default_factory=lambda: ["nvidia_gpu"])
85+
nvidia_gpu_supported: bool = False
86+
arm_cpu_supported: bool = False
87+
model_format: ModelFormat = ModelFormat.UNKNOWN
8088

8189

8290
@dataclass(repr=False)
@@ -260,6 +268,7 @@ class ImportModelDetails(CLIBuilderMixin):
260268
finetuning_container: Optional[str] = None
261269
compartment_id: Optional[str] = None
262270
project_id: Optional[str] = None
271+
model_file: Optional[str] = None
263272

264273
def __post_init__(self):
265274
self._command = "model register"

0 commit comments

Comments
 (0)