1818import logging
1919import os
2020
21+ from sagemaker import Session
2122from sagemaker .experiments import trial_component
2223from sagemaker .utils import retry_with_backoff
2324
2425TRAINING_JOB_ARN_ENV = "TRAINING_JOB_ARN"
2526PROCESSING_JOB_CONFIG_PATH = "/opt/ml/config/processingjobconfig.json"
26- TRANSFORM_JOB_ENV_BATCH_VAR = "SAGEMAKER_BATCH "
27+ TRANSFORM_JOB_ARN_ENV = "TRANSFORM_JOB_ARN "
2728MAX_RETRY_ATTEMPTS = 7
2829
2930logger = logging .getLogger (__name__ )
@@ -40,7 +41,7 @@ class _EnvironmentType(enum.Enum):
4041class _RunEnvironment (object ):
4142 """Retrieves job specific data from the environment."""
4243
43- def __init__ (self , environment_type , source_arn ):
44+ def __init__ (self , environment_type : _EnvironmentType , source_arn : str ):
4445 """Init for _RunEnvironment.
4546
4647 Args:
@@ -53,9 +54,9 @@ def __init__(self, environment_type, source_arn):
5354 @classmethod
5455 def load (
5556 cls ,
56- training_job_arn_env = TRAINING_JOB_ARN_ENV ,
57- processing_job_config_path = PROCESSING_JOB_CONFIG_PATH ,
58- transform_job_batch_var = TRANSFORM_JOB_ENV_BATCH_VAR ,
57+ training_job_arn_env : str = TRAINING_JOB_ARN_ENV ,
58+ processing_job_config_path : str = PROCESSING_JOB_CONFIG_PATH ,
59+ transform_job_arn_env : str = TRANSFORM_JOB_ARN_ENV ,
5960 ):
6061 """Loads source arn of current job from environment.
6162
@@ -64,8 +65,8 @@ def load(
6465 (default: `TRAINING_JOB_ARN`).
6566 processing_job_config_path (str): The processing job config path
6667 (default: `/opt/ml/config/processingjobconfig.json`).
67- transform_job_batch_var (str): The environment variable indicating if
68- it is a transform job (default: `SAGEMAKER_BATCH `).
68+ transform_job_arn_env (str): The environment key for transform job ARN
69+ (default: `TRANSFORM_JOB_ARN_ENV `).
6970
7071 Returns:
7172 _RunEnvironment: Job data loaded from the environment. None if config does not exist.
@@ -78,16 +79,15 @@ def load(
7879 environment_type = _EnvironmentType .SageMakerProcessingJob
7980 source_arn = json .loads (open (processing_job_config_path ).read ())["ProcessingJobArn" ]
8081 return _RunEnvironment (environment_type , source_arn )
81- if transform_job_batch_var in os .environ and os . environ [ transform_job_batch_var ] == "true" :
82+ if transform_job_arn_env in os .environ :
8283 environment_type = _EnvironmentType .SageMakerTransformJob
83- # TODO: need to figure out how to get source_arn from job env
84- # with Transform team's help.
85- source_arn = ""
84+ # TODO: need to update to get source_arn from config file once Transform side ready
85+ source_arn = os .environ .get (transform_job_arn_env )
8686 return _RunEnvironment (environment_type , source_arn )
8787
8888 return None
8989
90- def get_trial_component (self , sagemaker_session ):
90+ def get_trial_component (self , sagemaker_session : Session ):
9191 """Retrieves the trial component from the job in the environment.
9292
9393 Args:
@@ -99,14 +99,6 @@ def get_trial_component(self, sagemaker_session):
9999 Returns:
100100 _TrialComponent: The trial component created from the job. None if not found.
101101 """
102- # TODO: Remove this condition check once we have a way to retrieve source ARN
103- # from transform job env
104- if self .environment_type == _EnvironmentType .SageMakerTransformJob :
105- logger .error (
106- "Currently getting the job trial component from the transform job environment "
107- "is not supported. Returning None."
108- )
109- return None
110102
111103 def _get_trial_component ():
112104 summaries = list (
0 commit comments