1313"""Placeholder docstring"""
1414from __future__ import absolute_import
1515
16+ import json
1617import logging
1718import math
1819import os
3536 profiler_config_deprecation_warning ,
3637)
3738from sagemaker .git_utils import _run_clone_command
39+ from sagemaker .image_uris import retrieve
3840from sagemaker .pytorch import defaults
3941from sagemaker .pytorch .model import PyTorchModel
4042from sagemaker .pytorch .training_compiler .config import TrainingCompilerConfig
43+ from sagemaker .session import Session
4144from sagemaker .vpc_utils import VPC_CONFIG_DEFAULT
4245from sagemaker .workflow .entities import PipelineVariable
4346
@@ -67,15 +70,6 @@ class PyTorch(Framework):
6770
6871 # [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve
6972 # to retrieve the image uri below before GA.
70- SM_ADAPTER_REPO = "git@github.com:aws/private-sagemaker-training-adapter-for-nemo-staging.git"
71- SM_LAUNCHER_REPO = "git@github.com:aws/private-sagemaker-training-launcher-staging.git"
72- SM_TRAINING_RECIPE_GPU_IMG = (
73- "855988369404.dkr.ecr.us-west-2.amazonaws.com/chinmayee-dev:adaptor_sept9_v1"
74- )
75- SM_NEURONX_DIST_REPO = "https://github.com/aws-neuron/neuronx-distributed-training.git"
76- SM_NEURONX_DIST_IMG = (
77- "855988369404.dkr.ecr.us-west-2.amazonaws.com/chinmayee-dev:neuron_sept26_v1"
78- )
7973
8074 def __init__ (
8175 self ,
@@ -561,6 +555,16 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
561555 dict containing arg values for estimator initialization and setup.
562556
563557 """
558+ if kwargs .get ("sagemaker_session" ) is not None :
559+ region_name = kwargs .get ("sagemaker_session" ).boto_region_name
560+ else :
561+ region_name = Session ().boto_region_name
562+ training_recipes_cfg_filename = os .path .join (
563+ os .path .dirname (__file__ ), "training_recipes.json"
564+ )
565+ with open (training_recipes_cfg_filename ) as training_recipes_cfg_file :
566+ training_recipes_cfg = json .load (training_recipes_cfg_file )
567+
564568 if recipe_overrides is None :
565569 recipe_overrides = dict ()
566570 cls .recipe_train_dir = tempfile .TemporaryDirectory (prefix = "training_" )
@@ -580,7 +584,9 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
580584 f"Could not fetch the provided recipe { training_recipe } : exception { str (e )} "
581585 )
582586 else :
583- launcher_repo = os .environ .get ("training_launcher_git" , None ) or cls .SM_LAUNCHER_REPO
587+ launcher_repo = os .environ .get (
588+ "training_launcher_git" , None
589+ ) or training_recipes_cfg .get ("launcher_repo" )
584590 _run_clone_command (launcher_repo , cls .recipe_launcher_dir .name )
585591 recipe = os .path .join (
586592 cls .recipe_launcher_dir .name ,
@@ -629,7 +635,9 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
629635 # [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve
630636 # to retrieve the image uri below before we go GA.
631637 if device_type == "gpu" :
632- adapter_repo = os .environ .get ("training_adapter_git" , None ) or cls .SM_ADAPTER_REPO
638+ adapter_repo = os .environ .get ("training_adapter_git" , None ) or training_recipes_cfg .get (
639+ "adapter_repo"
640+ )
633641 _run_clone_command (adapter_repo , cls .recipe_train_dir .name )
634642
635643 model_type_to_entry = {
@@ -650,7 +658,17 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
650658 cls .recipe_train_dir .name , "examples" , model_type_to_entry [model_type ][0 ]
651659 )
652660 args ["entry_point" ] = model_type_to_entry [model_type ][1 ]
653- args ["default_image_uri" ] = cls .SM_TRAINING_RECIPE_GPU_IMG
661+ gpu_image_cfg = training_recipes_cfg .get ("gpu_image" )
662+ if isinstance (gpu_image_cfg , str ):
663+ args ["default_image_uri" ] = gpu_image_cfg
664+ else :
665+ args ["default_image_uri" ] = retrieve (
666+ gpu_image_cfg .get ("framework" ),
667+ region = region_name ,
668+ version = gpu_image_cfg .get ("version" ),
669+ image_scope = "training" ,
670+ ** gpu_image_cfg .get ("additional_args" ),
671+ )
654672 smp_options = {
655673 "enabled" : True ,
656674 "parameters" : {
@@ -662,10 +680,22 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
662680 "torch_distributed" : {"enabled" : True },
663681 }
664682 elif device_type == "trainium" :
665- _run_clone_command (cls .SM_NEURONX_DIST_REPO , cls .recipe_train_dir .name )
683+ _run_clone_command (
684+ training_recipes_cfg .get ("neuron_dist_repo" ), cls .recipe_train_dir .name
685+ )
666686 args ["source_dir" ] = os .path .join (cls .recipe_train_dir .name , "examples" )
667687 args ["entry_point" ] = "training_orchestrator.py"
668- args ["default_image_uri" ] = cls .SM_NEURONX_DIST_IMG
688+ neuron_image_cfg = training_recipes_cfg .get ("neuron_image" )
689+ if isinstance (neuron_image_cfg , str ):
690+ args ["default_image_uri" ] = neuron_image_cfg
691+ else :
692+ args ["default_image_uri" ] = retrieve (
693+ neuron_image_cfg .get ("framework" ),
694+ region = region_name ,
695+ version = neuron_image_cfg .get ("version" ),
696+ image_scope = "training" ,
697+ ** neuron_image_cfg .get ("additional_args" ),
698+ )
669699 args ["distribution" ] = {
670700 "torch_distributed" : {"enabled" : True },
671701 }
0 commit comments