Skip to content

Commit b38db75

Browse files
committed
added inital code for supporting FT models in multi model
1 parent 845ec6d commit b38db75

File tree

3 files changed

+69
-34
lines changed

3 files changed

+69
-34
lines changed

ads/aqua/finetuning/entities.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,20 @@
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55
import json
6-
from typing import Any, Dict, List, Literal, Optional, Union
6+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
77

88
from pydantic import Field, model_validator
99

1010
from ads.aqua.common.errors import AquaValueError
11+
from ads.aqua.common.utils import get_model_by_reference_paths
1112
from ads.aqua.config.utils.serializer import Serializable
1213
from ads.aqua.data import AquaResourceIdentifier
13-
from ads.aqua.finetuning.constants import FineTuningRestrictedParams
14+
from ads.aqua.finetuning.constants import (
15+
FineTuneCustomMetadata,
16+
FineTuningRestrictedParams,
17+
)
18+
from ads.common.object_storage_details import ObjectStorageDetails
19+
from ads.model.datascience_model import DataScienceModel
1420

1521

1622
class AquaFineTuningParams(Serializable):
@@ -179,3 +185,44 @@ class CreateFineTuningDetails(Serializable):
179185

180186
class Config:
181187
extra = "ignore"
188+
189+
190+
@staticmethod
191+
def extract_base_model_ocid(aqua_model: DataScienceModel) -> Tuple[str, str]:
192+
"""Extracts the model_name, base model (config_source_id) OCID of the Fine Tuned Model
193+
"""
194+
config_source_id = aqua_model.custom_metadata_list.get(
195+
FineTuneCustomMetadata.FINE_TUNE_SOURCE
196+
).value
197+
model_name = aqua_model.custom_metadata_list.get(
198+
FineTuneCustomMetadata.FINE_TUNE_SOURCE_NAME
199+
).value
200+
201+
if not config_source_id or not model_name:
202+
raise AquaValueError(
203+
f"Either {FineTuneCustomMetadata.FINE_TUNE_SOURCE} or {FineTuneCustomMetadata.FINE_TUNE_SOURCE_NAME} is missing "
204+
f"from custom metadata for the model {config_source_id}")
205+
206+
return config_source_id, model_name
207+
208+
209+
@staticmethod
210+
def set_fine_tune_env_var(aqua_model: DataScienceModel, env_var: Dict[str,str]) -> Dict[str,str]:
211+
"""Extracts the fine tuning source (fine_tune_output_path).
212+
Sets the environment variable (env_var) of the fine tuned model to include FT_model (fine tuning source)"""
213+
214+
base_model_path, fine_tune_output_path = get_model_by_reference_paths(
215+
aqua_model.model_file_description
216+
)
217+
218+
if fine_tune_output_path and ObjectStorageDetails.is_oci_path(fine_tune_output_path):
219+
os_path = ObjectStorageDetails.from_path(fine_tune_output_path)
220+
fine_tune_output_path = os_path.filepath.rstrip("/")
221+
222+
env_var.update({"FT_MODEL": f"{fine_tune_output_path}"})
223+
224+
return env_var
225+
226+
raise AquaValueError(
227+
"Fine tuned output path is not available in the model artifact."
228+
)

ads/aqua/model/model.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
VALIDATION_METRICS,
6565
VALIDATION_METRICS_FINAL,
6666
)
67+
from ads.aqua.finetuning.entities import extract_base_model_ocid, set_fine_tune_env_var
6768
from ads.aqua.model.constants import (
6869
AquaModelMetadataKeys,
6970
FineTuningCustomMetadata,
@@ -311,6 +312,12 @@ def create_multi(
311312
# "Currently only service models are supported for multi model deployment."
312313
# )
313314

315+
is_fine_tuned_model = Tags.AQUA_FINE_TUNED_MODEL_TAG in source_model.freeform_tags
316+
317+
if is_fine_tuned_model:
318+
model.model_id, model.model_name = extract_base_model_ocid(source_model)
319+
model.env_var = set_fine_tune_env_var(source_model, model.env_var)
320+
314321
display_name_list.append(display_name)
315322

316323
self._extract_model_task(model, source_model)

ads/aqua/modeldeployment/deployment.py

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
build_pydantic_error_message,
2626
get_combined_params,
2727
get_container_params_type,
28-
get_model_by_reference_paths,
2928
get_ocid_substring,
3029
get_params_dict,
3130
get_params_list,
@@ -46,7 +45,7 @@
4645
UNKNOWN_DICT,
4746
)
4847
from ads.aqua.data import AquaResourceIdentifier
49-
from ads.aqua.finetuning.finetuning import FineTuneCustomMetadata
48+
from ads.aqua.finetuning.entities import extract_base_model_ocid, set_fine_tune_env_var
5049
from ads.aqua.model import AquaModelApp
5150
from ads.aqua.model.constants import AquaModelMetadataKeys, ModelCustomMetadataFields
5251
from ads.aqua.modeldeployment.entities import (
@@ -210,7 +209,8 @@ def create(
210209
container_config=container_config,
211210
)
212211
else:
213-
model_ids = [model.model_id for model in create_deployment_details.models]
212+
model_ids =[model.model_id for model in create_deployment_details.models]
213+
214214
try:
215215
model_config_summary = self.get_multimodel_deployment_config(
216216
model_ids=model_ids, compartment_id=compartment_id
@@ -343,22 +343,6 @@ def _create(
343343
config_source_id = create_deployment_details.model_id
344344
model_name = aqua_model.display_name
345345

346-
is_fine_tuned_model = Tags.AQUA_FINE_TUNED_MODEL_TAG in aqua_model.freeform_tags
347-
348-
if is_fine_tuned_model:
349-
try:
350-
config_source_id = aqua_model.custom_metadata_list.get(
351-
FineTuneCustomMetadata.FINE_TUNE_SOURCE
352-
).value
353-
model_name = aqua_model.custom_metadata_list.get(
354-
FineTuneCustomMetadata.FINE_TUNE_SOURCE_NAME
355-
).value
356-
except ValueError as err:
357-
raise AquaValueError(
358-
f"Either {FineTuneCustomMetadata.FINE_TUNE_SOURCE} or {FineTuneCustomMetadata.FINE_TUNE_SOURCE_NAME} is missing "
359-
f"from custom metadata for the model {config_source_id}"
360-
) from err
361-
362346
# set up env and cmd var
363347
env_var = create_deployment_details.env_var or {}
364348
cmd_var = create_deployment_details.cmd_var or []
@@ -378,20 +362,11 @@ def _create(
378362

379363
env_var.update({"BASE_MODEL": f"{model_path_prefix}"})
380364

381-
if is_fine_tuned_model:
382-
_, fine_tune_output_path = get_model_by_reference_paths(
383-
aqua_model.model_file_description
384-
)
385-
386-
if not fine_tune_output_path:
387-
raise AquaValueError(
388-
"Fine tuned output path is not available in the model artifact."
389-
)
390-
391-
os_path = ObjectStorageDetails.from_path(fine_tune_output_path)
392-
fine_tune_output_path = os_path.filepath.rstrip("/")
365+
is_fine_tuned_model = Tags.AQUA_FINE_TUNED_MODEL_TAG in aqua_model.freeform_tags
393366

394-
env_var.update({"FT_MODEL": f"{fine_tune_output_path}"})
367+
if is_fine_tuned_model:
368+
config_source_id, model_name = extract_base_model_ocid(aqua_model)
369+
env_var = set_fine_tune_env_var(aqua_model, env_var)
395370

396371
container_type_key = self._get_container_type_key(
397372
model=aqua_model,
@@ -647,6 +622,12 @@ def _create_multi(
647622
config_data = {"params": params, "model_path": artifact_path_prefix}
648623
if model.model_task:
649624
config_data["model_task"] = model.model_task
625+
626+
fine_tuned_model = model.env_var.get("FT_MODEL")
627+
628+
if fine_tuned_model:
629+
config_data["FT_MODEL"] = fine_tuned_model
630+
650631
model_config.append(config_data)
651632
model_name_list.append(model.model_name)
652633

0 commit comments

Comments
 (0)