From b5714a802661f50e39796e0ddf42b4065577465d Mon Sep 17 00:00:00 2001 From: Chad Chiang Date: Tue, 10 Jun 2025 13:04:40 -0700 Subject: [PATCH 1/5] feat: Add support for MetricDefinitions in ModelTrainer --- src/sagemaker/modules/configs.py | 2 ++ src/sagemaker/modules/train/model_trainer.py | 25 +++++++++++++++++ .../modules/train/test_model_trainer.py | 27 +++++++++++++++++++ 3 files changed, 54 insertions(+) diff --git a/src/sagemaker/modules/configs.py b/src/sagemaker/modules/configs.py index 1ada10dff3..8fdf88e735 100644 --- a/src/sagemaker/modules/configs.py +++ b/src/sagemaker/modules/configs.py @@ -42,6 +42,7 @@ RemoteDebugConfig, SessionChainingConfig, InstanceGroup, + MetricDefinition, ) from sagemaker.modules.utils import convert_unassigned_to_none @@ -68,6 +69,7 @@ "Compute", "Networking", "InputData", + "MetricDefinition", ] diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index 7d83766c9f..cf92f5294f 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -66,6 +66,7 @@ RemoteDebugConfig, SessionChainingConfig, InputData, + MetricDefinition, ) from sagemaker.modules.local_core.local_container import _LocalContainer @@ -1290,3 +1291,27 @@ def with_checkpoint_config( """ self.checkpoint_config = checkpoint_config or configs.CheckpointConfig() return self + + def with_metric_definitions( + self, metric_definitions: List[MetricDefinition] + ) -> "ModelTrainer": # noqa: D412 + """Set the metric definitions for the training job. + Example: + .. code:: python + from sagemaker.modules.train import ModelTrainer + from sagemaker.modules.configs import MetricDefinition + metric_definitions = [ + MetricDefinition( + name="loss", + regex="Loss: (.*?)", + ) + ] + model_trainer = ModelTrainer( + ... + ).with_metric_definitions(metric_definitions) + Args: + metric_definitions (List[MetricDefinition]): + The metric definitions for the training job. + """ + self._metric_definitions = metric_definitions + return self diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index cf38f26334..ba459d6c61 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -64,6 +64,7 @@ FileSystemDataSource, Channel, DataSource, + MetricDefinition, ) from sagemaker.modules.distributed import Torchrun, SMP, MPI from sagemaker.modules.train.sm_recipes.utils import _load_recipes_cfg @@ -705,6 +706,32 @@ def test_remote_debug_config(mock_training_job, modules_session): ) +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_metric_definitions(mock_training_job, modules_session): + image_uri = DEFAULT_IMAGE + role = DEFAULT_ROLE + metric_definitions = [ + MetricDefinition( + name="loss", + regex="Loss: (.*?);", + ) + ] + + model_trainer = ModelTrainer( + training_image=image_uri, sagemaker_session=modules_session, role=role + ).with_metric_definitions(metric_definitions) + + with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data: + mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix" + model_trainer.train() + + mock_training_job.create.assert_called_once() + assert ( + mock_training_job.create.call_args.kwargs["algorithm_specification"].metric_definitions + == metric_definitions + ) + + @patch("sagemaker.modules.train.model_trainer._get_unique_name") @patch("sagemaker.modules.train.model_trainer.TrainingJob") def test_model_trainer_full_init(mock_training_job, mock_unique_name, modules_session): From 3fd063c8f93b5fb4032b6e21f9d7e3660a0abd44 Mon Sep 17 00:00:00 2001 From: Chad Chiang Date: Tue, 10 Jun 2025 13:47:30 -0700 Subject: [PATCH 2/5] style fix --- src/sagemaker/modules/train/model_trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index cf92f5294f..2691b5362a 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -1291,12 +1291,14 @@ def with_checkpoint_config( """ self.checkpoint_config = checkpoint_config or configs.CheckpointConfig() return self - + def with_metric_definitions( self, metric_definitions: List[MetricDefinition] ) -> "ModelTrainer": # noqa: D412 """Set the metric definitions for the training job. + Example: + .. code:: python from sagemaker.modules.train import ModelTrainer from sagemaker.modules.configs import MetricDefinition @@ -1309,6 +1311,7 @@ def with_metric_definitions( model_trainer = ModelTrainer( ... ).with_metric_definitions(metric_definitions) + Args: metric_definitions (List[MetricDefinition]): The metric definitions for the training job. From f61c6b0c5679a6becb1ffe0a24d2091c2c450caa Mon Sep 17 00:00:00 2001 From: Chad Chiang <42759281+chad119@users.noreply.github.com> Date: Tue, 10 Jun 2025 14:27:13 -0700 Subject: [PATCH 3/5] Update model_trainer.py to generate the doc --- src/sagemaker/modules/train/model_trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index 2691b5362a..a1cfb6d729 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -1300,14 +1300,17 @@ def with_metric_definitions( Example: .. code:: python + from sagemaker.modules.train import ModelTrainer from sagemaker.modules.configs import MetricDefinition + metric_definitions = [ MetricDefinition( name="loss", regex="Loss: (.*?)", ) ] + model_trainer = ModelTrainer( ... ).with_metric_definitions(metric_definitions) From 395a4353b6250a842c60bfd5f8270f3420ae5419 Mon Sep 17 00:00:00 2001 From: Chad Chiang <42759281+chad119@users.noreply.github.com> Date: Tue, 10 Jun 2025 16:12:37 -0700 Subject: [PATCH 4/5] resolve unit test failed --- tests/unit/sagemaker/modules/train/test_model_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index ba459d6c61..23ea167ecf 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -849,6 +849,7 @@ def mock_upload_data(path, bucket, key_prefix): training_input_mode=training_input_mode, training_image=training_image, algorithm_name=None, + metric_definitions=None, container_entrypoint=DEFAULT_ENTRYPOINT, container_arguments=DEFAULT_ARGUMENTS, training_image_config=training_image_config, From 3ae0d3a362d89b25d79fa8a84663c4af361fd8c5 Mon Sep 17 00:00:00 2001 From: Chad Chiang Date: Wed, 11 Jun 2025 10:55:15 -0700 Subject: [PATCH 5/5] solve another unit test error --- src/sagemaker/modules/train/model_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index a1cfb6d729..eaabe5972a 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -240,6 +240,7 @@ class ModelTrainer(BaseModel): _infra_check_config: Optional[InfraCheckConfig] = PrivateAttr(default=None) _session_chaining_config: Optional[SessionChainingConfig] = PrivateAttr(default=None) _remote_debug_config: Optional[RemoteDebugConfig] = PrivateAttr(default=None) + _metric_definitions: Optional[List[MetricDefinition]] = PrivateAttr(default=None) _temp_recipe_train_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None) @@ -697,6 +698,7 @@ def train( training_image_config=self.training_image_config, container_entrypoint=container_entrypoint, container_arguments=container_arguments, + metric_definitions=self._metric_definitions, ) resource_config = self.compute._to_resource_config()