@@ -54,6 +54,13 @@ class PyTorch(Framework):
5454 # to retrieve the image uri below before GA.
5555 SM_ADAPTER_REPO = "git@github.com:aws/private-sagemaker-training-adapter-for-nemo-staging.git"
5656 SM_LAUNCHER_REPO = "git@github.com:aws/private-sagemaker-training-launcher-staging.git"
57+ SM_TRAINING_RECIPE_GPU_IMG = (
58+ "855988369404.dkr.ecr.us-west-2.amazonaws.com/chinmayee-dev:adaptor_sept9_v1"
59+ )
60+ SM_NEURONX_DIST_REPO = "https://github.com/aws-neuron/neuronx-distributed-training.git"
61+ SM_NEURONX_DIST_IMG = (
62+ "855988369404.dkr.ecr.us-west-2.amazonaws.com/chinmayee-dev:neuron_sept26_v1"
63+ )
5764
5865 def __init__ (
5966 self ,
@@ -66,6 +73,7 @@ def __init__(
6673 distribution : Optional [Dict ] = None ,
6774 compiler_config : Optional [TrainingCompilerConfig ] = None ,
6875 training_recipe : Optional [str ] = None ,
76+ recipe_overrides : Optional [Dict ] = None ,
6977 ** kwargs ,
7078 ):
7179 """This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment.
@@ -262,6 +270,9 @@ def __init__(
262270 a url to fetch, or a recipe provided by Saagemaker
263271 training.
264272
273+ recipe_overrides (Dict): Dictionary specifying key values to override in the
274+ training_recipe.
275+
265276 **kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework`
266277 constructor.
267278
@@ -280,12 +291,12 @@ def __init__(
280291 logger .warning ("Argument hyperparameters will be ignored with training recipe." )
281292 if distribution is not None :
282293 logger .warning ("Argument distribution will be ignored with training_recipe." )
283- args = self ._setup_for_training_recipe (training_recipe , kwargs )
294+ args = self ._setup_for_training_recipe (training_recipe , recipe_overrides , kwargs )
284295 entry_point = args ["entry_point" ]
285296 source_dir = args ["source_dir" ]
286297 hyperparameters = args ["hyperparameters" ]
287298 if image_uri is None :
288- image_uri = args ["image_uri " ]
299+ image_uri = args ["default_image_uri " ]
289300 distribution = args ["distribution" ]
290301 elif entry_point is None :
291302 raise ValueError (
@@ -518,7 +529,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
518529 return init_params
519530
520531 @classmethod
521- def _setup_for_training_recipe (cls , training_recipe , kwargs ):
532+ def _setup_for_training_recipe (cls , training_recipe , recipe_overrides , kwargs ):
522533 """Performs training recipe specific setup and returns recipe specific args.
523534
524535 Updates kwargs and returns a dictionary of args to use for estimator
@@ -528,28 +539,25 @@ def _setup_for_training_recipe(cls, training_recipe, kwargs):
528539 Args:
529540 training_recipe (str): A recipe which is a local file path, a url or a
530541 sagemaker training recipe.
542+ recipe_overrides (Dict): Dictionary specifying key values to override in the
543+ training_recipe.
531544 kwargs (dict): Dictionary of args used for estimator initializaiton.
532545 Returns:
533546 dict containing arg values for estimator initialization and setup.
534547
535548 """
549+ if recipe_overrides is None :
550+ recipe_overrides = dict ()
536551 cls .recipe_train_dir = tempfile .TemporaryDirectory (prefix = "training_" )
537552 cls .recipe_launcher_dir = tempfile .TemporaryDirectory (prefix = "launcher_" )
538553
539- adapter_repo = os .environ .get ("training_adapter_git" , None ) or cls .SM_ADAPTER_REPO
540- _run_clone_command (adapter_repo , cls .recipe_train_dir .name )
541- source_dir = os .path .join (cls .recipe_train_dir .name , "scripts" )
542-
543- model_type_to_script = {"llama_v3" : "llama_pretrain.py" }
544-
545- args = {"source_dir" : source_dir }
546- local_recipe_path = os .path .join (source_dir , "recipe.yaml" )
554+ temp_local_recipe = tempfile .NamedTemporaryFile (prefix = "recipe" ).name
547555 if training_recipe .endswith (".yaml" ):
548556 if os .path .isfile (training_recipe ):
549- shutil .copy (training_recipe , local_recipe_path )
557+ shutil .copy (training_recipe , temp_local_recipe )
550558 else :
551559 try :
552- urlretrieve (training_recipe , local_recipe_path )
560+ urlretrieve (training_recipe , temp_local_recipe )
553561 except Exception as e :
554562 raise ValueError (
555563 f"Could not fetch the provided recipe { training_recipe } : exception { str (e )} "
@@ -559,28 +567,27 @@ def _setup_for_training_recipe(cls, training_recipe, kwargs):
559567 _run_clone_command (launcher_repo , cls .recipe_launcher_dir .name )
560568 recipe = os .path .join (
561569 cls .recipe_launcher_dir .name ,
562- "examples " ,
570+ "recipes-collection " ,
563571 "recipes" ,
564572 "training" ,
565573 training_recipe + ".yaml" ,
566574 )
567575 if os .path .isfile (recipe ):
568- shutil .copy (recipe , local_recipe_path )
576+ shutil .copy (recipe , temp_local_recipe )
569577 else :
570578 raise ValueError (f"Recipe { training_recipe } not found." )
571579
572- recipe = OmegaConf .load (local_recipe_path )
573-
574- if "model" not in recipe :
575- raise ValueError ("Supplied recipe does not contain required field model." )
576- if "model_type" not in recipe ["model" ]:
577- raise ValueError ("Supplied recipe does not contain required field model_type." )
578- model_type = recipe ["model" ]["model_type" ]
579- if model_type not in model_type_to_script :
580- raise ValueError (f"Model type { model_type } not supported" )
581- args ["model_type" ] = model_type
582- args ["entry_point" ] = model_type_to_script [model_type ]
583- args ["hyperparameters" ] = {"config-path" : "." , "config-name" : "recipe.yaml" }
580+ recipe = OmegaConf .load (temp_local_recipe )
581+
582+ if "instance_type" not in kwargs :
583+ raise ValueError ("Must pass instance type to estimator when using training recipes." )
584+ instance_type = kwargs ["instance_type" ].split ("." )[1 ]
585+ if instance_type .startswith (("p" , "g" )):
586+ device_type = "gpu"
587+ elif instance_type .startswith ("trn" ):
588+ device_type = "trainium"
589+ else :
590+ device_type = "cpu"
584591
585592 if "trainer" not in recipe :
586593 raise ValueError ("Supplied recipe does not contain required field trainer." )
@@ -597,17 +604,32 @@ def _setup_for_training_recipe(cls, training_recipe, kwargs):
597604 )
598605 kwargs ["instance_count" ] = recipe ["trainer" ]["num_nodes" ]
599606
600- if "accelerator" not in recipe ["trainer" ]:
601- raise ValueError (
602- "Supplied recipe does not contain required field trainer -> accelerator."
603- )
604- accelerator = recipe ["trainer" ]["accelerator" ]
605- if accelerator == "gpu" :
606- # [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve
607- # to retrieve the image uri below before we go GA.
608- args ["image_uri" ] = (
609- "855988369404.dkr.ecr.us-west-2.amazonaws.com/chinmayee-dev:adaptor_sept9_v1"
607+ args = dict ()
608+ # [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve
609+ # to retrieve the image uri below before we go GA.
610+ if device_type == "gpu" :
611+ adapter_repo = os .environ .get ("training_adapter_git" , None ) or cls .SM_ADAPTER_REPO
612+ _run_clone_command (adapter_repo , cls .recipe_train_dir .name )
613+
614+ model_type_to_entry = {
615+ "llama_v3" : ("llama" , "llama_pretrain.py" ),
616+ "mistral" : ("mistral" , "mistral_pretrain.py" ),
617+ "mixtral" : ("mixtral" , "mixtral_pretrain.py" ),
618+ }
619+
620+ if "model" not in recipe :
621+ raise ValueError ("Supplied recipe does not contain required field model." )
622+ if "model_type" not in recipe ["model" ]:
623+ raise ValueError ("Supplied recipe does not contain required field model_type." )
624+ model_type = recipe ["model" ]["model_type" ]
625+ if model_type not in model_type_to_entry :
626+ raise ValueError (f"Model type { model_type } not supported" )
627+
628+ args ["source_dir" ] = os .path .join (
629+ cls .recipe_train_dir .name , "examples" , model_type_to_entry [model_type ][0 ]
610630 )
631+ args ["entry_point" ] = model_type_to_entry [model_type ][1 ]
632+ args ["default_image_uri" ] = cls .SM_TRAINING_RECIPE_GPU_IMG
611633 smp_options = {
612634 "enabled" : True ,
613635 "parameters" : {
@@ -618,26 +640,29 @@ def _setup_for_training_recipe(cls, training_recipe, kwargs):
618640 "smdistributed" : {"modelparallel" : smp_options },
619641 "torch_distributed" : {"enabled" : True },
620642 }
643+ elif device_type == "trainium" :
644+ _run_clone_command (cls .SM_NEURONX_DIST_REPO , cls .recipe_train_dir .name )
645+ args ["source_dir" ] = os .path .join (cls .recipe_train_dir .name , "examples" )
646+ args ["entry_point" ] = "training_orchestrator.py"
647+ args ["default_image_uri" ] = cls .SM_NEURONX_DIST_IMG
648+ args ["distribution" ] = {
649+ "torch_distributed" : {"enabled" : True },
650+ }
621651 else :
622- raise ValueError (f"Accelerator type { accelerator } not yet supported." )
623-
624- try :
625- recipe ["run" ]["results_dir" ] = "/opt/ml/model/"
626- recipe ["exp_manager" ]["exp_dir" ] = "/opt/ml/model/"
627- recipe ["exp_manager" ]["explicit_log_dir" ] = "/opt/ml/output/tensorboard"
628- recipe ["exp_manager" ]["checkpoint_dir" ] = "/opt/ml/checkpoints"
629- recipe ["model" ]["data" ]["train_dir" ] = ["/opt/ml/input/data/train" ]
630- recipe ["model" ]["data" ]["val_dir" ] = ["/opt/ml/input/data/val" ]
631- except KeyError as e :
632- raise RuntimeError (
633- f"Error when trying to update recipe for sagemaker jobs with key { str (e )} ."
652+ raise ValueError (
653+ f"Devices of type { device_type } are not supported with training recipes."
634654 )
635655
656+ recipe_overrides .setdefault ("run" , dict ())["results_dir" ] = "/opt/ml/model"
657+ recipe_overrides .setdefault ("exp_manager" , dict ())["exp_dir" ] = "/opt/ml/model/"
658+ recipe = OmegaConf .merge (recipe , recipe_overrides )
659+
636660 if "container" in recipe and not recipe ["container" ]:
637661 logger .warning (
638662 "Ignoring container from training_recipe. Use image_uri arg for estimator."
639663 )
640664
641- OmegaConf .save (config = recipe , f = local_recipe_path )
665+ OmegaConf .save (config = recipe , f = os .path .join (args ["source_dir" ], "recipe.yaml" ))
666+ args ["hyperparameters" ] = {"config-path" : "." , "config-name" : "recipe.yaml" }
642667
643668 return args
0 commit comments