From eb5ceccd1170e424f634a3ad02d65973afc438ec Mon Sep 17 00:00:00 2001 From: Nathan Park Date: Wed, 3 Sep 2025 15:23:44 -0700 Subject: [PATCH 1/5] fix: add missing FinalMetricDataList to LocalTrainingJob --- src/sagemaker/local/entities.py | 51 +++++++++++++++++++++++++++------ 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/local/entities.py b/src/sagemaker/local/entities.py index 0cf6c6d55a..2d92d75632 100644 --- a/src/sagemaker/local/entities.py +++ b/src/sagemaker/local/entities.py @@ -13,23 +13,25 @@ """Placeholder docstring""" from __future__ import absolute_import -import enum import datetime +import enum import json import logging import os +import re import tempfile import time -from uuid import uuid4 from copy import deepcopy +from uuid import uuid4 + from botocore.exceptions import ClientError import sagemaker.local.data - -from sagemaker.local.image import _SageMakerContainer -from sagemaker.local.utils import copy_directory_structure, move_to_destination, get_docker_host -from sagemaker.utils import DeferredError, get_config_value, format_tags from sagemaker.local.exceptions import StepExecutionException +from sagemaker.local.image import _SageMakerContainer +from sagemaker.local.utils import (copy_directory_structure, get_docker_host, + move_to_destination) +from sagemaker.utils import DeferredError, format_tags, get_config_value logger = logging.getLogger(__name__) @@ -272,9 +274,42 @@ def describe(self): "AlgorithmSpecification": { "ContainerEntrypoint": self.container.container_entrypoint, }, + "FinalMetricDataList": self._extract_final_metrics() } return response + def _extract_final_metrics(self): + """Extract metrics from container logs using metric definitions.""" + if not hasattr(self.container, 'logs') or not self.container.logs: + return [] + + # Get metric definitions from container + metric_definitions = getattr(self.container, 'metric_definitions', []) + if not metric_definitions: + return [] + + final_metrics = [] + logs = self.container.logs + + for metric_def in metric_definitions: + metric_name = metric_def.get('Name') + regex_pattern = metric_def.get('Regex') + + if not metric_name or not regex_pattern: + continue + + # Find all matches in logs + matches = re.findall(regex_pattern, logs) + if matches: + # Use the last match as final metric + final_value = float(matches[-1]) + final_metrics.append({ + 'MetricName': metric_name, + 'Value': final_value, + 'Timestamp': self.end_time or datetime.now() + }) + + return final_metrics class _LocalTransformJob(object): """Placeholder docstring""" @@ -711,8 +746,8 @@ def __init__( PipelineExecutionDisplayName=None, local_session=None, ): - from sagemaker.workflow.pipeline import PipelineGraph from sagemaker import LocalSession + from sagemaker.workflow.pipeline import PipelineGraph self.pipeline = pipeline self.pipeline_execution_name = execution_id @@ -809,7 +844,7 @@ def mark_step_executing(self, step_name): def _initialize_step_execution(self, steps): """Initialize step_execution dict.""" - from sagemaker.workflow.steps import StepTypeEnum, Step + from sagemaker.workflow.steps import Step, StepTypeEnum supported_steps_types = ( StepTypeEnum.TRAINING, From e1f206b70d9fdb287fbaa6b8d101c09522c79959 Mon Sep 17 00:00:00 2001 From: Nathan Park Date: Wed, 3 Sep 2025 15:40:31 -0700 Subject: [PATCH 2/5] format: black --- src/sagemaker/local/entities.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/sagemaker/local/entities.py b/src/sagemaker/local/entities.py index 2d92d75632..0840ebd466 100644 --- a/src/sagemaker/local/entities.py +++ b/src/sagemaker/local/entities.py @@ -274,17 +274,17 @@ def describe(self): "AlgorithmSpecification": { "ContainerEntrypoint": self.container.container_entrypoint, }, - "FinalMetricDataList": self._extract_final_metrics() + "FinalMetricDataList": self._extract_final_metrics(), } return response def _extract_final_metrics(self): """Extract metrics from container logs using metric definitions.""" - if not hasattr(self.container, 'logs') or not self.container.logs: + if not hasattr(self.container, "logs") or not self.container.logs: return [] # Get metric definitions from container - metric_definitions = getattr(self.container, 'metric_definitions', []) + metric_definitions = getattr(self.container, "metric_definitions", []) if not metric_definitions: return [] @@ -292,8 +292,8 @@ def _extract_final_metrics(self): logs = self.container.logs for metric_def in metric_definitions: - metric_name = metric_def.get('Name') - regex_pattern = metric_def.get('Regex') + metric_name = metric_def.get("Name") + regex_pattern = metric_def.get("Regex") if not metric_name or not regex_pattern: continue @@ -303,14 +303,17 @@ def _extract_final_metrics(self): if matches: # Use the last match as final metric final_value = float(matches[-1]) - final_metrics.append({ - 'MetricName': metric_name, - 'Value': final_value, - 'Timestamp': self.end_time or datetime.now() - }) + final_metrics.append( + { + "MetricName": metric_name, + "Value": final_value, + "Timestamp": self.end_time or datetime.now(), + } + ) return final_metrics + class _LocalTransformJob(object): """Placeholder docstring""" From b1f65b8812aee8e70770c930f4cd9033bf5cb9df Mon Sep 17 00:00:00 2001 From: Nathan Park Date: Wed, 3 Sep 2025 17:55:47 -0700 Subject: [PATCH 3/5] format: black again --- src/sagemaker/local/entities.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/sagemaker/local/entities.py b/src/sagemaker/local/entities.py index 0840ebd466..da6a8aac6f 100644 --- a/src/sagemaker/local/entities.py +++ b/src/sagemaker/local/entities.py @@ -29,8 +29,7 @@ import sagemaker.local.data from sagemaker.local.exceptions import StepExecutionException from sagemaker.local.image import _SageMakerContainer -from sagemaker.local.utils import (copy_directory_structure, get_docker_host, - move_to_destination) +from sagemaker.local.utils import copy_directory_structure, get_docker_host, move_to_destination from sagemaker.utils import DeferredError, format_tags, get_config_value logger = logging.getLogger(__name__) From 5d434624039b05d0cdbf99423626e45433094c09 Mon Sep 17 00:00:00 2001 From: Nathan Park Date: Thu, 4 Sep 2025 10:46:06 -0700 Subject: [PATCH 4/5] tests: add unit test for local training job describe --- .../local/test_local_training_job.py | 206 ++++++++++++++++++ 1 file changed, 206 insertions(+) create mode 100644 tests/unit/sagemaker/local/test_local_training_job.py diff --git a/tests/unit/sagemaker/local/test_local_training_job.py b/tests/unit/sagemaker/local/test_local_training_job.py new file mode 100644 index 0000000000..d6bb69a4fa --- /dev/null +++ b/tests/unit/sagemaker/local/test_local_training_job.py @@ -0,0 +1,206 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +from datetime import datetime +from mock import Mock, patch + +from sagemaker.local.entities import _LocalTrainingJob + + +class TestLocalTrainingJobFinalMetrics: + """Test cases for FinalMetricDataList functionality in _LocalTrainingJob.""" + + def test_describe_includes_final_metric_data_list(self): + """Test that describe() includes FinalMetricDataList field.""" + container = Mock() + job = _LocalTrainingJob(container) + job.training_job_name = "test-job" + job.state = "Completed" + job.start_time = datetime.now() + job.end_time = datetime.now() + job.model_artifacts = "/path/to/model" + job.output_data_config = {} + job.environment = {} + + response = job.describe() + + assert "FinalMetricDataList" in response + assert isinstance(response["FinalMetricDataList"], list) + + def test_extract_final_metrics_no_logs(self): + """Test _extract_final_metrics returns empty list when no logs.""" + container = Mock() + container.logs = None + job = _LocalTrainingJob(container) + + result = job._extract_final_metrics() + + assert result == [] + + def test_extract_final_metrics_no_metric_definitions(self): + """Test _extract_final_metrics returns empty list when no metric definitions.""" + container = Mock() + container.logs = "some logs" + container.metric_definitions = [] + job = _LocalTrainingJob(container) + + result = job._extract_final_metrics() + + assert result == [] + + def test_extract_final_metrics_with_valid_metrics(self): + """Test _extract_final_metrics extracts metrics correctly.""" + container = Mock() + container.logs = "Training started\nGAN_loss=0.138318;\nTraining complete" + container.metric_definitions = [ + {"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"} + ] + job = _LocalTrainingJob(container) + job.end_time = datetime(2023, 1, 1, 12, 0, 0) + + result = job._extract_final_metrics() + + assert len(result) == 1 + assert result[0]["MetricName"] == "ganloss" + assert result[0]["Value"] == 0.138318 + assert result[0]["Timestamp"] == job.end_time + + def test_extract_final_metrics_multiple_matches_uses_last(self): + """Test _extract_final_metrics uses the last match for each metric.""" + container = Mock() + container.logs = "GAN_loss=0.5;\nGAN_loss=0.3;\nGAN_loss=0.138318;" + container.metric_definitions = [ + {"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"} + ] + job = _LocalTrainingJob(container) + job.end_time = datetime(2023, 1, 1, 12, 0, 0) + + result = job._extract_final_metrics() + + assert len(result) == 1 + assert result[0]["Value"] == 0.138318 + + def test_extract_final_metrics_multiple_metrics(self): + """Test _extract_final_metrics handles multiple different metrics.""" + container = Mock() + container.logs = "GAN_loss=0.138318;\nAccuracy=0.95;\nLoss=1.234;" + container.metric_definitions = [ + {"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"}, + {"Name": "accuracy", "Regex": r"Accuracy=([\d\.]+);"}, + {"Name": "loss", "Regex": r"Loss=([\d\.]+);"} + ] + job = _LocalTrainingJob(container) + job.end_time = datetime(2023, 1, 1, 12, 0, 0) + + result = job._extract_final_metrics() + + assert len(result) == 3 + metric_names = [m["MetricName"] for m in result] + assert "ganloss" in metric_names + assert "accuracy" in metric_names + assert "loss" in metric_names + + def test_extract_final_metrics_no_matches(self): + """Test _extract_final_metrics returns empty list when regex doesn't match.""" + container = Mock() + container.logs = "Training started\nTraining complete" + container.metric_definitions = [ + {"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"} + ] + job = _LocalTrainingJob(container) + + result = job._extract_final_metrics() + + assert result == [] + + def test_extract_final_metrics_invalid_metric_definition(self): + """Test _extract_final_metrics skips invalid metric definitions.""" + container = Mock() + container.logs = "GAN_loss=0.138318;" + container.metric_definitions = [ + {"Name": "ganloss"}, # Missing Regex + {"Regex": r"GAN_loss=([\d\.]+);"}, # Missing Name + {"Name": "valid", "Regex": r"GAN_loss=([\d\.]+);"} # Valid + ] + job = _LocalTrainingJob(container) + job.end_time = datetime(2023, 1, 1, 12, 0, 0) + + result = job._extract_final_metrics() + + assert len(result) == 1 + assert result[0]["MetricName"] == "valid" + + @patch("sagemaker.local.entities.datetime") + def test_extract_final_metrics_uses_current_time_when_no_end_time(self, mock_datetime): + """Test _extract_final_metrics uses current time when end_time is None.""" + container = Mock() + container.logs = "GAN_loss=0.138318;" + container.metric_definitions = [ + {"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"} + ] + job = _LocalTrainingJob(container) + job.end_time = None + + mock_now = datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.now.return_value = mock_now + + result = job._extract_final_metrics() + + assert len(result) == 1 + assert result[0]["Timestamp"] == mock_now + + @patch("sagemaker.local.image._SageMakerContainer.train", return_value="/some/path/to/model") + def test_integration_describe_training_job_with_metrics(self, mock_train): + """Integration test: describe_training_job includes FinalMetricDataList.""" + from sagemaker.local.local_session import LocalSagemakerClient + + local_sagemaker_client = LocalSagemakerClient() + + algo_spec = {"TrainingImage": "my-image:1.0"} + input_data_config = [{ + "ChannelName": "training", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3Uri": "s3://bucket/data" + } + } + }] + output_data_config = {} + resource_config = {"InstanceType": "local", "InstanceCount": 1} + + # Create training job + local_sagemaker_client.create_training_job( + "test-job", + algo_spec, + output_data_config, + resource_config, + InputDataConfig=input_data_config, + HyperParameters={} + ) + + # Mock the container logs and metric definitions + training_job = local_sagemaker_client._training_jobs["test-job"] + training_job.container.logs = "GAN_loss=0.138318;" + training_job.container.metric_definitions = [ + {"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"} + ] + + response = local_sagemaker_client.describe_training_job("test-job") + + assert "FinalMetricDataList" in response + assert len(response["FinalMetricDataList"]) == 1 + assert response["FinalMetricDataList"][0]["MetricName"] == "ganloss" + assert response["FinalMetricDataList"][0]["Value"] == 0.138318 From 5d15929c26c1c1e5c45df725312bafac5e1fce19 Mon Sep 17 00:00:00 2001 From: Nathan Park Date: Tue, 9 Sep 2025 12:17:10 -0700 Subject: [PATCH 5/5] chore: add more tests --- .../sagemaker/local/test_local_session.py | 190 +++++++++++++++++- 1 file changed, 189 insertions(+), 1 deletion(-) diff --git a/tests/unit/sagemaker/local/test_local_session.py b/tests/unit/sagemaker/local/test_local_session.py index ce8fd19b5c..e11c118e06 100644 --- a/tests/unit/sagemaker/local/test_local_session.py +++ b/tests/unit/sagemaker/local/test_local_session.py @@ -16,6 +16,7 @@ import pytest import urllib3 import os +from datetime import datetime from botocore.exceptions import ClientError from mock import Mock, patch from tests.unit import DATA_DIR, SAGEMAKER_CONFIG_SESSION @@ -25,7 +26,7 @@ from sagemaker.workflow.pipeline import Pipeline from tests.unit.sagemaker.workflow.helpers import CustomStep from sagemaker.local.local_session import LocalSession -from sagemaker.local.entities import _LocalPipelineExecution +from sagemaker.local.entities import _LocalPipelineExecution, _LocalTrainingJob OK_RESPONSE = urllib3.HTTPResponse() @@ -1100,3 +1101,190 @@ def test_config_setter(): with pytest.raises(jsonschema.ValidationError): session.config = INVALID_LOCAL_MODE_CONFIG + + +class TestLocalTrainingJobFinalMetrics: + """Test cases for FinalMetricDataList functionality in _LocalTrainingJob.""" + + def test_describe_includes_final_metric_data_list(self): + """Test that describe() includes FinalMetricDataList field.""" + container = Mock() + container.logs = None + container.metric_definitions = [] + job = _LocalTrainingJob(container) + job.training_job_name = "test-job" + job.state = "Completed" + job.start_time = datetime.now() + job.end_time = datetime.now() + job.model_artifacts = "/path/to/model" + job.output_data_config = {} + job.environment = {} + + response = job.describe() + + assert "FinalMetricDataList" in response + assert isinstance(response["FinalMetricDataList"], list) + + def test_extract_final_metrics_no_logs(self): + """Test _extract_final_metrics returns empty list when no logs.""" + container = Mock() + container.logs = None + job = _LocalTrainingJob(container) + + result = job._extract_final_metrics() + + assert result == [] + + def test_extract_final_metrics_no_metric_definitions(self): + """Test _extract_final_metrics returns empty list when no metric definitions.""" + container = Mock() + container.logs = "some logs" + container.metric_definitions = [] + job = _LocalTrainingJob(container) + + result = job._extract_final_metrics() + + assert result == [] + + def test_extract_final_metrics_with_valid_metrics(self): + """Test _extract_final_metrics extracts metrics correctly.""" + container = Mock() + container.logs = "Training started\nGAN_loss=0.138318;\nTraining complete" + container.metric_definitions = [ + {"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"} + ] + job = _LocalTrainingJob(container) + job.end_time = datetime(2023, 1, 1, 12, 0, 0) + + result = job._extract_final_metrics() + + assert len(result) == 1 + assert result[0]["MetricName"] == "ganloss" + assert result[0]["Value"] == 0.138318 + assert result[0]["Timestamp"] == job.end_time + + def test_extract_final_metrics_multiple_matches_uses_last(self): + """Test _extract_final_metrics uses the last match for each metric.""" + container = Mock() + container.logs = "GAN_loss=0.5;\nGAN_loss=0.3;\nGAN_loss=0.138318;" + container.metric_definitions = [ + {"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"} + ] + job = _LocalTrainingJob(container) + job.end_time = datetime(2023, 1, 1, 12, 0, 0) + + result = job._extract_final_metrics() + + assert len(result) == 1 + assert result[0]["Value"] == 0.138318 + + def test_extract_final_metrics_multiple_metrics(self): + """Test _extract_final_metrics handles multiple different metrics.""" + container = Mock() + container.logs = "GAN_loss=0.138318;\nAccuracy=0.95;\nLoss=1.234;" + container.metric_definitions = [ + {"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"}, + {"Name": "accuracy", "Regex": r"Accuracy=([\d\.]+);"}, + {"Name": "loss", "Regex": r"Loss=([\d\.]+);"} + ] + job = _LocalTrainingJob(container) + job.end_time = datetime(2023, 1, 1, 12, 0, 0) + + result = job._extract_final_metrics() + + assert len(result) == 3 + metric_names = [m["MetricName"] for m in result] + assert "ganloss" in metric_names + assert "accuracy" in metric_names + assert "loss" in metric_names + + def test_extract_final_metrics_no_matches(self): + """Test _extract_final_metrics returns empty list when regex doesn't match.""" + container = Mock() + container.logs = "Training started\nTraining complete" + container.metric_definitions = [ + {"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"} + ] + job = _LocalTrainingJob(container) + + result = job._extract_final_metrics() + + assert result == [] + + def test_extract_final_metrics_invalid_metric_definition(self): + """Test _extract_final_metrics skips invalid metric definitions.""" + container = Mock() + container.logs = "GAN_loss=0.138318;" + container.metric_definitions = [ + {"Name": "ganloss"}, # Missing Regex + {"Regex": r"GAN_loss=([\d\.]+);"}, # Missing Name + {"Name": "valid", "Regex": r"GAN_loss=([\d\.]+);"} # Valid + ] + job = _LocalTrainingJob(container) + job.end_time = datetime(2023, 1, 1, 12, 0, 0) + + result = job._extract_final_metrics() + + assert len(result) == 1 + assert result[0]["MetricName"] == "valid" + + @patch("sagemaker.local.entities.datetime") + def test_extract_final_metrics_uses_current_time_when_no_end_time(self, mock_datetime): + """Test _extract_final_metrics uses current time when end_time is None.""" + container = Mock() + container.logs = "GAN_loss=0.138318;" + container.metric_definitions = [ + {"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"} + ] + job = _LocalTrainingJob(container) + job.end_time = None + + mock_now = datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.now.return_value = mock_now + + result = job._extract_final_metrics() + + assert len(result) == 1 + assert result[0]["Timestamp"] == mock_now + + @patch("sagemaker.local.image._SageMakerContainer.train", return_value="/some/path/to/model") + def test_integration_describe_training_job_with_metrics(self, mock_train): + """Integration test: describe_training_job includes FinalMetricDataList.""" + local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() + + algo_spec = {"TrainingImage": "my-image:1.0"} + input_data_config = [{ + "ChannelName": "training", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3Uri": "s3://bucket/data" + } + } + }] + output_data_config = {} + resource_config = {"InstanceType": "local", "InstanceCount": 1} + + # Create training job + local_sagemaker_client.create_training_job( + "test-job", + algo_spec, + output_data_config, + resource_config, + InputDataConfig=input_data_config, + HyperParameters={} + ) + + # Mock the container logs and metric definitions + training_job = local_sagemaker_client._training_jobs["test-job"] + training_job.container.logs = "GAN_loss=0.138318;" + training_job.container.metric_definitions = [ + {"Name": "ganloss", "Regex": r"GAN_loss=([\d\.]+);"} + ] + + response = local_sagemaker_client.describe_training_job("test-job") + + assert "FinalMetricDataList" in response + assert len(response["FinalMetricDataList"]) == 1 + assert response["FinalMetricDataList"][0]["MetricName"] == "ganloss" + assert response["FinalMetricDataList"][0]["Value"] == 0.138318