2525
2626from sagemaker .apiutils import _utils
2727from sagemaker .experiments import _api_types
28- from sagemaker .experiments ._api_types import TrialComponentArtifact , _TrialComponentStatusType
28+ from sagemaker .experiments ._api_types import (
29+ TrialComponentArtifact ,
30+ _TrialComponentStatusType ,
31+ )
2932from sagemaker .experiments ._helper import (
3033 _ArtifactUploader ,
3134 _LineageArtifactTracker ,
@@ -200,7 +203,11 @@ def __init__(
200203 self .run_name ,
201204 self .experiment_name ,
202205 )
203- self ._trial .add_trial_component (self ._trial_component )
206+
207+ if not _TrialComponent ._trial_component_is_associated_to_trial (
208+ self ._trial_component .trial_component_name , self ._trial .trial_name , sagemaker_session
209+ ):
210+ self ._trial .add_trial_component (self ._trial_component )
204211
205212 self ._artifact_uploader = _ArtifactUploader (
206213 trial_component_name = self ._trial_component .trial_component_name ,
@@ -348,7 +355,10 @@ def log_precision_recall(
348355 "noSkill" : no_skill ,
349356 }
350357 self ._log_graph_artifact (
351- artifact_name = title , data = data , graph_type = "PrecisionRecallCurve" , is_output = is_output
358+ artifact_name = title ,
359+ data = data ,
360+ graph_type = "PrecisionRecallCurve" ,
361+ is_output = is_output ,
352362 )
353363
354364 @validate_invoked_inside_run_context
@@ -381,7 +391,9 @@ def log_roc_curve(
381391 If set to False then represented as input association.
382392 """
383393 verify_length_of_true_and_predicted (
384- true_labels = y_true , predicted_attrs = y_score , predicted_attrs_name = "predicted scores"
394+ true_labels = y_true ,
395+ predicted_attrs = y_score ,
396+ predicted_attrs_name = "predicted scores" ,
385397 )
386398
387399 get_module ("sklearn" )
@@ -432,7 +444,9 @@ def log_confusion_matrix(
432444 If set to False then represented as input association.
433445 """
434446 verify_length_of_true_and_predicted (
435- true_labels = y_true , predicted_attrs = y_pred , predicted_attrs_name = "predicted labels"
447+ true_labels = y_true ,
448+ predicted_attrs = y_pred ,
449+ predicted_attrs_name = "predicted labels" ,
436450 )
437451
438452 get_module ("sklearn" )
@@ -447,12 +461,19 @@ def log_confusion_matrix(
447461 "confusionMatrix" : matrix .tolist (),
448462 }
449463 self ._log_graph_artifact (
450- artifact_name = title , data = data , graph_type = "ConfusionMatrix" , is_output = is_output
464+ artifact_name = title ,
465+ data = data ,
466+ graph_type = "ConfusionMatrix" ,
467+ is_output = is_output ,
451468 )
452469
453470 @validate_invoked_inside_run_context
454471 def log_artifact (
455- self , name : str , value : str , media_type : Optional [str ] = None , is_output : bool = True
472+ self ,
473+ name : str ,
474+ value : str ,
475+ media_type : Optional [str ] = None ,
476+ is_output : bool = True ,
456477 ):
457478 """Record a single artifact for this run.
458479
@@ -575,11 +596,17 @@ def _log_graph_artifact(self, data, graph_type, is_output, artifact_name=None):
575596 # create an artifact and association for the table
576597 if is_output :
577598 self ._lineage_artifact_tracker .add_output_artifact (
578- name = artifact_name , source_uri = s3_uri , etag = etag , artifact_type = graph_type
599+ name = artifact_name ,
600+ source_uri = s3_uri ,
601+ etag = etag ,
602+ artifact_type = graph_type ,
579603 )
580604 else :
581605 self ._lineage_artifact_tracker .add_input_artifact (
582- name = artifact_name , source_uri = s3_uri , etag = etag , artifact_type = graph_type
606+ name = artifact_name ,
607+ source_uri = s3_uri ,
608+ etag = etag ,
609+ artifact_type = graph_type ,
583610 )
584611
585612 def _verify_trial_component_artifacts_length (self , is_output ):
@@ -719,7 +746,8 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
719746 self ._trial_component .end_time = end_time
720747 if exc_value :
721748 self ._trial_component .status = _api_types .TrialComponentStatus (
722- primary_status = _TrialComponentStatusType .Failed .value , message = str (exc_value )
749+ primary_status = _TrialComponentStatusType .Failed .value ,
750+ message = str (exc_value ),
723751 )
724752 else :
725753 self ._trial_component .status = _api_types .TrialComponentStatus (
@@ -837,7 +865,8 @@ def load_run(
837865 run_instance = _RunContext .get_current_run ()
838866 elif environment :
839867 exp_config = get_tc_and_exp_config_from_job_env (
840- environment = environment , sagemaker_session = sagemaker_session or _utils .default_session ()
868+ environment = environment ,
869+ sagemaker_session = sagemaker_session or _utils .default_session (),
841870 )
842871 run_name = Run ._extract_run_name_from_tc_name (
843872 trial_component_name = exp_config [RUN_NAME ],
0 commit comments