Skip to content

Commit 0709c17

Browse files
committed
Fix gguf deployment
1 parent e06e47f commit 0709c17

File tree

3 files changed

+17
-7
lines changed

3 files changed

+17
-7
lines changed

ads/aqua/model/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,8 +364,8 @@ def _build_ft_metrics(
364364
training_final,
365365
]
366366

367+
@staticmethod
367368
def to_aqua_model(
368-
self,
369369
model: Union[
370370
DataScienceModel,
371371
oci.data_science.models.model.Model,
@@ -375,7 +375,7 @@ def to_aqua_model(
375375
region: str,
376376
) -> AquaModel:
377377
"""Converts a model to an Aqua model."""
378-
return AquaModel(**self._process_model(model, region))
378+
return AquaModel(**AquaModelApp._process_model(model, region))
379379

380380
@staticmethod
381381
def _process_model(

ads/aqua/modeldeployment/deployment.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
load_config,
2626
)
2727
from ads.aqua.constants import (
28+
AQUA_MODEL_ARTIFACT_FILE,
2829
AQUA_MODEL_TYPE_CUSTOM,
2930
AQUA_MODEL_TYPE_SERVICE,
3031
MODEL_BY_REFERENCE_OSS_PATH_KEY,
@@ -39,6 +40,7 @@
3940
AquaDeploymentDetail,
4041
ContainerSpec,
4142
)
43+
from ads.aqua.ui import ModelFormat
4244
from ads.common.object_storage_details import ObjectStorageDetails
4345
from ads.common.utils import get_log_links
4446
from ads.config import (
@@ -310,6 +312,17 @@ def create(
310312
if isinstance(env, dict):
311313
env_var.update(env)
312314

315+
if (
316+
AquaModelApp.to_aqua_model(
317+
model=aqua_model, region=self.region
318+
).model_format
319+
== ModelFormat.GGUF
320+
):
321+
model_file = aqua_model.custom_metadata_list.get(
322+
AQUA_MODEL_ARTIFACT_FILE
323+
).value
324+
env_var.update({"MODEL": f"/opt/ds/model/deployed_model/{model_file}"})
325+
313326
logging.info(f"Env vars used for deploying {aqua_model.id} :{env_var}")
314327

315328
# Start model deployment

ads/common/object_storage_details.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,20 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

43
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76
import json
87
import os
98
import re
9+
from concurrent.futures import ThreadPoolExecutor, as_completed
1010
from dataclasses import dataclass
1111
from typing import Dict, List
1212
from urllib.parse import urlparse
1313

14-
1514
import oci
1615
from ads.common import auth as authutil
1716
from ads.common import oci_client
1817
from ads.dataset.progress import TqdmProgressBar
19-
from concurrent.futures import ThreadPoolExecutor, as_completed
2018

2119
THREAD_POOL_MAX_WORKERS = 10
2220

@@ -169,8 +167,7 @@ def is_bucket_versioned(self) -> bool:
169167

170168
def list_objects(self, **kwargs):
171169
"""Lists objects in a given oss path
172-
173-
Parameters
170+
Parameters
174171
-------
175172
**kwargs:
176173
namespace, bucket, filepath are set by the class. By default, fields gets all values. For other supported

0 commit comments

Comments
 (0)