From 9059c254c3226dc8b158bde47f9c909488d9e93a Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Thu, 30 Oct 2025 15:26:18 +0000 Subject: [PATCH 1/2] feat: Extract reward_lambda_arn from Nova recipes to training job hyperparameters --- .../modules/train/sm_recipes/utils.py | 6 +++ src/sagemaker/pytorch/estimator.py | 6 +++ .../modules/train/sm_recipes/test_utils.py | 54 +++++++++++++++++++ 3 files changed, 66 insertions(+) diff --git a/src/sagemaker/modules/train/sm_recipes/utils.py b/src/sagemaker/modules/train/sm_recipes/utils.py index c7457f6fad..e0400a1c0e 100644 --- a/src/sagemaker/modules/train/sm_recipes/utils.py +++ b/src/sagemaker/modules/train/sm_recipes/utils.py @@ -312,6 +312,12 @@ def _get_args_from_nova_recipe( if lambda_arn: args["hyperparameters"]["eval_lambda_arn"] = lambda_arn + # Handle reward lambda configuration + run_config = recipe.get("run", {}) + reward_lambda_arn = run_config.get("reward_lambda_arn", "") + if reward_lambda_arn: + args["hyperparameters"]["reward_lambda_arn"] = reward_lambda_arn + _register_custom_resolvers() # Resolve Final Recipe diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index ce8daae9d1..db137b11f9 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -1251,6 +1251,12 @@ def _setup_for_nova_recipe( if lambda_arn: args["hyperparameters"]["eval_lambda_arn"] = lambda_arn + # Handle reward lambda configuration + run_config = recipe.get("run", {}) + reward_lambda_arn = run_config.get("reward_lambda_arn", "") + if reward_lambda_arn: + args["hyperparameters"]["reward_lambda_arn"] = reward_lambda_arn + # Resolve and save the final recipe self._recipe_resolve_and_save(recipe, recipe_name, args["source_dir"]) diff --git a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py index 6087050171..7a5912d25e 100644 --- a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py +++ b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py @@ -478,3 +478,57 @@ def test_get_args_from_nova_recipe_with_evaluation(test_case): recipe=recipe, compute=test_case["compute"], role=test_case["role"] ) assert args == test_case["expected_args"] + + +@pytest.mark.parametrize( + "test_case", + [ + { + "recipe": { + "run": { + "model_type": "amazon.nova", + "model_name_or_path": "dummy-test", + "reward_lambda_arn": "arn:aws:lambda:us-east-1:123456789012:function:MyRewardLambdaFunction", + }, + }, + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "role": "arn:aws:iam::123456789012:role/SageMakerRole", + "expected_args": { + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "hyperparameters": { + "base_model": "dummy-test", + "reward_lambda_arn": "arn:aws:lambda:us-east-1:123456789012:function:MyRewardLambdaFunction", + }, + "training_image": None, + "source_code": None, + "distributed": None, + }, + }, + { + "recipe": { + "run": { + "model_type": "amazon.nova", + "model_name_or_path": "dummy-test", + # No reward_lambda_arn - should not be in hyperparameters + }, + }, + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "role": "arn:aws:iam::123456789012:role/SageMakerRole", + "expected_args": { + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "hyperparameters": { + "base_model": "dummy-test", + }, + "training_image": None, + "source_code": None, + "distributed": None, + }, + }, + ], +) +def test_get_args_from_nova_recipe_with_reward_lambda(test_case): + recipe = OmegaConf.create(test_case["recipe"]) + args, _ = _get_args_from_nova_recipe( + recipe=recipe, compute=test_case["compute"], role=test_case["role"] + ) + assert args == test_case["expected_args"] From ef3bf7b716d6e9fa853574c08de8e63f9e3ae0f6 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Thu, 30 Oct 2025 15:44:17 +0000 Subject: [PATCH 2/2] Add test for pytorch reward lambda --- tests/unit/test_pytorch_nova.py | 78 +++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/tests/unit/test_pytorch_nova.py b/tests/unit/test_pytorch_nova.py index 662d27e85f..ddc4b62d1e 100644 --- a/tests/unit/test_pytorch_nova.py +++ b/tests/unit/test_pytorch_nova.py @@ -832,3 +832,81 @@ def test_setup_for_nova_recipe_sets_model_type(mock_resolve_save, sagemaker_sess # Verify that model_type hyperparameter was set correctly assert pytorch._hyperparameters.get("model_type") == "amazon.nova.llama-2-7b" + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_resolve_and_save") +def test_setup_for_nova_recipe_with_reward_lambda(mock_resolve_save, sagemaker_session): + """Test that _setup_for_nova_recipe correctly handles reward lambda configuration.""" + # Create a mock recipe with reward lambda config + recipe = OmegaConf.create( + { + "run": { + "model_type": "amazon.nova.foobar3", + "model_name_or_path": "foobar/foobar-3-8b", + "reward_lambda_arn": "arn:aws:lambda:us-west-2:123456789012:function:reward-function", + "replicas": 1, + }, + } + ) + + with patch( + "sagemaker.pytorch.estimator.PyTorch._recipe_load", return_value=("nova_recipe", recipe) + ): + mock_resolve_save.return_value = recipe + + pytorch = PyTorch( + training_recipe="nova_recipe", + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + ) + + # Check that the Nova recipe was correctly identified + assert pytorch.is_nova_or_eval_recipe is True + + # Verify that reward_lambda_arn hyperparameter was set correctly + assert ( + pytorch._hyperparameters.get("reward_lambda_arn") + == "arn:aws:lambda:us-west-2:123456789012:function:reward-function" + ) + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_resolve_and_save") +def test_setup_for_nova_recipe_without_reward_lambda(mock_resolve_save, sagemaker_session): + """Test that _setup_for_nova_recipe does not set reward_lambda_arn when not present.""" + # Create a mock recipe without reward lambda config + recipe = OmegaConf.create( + { + "run": { + "model_type": "amazon.nova.foobar3", + "model_name_or_path": "foobar/foobar-3-8b", + "replicas": 1, + }, + } + ) + + with patch( + "sagemaker.pytorch.estimator.PyTorch._recipe_load", return_value=("nova_recipe", recipe) + ): + mock_resolve_save.return_value = recipe + + pytorch = PyTorch( + training_recipe="nova_recipe", + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + ) + + # Check that the Nova recipe was correctly identified + assert pytorch.is_nova_or_eval_recipe is True + + # Verify that reward_lambda_arn hyperparameter was not set + assert "reward_lambda_arn" not in pytorch._hyperparameters