diff --git a/ads/aqua/finetuning/constants.py b/ads/aqua/finetuning/constants.py index 1e7309e61..9e18bcb1f 100644 --- a/ads/aqua/finetuning/constants.py +++ b/ads/aqua/finetuning/constants.py @@ -15,3 +15,6 @@ class FineTuneCustomMetadata(str, metaclass=ExtendedEnumMeta): SERVICE_MODEL_ARTIFACT_LOCATION = "artifact_location" SERVICE_MODEL_DEPLOYMENT_CONTAINER = "deployment-container" SERVICE_MODEL_FINE_TUNE_CONTAINER = "finetune-container" + + +ENV_AQUA_FINE_TUNING_CONTAINER = "AQUA_FINE_TUNING_CONTAINER" diff --git a/ads/aqua/finetuning/finetuning.py b/ads/aqua/finetuning/finetuning.py index 2fc52e441..11a99c6b1 100644 --- a/ads/aqua/finetuning/finetuning.py +++ b/ads/aqua/finetuning/finetuning.py @@ -31,7 +31,10 @@ UNKNOWN_DICT, ) from ads.aqua.data import AquaResourceIdentifier -from ads.aqua.finetuning.constants import * +from ads.aqua.finetuning.constants import ( + ENV_AQUA_FINE_TUNING_CONTAINER, + FineTuneCustomMetadata, +) from ads.aqua.finetuning.entities import * from ads.common.auth import default_signer from ads.common.object_storage_details import ObjectStorageDetails @@ -310,6 +313,15 @@ def create( except Exception: pass + if not is_custom_container and ENV_AQUA_FINE_TUNING_CONTAINER in os.environ: + ft_container = os.environ[ENV_AQUA_FINE_TUNING_CONTAINER] + logger.info( + "Using container set by environment variable %s=%s", + ENV_AQUA_FINE_TUNING_CONTAINER, + ft_container, + ) + is_custom_container = True + ft_parameters.batch_size = ft_parameters.batch_size or ( ft_config.get("shape", UNKNOWN_DICT) .get(create_fine_tuning_details.shape_name, UNKNOWN_DICT) @@ -559,7 +571,6 @@ def get_finetuning_config(self, model_id: str) -> Dict: Dict: A dict of allowed finetuning configs. """ - config = self.get_config(model_id, AQUA_MODEL_FINETUNING_CONFIG) if not config: logger.debug(