Skip to content

Commit deae57b

Browse files
committed
added artifact location for fine tuned model
1 parent 1661f1a commit deae57b

File tree

4 files changed

+20
-13
lines changed

4 files changed

+20
-13
lines changed

ads/aqua/model/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,12 +312,13 @@ def create_multi(
312312
# "Currently only service models are supported for multi model deployment."
313313
# )
314314

315-
# check if model is a fine-tuned model and if so, pass FT_MODEL info to model's env variables
315+
# check if model is a fine-tuned model and if so, pass FT_MODEL info to model's
316+
# env variables & set model_path to be the base model
316317
is_fine_tuned_model = Tags.AQUA_FINE_TUNED_MODEL_TAG in source_model.freeform_tags
317318

318319
if is_fine_tuned_model:
319320
model.model_id, model.model_name = extract_base_model_from_ft(source_model)
320-
set_fine_tune_env_var(source_model, model.env_var)
321+
set_fine_tune_env_var(source_model, model=model)
321322

322323
display_name_list.append(display_name)
323324

ads/aqua/model/utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44
"""AQUA model utils"""
55

6-
from typing import Dict, Tuple
6+
from typing import Dict, Optional, Tuple
77

8+
from ads.aqua.common.entities import AquaMultiModelRef
89
from ads.aqua.common.errors import AquaValueError
910
from ads.aqua.common.utils import get_model_by_reference_paths
1011
from ads.aqua.finetuning.constants import FineTuneCustomMetadata
@@ -30,7 +31,7 @@ def extract_base_model_from_ft(aqua_model: DataScienceModel) -> Tuple[str, str]:
3031
return config_source_id, model_name
3132

3233

33-
def set_fine_tune_env_var(aqua_model: DataScienceModel, env_var: Dict[str,str]) -> None:
34+
def set_fine_tune_env_var(aqua_model: DataScienceModel, env_var: Optional[Dict[str,str]], model: Optional[AquaMultiModelRef] = None) -> None:
3435
"""Extracts the fine tuning source (fine_tune_output_path).
3536
Sets the environment variable (env_var) of the fine tuned model to include FT_model (fine tuning source)"""
3637

@@ -44,5 +45,12 @@ def set_fine_tune_env_var(aqua_model: DataScienceModel, env_var: Dict[str,str])
4445
os_path = ObjectStorageDetails.from_path(fine_tune_output_path)
4546
fine_tune_output_path = os_path.filepath.rstrip("/")
4647

47-
env_var.update({"FT_MODEL": f"{fine_tune_output_path}"})
48+
# we add the correct artifact location when using FT in Multi Model Deployment
49+
if model:
50+
model.artifact_location = base_model_path # validated later in _create_multi method in deployment.py
51+
model.env_var.update({"FT_MODEL": f"{fine_tune_output_path}"})
52+
53+
else:
54+
env_var.update({"FT_MODEL": f"{fine_tune_output_path}"})
55+
4856

ads/aqua/modeldeployment/deployment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def _create(
366366

367367
if is_fine_tuned_model:
368368
config_source_id, model_name = extract_base_model_from_ft(aqua_model)
369-
set_fine_tune_env_var(aqua_model, env_var)
369+
set_fine_tune_env_var(aqua_model, env_var=env_var)
370370

371371
container_type_key = self._get_container_type_key(
372372
model=aqua_model,

tests/unitary/with_extras/aqua/test_deployment.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,21 @@
1919
)
2020
from parameterized import parameterized
2121

22+
import ads.aqua.modeldeployment.deployment
23+
import ads.config
24+
from ads.aqua.app import AquaApp
2225
from ads.aqua.common.entities import (
2326
AquaMultiModelRef,
2427
ComputeShapeSummary,
2528
ModelConfigResult,
2629
)
27-
from ads.aqua.app import AquaApp
28-
from ads.aqua.common.entities import ModelConfigResult
29-
import ads.aqua.modeldeployment.deployment
30-
import ads.config
31-
from ads.aqua.common.entities import AquaMultiModelRef
3230
from ads.aqua.common.enums import Tags
3331
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
3432
from ads.aqua.config.container_config import (
35-
AquaContainerConfigItem,
3633
AquaContainerConfig,
34+
AquaContainerConfigItem,
3735
)
36+
from ads.aqua.model.enums import MultiModelSupportedTaskType
3837
from ads.aqua.modeldeployment import AquaDeploymentApp, MDInferenceResponse
3938
from ads.aqua.modeldeployment.entities import (
4039
AquaDeployment,
@@ -45,7 +44,6 @@
4544
ModelDeploymentConfigSummary,
4645
ModelParams,
4746
)
48-
from ads.aqua.model.enums import MultiModelSupportedTaskType
4947
from ads.aqua.modeldeployment.utils import MultiModelDeploymentConfigLoader
5048
from ads.model.datascience_model import DataScienceModel
5149
from ads.model.deployment.model_deployment import ModelDeployment

0 commit comments

Comments
 (0)