4747logger = logging .getLogger ("sagemaker" )
4848
4949
50+ def _setup_omegaconf_resolvers ():
51+ """Set up omegaconf resolvers for training recipes."""
52+ if not OmegaConf .has_resolver ("multiply" ):
53+ OmegaConf .register_new_resolver ("multiply" , lambda x , y : x * y , replace = True )
54+ if not OmegaConf .has_resolver ("divide_ceil" ):
55+ OmegaConf .register_new_resolver (
56+ "divide_ceil" , lambda x , y : int (math .ceil (x / y )), replace = True
57+ )
58+ if not OmegaConf .has_resolver ("divide_floor" ):
59+ OmegaConf .register_new_resolver (
60+ "divide_floor" , lambda x , y : int (math .floor (x / y )), replace = True
61+ )
62+ if not OmegaConf .has_resolver ("add" ):
63+ OmegaConf .register_new_resolver ("add" , lambda * numbers : sum (numbers ))
64+
65+
5066def _try_resolve_recipe (recipe , key = None ):
5167 """Try to resolve recipe and return resolved recipe."""
5268 if key is not None :
@@ -60,6 +76,49 @@ def _try_resolve_recipe(recipe, key=None):
6076 return recipe [key ]
6177
6278
79+ def _get_training_recipe_image_uri (image_cfg , region_name ):
80+ """Fetch image uri given image spec and region name to use for training."""
81+ if isinstance (image_cfg , str ):
82+ return image_cfg
83+ return retrieve (
84+ image_cfg .get ("framework" ),
85+ region = region_name ,
86+ version = image_cfg .get ("version" ),
87+ image_scope = "training" ,
88+ ** image_cfg .get ("additional_args" ),
89+ )
90+
91+
92+ def _get_training_recipe_gpu_script (code_dir , recipe , source_dir ):
93+ """Return path to training script (entry point) when running a gpu recipe."""
94+ model_type_to_script = {
95+ "llama_v3" : ("llama" , "llama_pretrain.py" ),
96+ "mistral" : ("mistral" , "mistral_pretrain.py" ),
97+ "mixtral" : ("mixtral" , "mixtral_pretrain.py" ),
98+ }
99+
100+ if "model" not in recipe :
101+ raise ValueError ("Supplied recipe does not contain required field model." )
102+ if "model_type" not in recipe ["model" ]:
103+ raise ValueError ("Supplied recipe does not contain required field model_type." )
104+ model_type = recipe ["model" ]["model_type" ]
105+ if model_type not in model_type_to_script :
106+ raise ValueError (f"Model type { model_type } not supported" )
107+
108+ script_dir = os .path .join (code_dir , "examples" , model_type_to_script [model_type ][0 ])
109+ script = model_type_to_script [model_type ][1 ]
110+ shutil .copyfile (os .path .join (script_dir , script ), os .path .join (source_dir , script ))
111+ return script
112+
113+
114+ def _get_training_recipe_trainium_script (code_dir , source_dir ):
115+ """Return path to training script (entry point) when running a trainium recipe."""
116+ script_dir = os .path .join (code_dir , "examples" )
117+ script = "training_orchestrator.py"
118+ shutil .copytree (script_dir , source_dir , dirs_exist_ok = True )
119+ return script
120+
121+
63122class PyTorch (Framework ):
64123 """Handle end-to-end training and deployment of custom PyTorch code."""
65124
@@ -294,13 +353,13 @@ def __init__(
294353 if training_recipe is not None :
295354 if entry_point is not None :
296355 logger .warning ("Argument entry_point will be ignored with training_recipe." )
297- if source_dir is not None :
298- logger .warning ("Argument source_dir will be ignored with training_recipe." )
299356 if hyperparameters is not None :
300357 logger .warning ("Argument hyperparameters will be ignored with training recipe." )
301358 if distribution is not None :
302359 logger .warning ("Argument distribution will be ignored with training_recipe." )
303- args = self ._setup_for_training_recipe (training_recipe , recipe_overrides , kwargs )
360+ args = self ._setup_for_training_recipe (
361+ training_recipe , recipe_overrides , source_dir , kwargs
362+ )
304363 entry_point = args ["entry_point" ]
305364 source_dir = args ["source_dir" ]
306365 hyperparameters = args ["hyperparameters" ]
@@ -538,7 +597,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
538597 return init_params
539598
540599 @classmethod
541- def _setup_for_training_recipe (cls , training_recipe , recipe_overrides , kwargs ):
600+ def _setup_for_training_recipe (cls , training_recipe , recipe_overrides , source_dir , kwargs ):
542601 """Performs training recipe specific setup and returns recipe specific args.
543602
544603 Updates kwargs and returns a dictionary of args to use for estimator
@@ -549,7 +608,9 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
549608 training_recipe (str): A recipe which is a local file path, a url or a
550609 sagemaker training recipe.
551610 recipe_overrides (Dict): Dictionary specifying key values to override in the
552- training_recipe.
611+ source_dir (str): Path (absolute, or relative) to a directory where to copy
612+ the scripts for training recipe. requirements.txt can also
613+ go here.
553614 kwargs (dict): Dictionary of args used for estimator initializaiton.
554615 Returns:
555616 dict containing arg values for estimator initialization and setup.
@@ -559,6 +620,7 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
559620 region_name = kwargs .get ("sagemaker_session" ).boto_region_name
560621 else :
561622 region_name = Session ().boto_region_name
623+
562624 training_recipes_cfg_filename = os .path .join (
563625 os .path .dirname (__file__ ), "training_recipes.json"
564626 )
@@ -567,12 +629,16 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
567629
568630 if recipe_overrides is None :
569631 recipe_overrides = dict ()
570- cls .recipe_train_dir = tempfile .TemporaryDirectory (prefix = "training_" )
571- cls .recipe_launcher_dir = tempfile .TemporaryDirectory (prefix = "launcher_" )
632+ recipe_train_dir = tempfile .TemporaryDirectory (prefix = "training_" )
633+ recipe_launcher_dir = tempfile .TemporaryDirectory (prefix = "launcher_" )
634+ args = dict ()
635+ if source_dir is None :
636+ args ["source_dir" ] = "."
637+ else :
638+ args ["source_dir" ] = source_dir
572639
573- temp_local_recipe = tempfile .NamedTemporaryFile (
574- prefix = "recipe_original" , suffix = ".yaml"
575- ).name
640+ recipe_name = os .path .splitext (os .path .basename (training_recipe ))[0 ]
641+ temp_local_recipe = tempfile .NamedTemporaryFile (prefix = recipe_name , suffix = ".yaml" ).name
576642 if training_recipe .endswith (".yaml" ):
577643 if os .path .isfile (training_recipe ):
578644 shutil .copy (training_recipe , temp_local_recipe )
@@ -587,9 +653,9 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
587653 launcher_repo = os .environ .get (
588654 "training_launcher_git" , None
589655 ) or training_recipes_cfg .get ("launcher_repo" )
590- _run_clone_command (launcher_repo , cls . recipe_launcher_dir .name )
656+ _run_clone_command (launcher_repo , recipe_launcher_dir .name )
591657 recipe = os .path .join (
592- cls . recipe_launcher_dir .name ,
658+ recipe_launcher_dir .name ,
593659 "recipes_collection" ,
594660 "recipes" ,
595661 training_recipe + ".yaml" ,
@@ -628,44 +694,19 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
628694 )
629695 kwargs ["instance_count" ] = recipe ["trainer" ]["num_nodes" ]
630696
631- args = dict ()
632697 # [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve
633698 # to retrieve the image uri below before we go GA.
634699 if device_type == "gpu" :
635700 adapter_repo = os .environ .get ("training_adapter_git" , None ) or training_recipes_cfg .get (
636701 "adapter_repo"
637702 )
638- _run_clone_command (adapter_repo , cls .recipe_train_dir .name )
639-
640- model_type_to_entry = {
641- "llama_v3" : ("llama" , "llama_pretrain.py" ),
642- "mistral" : ("mistral" , "mistral_pretrain.py" ),
643- "mixtral" : ("mixtral" , "mixtral_pretrain.py" ),
644- }
645-
646- if "model" not in recipe :
647- raise ValueError ("Supplied recipe does not contain required field model." )
648- if "model_type" not in recipe ["model" ]:
649- raise ValueError ("Supplied recipe does not contain required field model_type." )
650- model_type = recipe ["model" ]["model_type" ]
651- if model_type not in model_type_to_entry :
652- raise ValueError (f"Model type { model_type } not supported" )
653-
654- args ["source_dir" ] = os .path .join (
655- cls .recipe_train_dir .name , "examples" , model_type_to_entry [model_type ][0 ]
703+ _run_clone_command (adapter_repo , recipe_train_dir .name )
704+ script = _get_training_recipe_gpu_script (
705+ recipe_train_dir .name , recipe , args ["source_dir" ]
706+ )
707+ args ["default_image_uri" ] = _get_training_recipe_image_uri (
708+ training_recipes_cfg .get ("gpu_image" ), region_name
656709 )
657- args ["entry_point" ] = model_type_to_entry [model_type ][1 ]
658- gpu_image_cfg = training_recipes_cfg .get ("gpu_image" )
659- if isinstance (gpu_image_cfg , str ):
660- args ["default_image_uri" ] = gpu_image_cfg
661- else :
662- args ["default_image_uri" ] = retrieve (
663- gpu_image_cfg .get ("framework" ),
664- region = region_name ,
665- version = gpu_image_cfg .get ("version" ),
666- image_scope = "training" ,
667- ** gpu_image_cfg .get ("additional_args" ),
668- )
669710 smp_options = {
670711 "enabled" : True ,
671712 "parameters" : {
@@ -677,55 +718,45 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
677718 "torch_distributed" : {"enabled" : True },
678719 }
679720 elif device_type == "trainium" :
680- _run_clone_command (
681- training_recipes_cfg .get ("neuron_dist_repo" ), cls .recipe_train_dir .name
721+ _run_clone_command (training_recipes_cfg .get ("neuron_dist_repo" ), recipe_train_dir .name )
722+ script = _get_training_recipe_trainium_script (recipe_train_dir .name , args ["source_dir" ])
723+ args ["default_image_uri" ] = _get_training_recipe_image_uri (
724+ training_recipes_cfg .get ("neuron_image" ), region_name
682725 )
683- args ["source_dir" ] = os .path .join (cls .recipe_train_dir .name , "examples" )
684- args ["entry_point" ] = "training_orchestrator.py"
685- neuron_image_cfg = training_recipes_cfg .get ("neuron_image" )
686- if isinstance (neuron_image_cfg , str ):
687- args ["default_image_uri" ] = neuron_image_cfg
688- else :
689- args ["default_image_uri" ] = retrieve (
690- neuron_image_cfg .get ("framework" ),
691- region = region_name ,
692- version = neuron_image_cfg .get ("version" ),
693- image_scope = "training" ,
694- ** neuron_image_cfg .get ("additional_args" ),
695- )
696726 args ["distribution" ] = {
697727 "torch_distributed" : {"enabled" : True },
698728 }
699729 else :
700730 raise ValueError (
701731 f"Devices of type { device_type } are not supported with training recipes."
702732 )
733+ args ["entry_point" ] = os .path .basename (script )
734+
735+ recipe_train_dir .cleanup ()
736+ recipe_launcher_dir .cleanup ()
703737
704738 if "container" in recipe and not recipe ["container" ]:
705739 logger .warning (
706740 "Ignoring container from training_recipe. Use image_uri arg for estimator."
707741 )
708742
709- if not OmegaConf .has_resolver ("multiply" ):
710- OmegaConf .register_new_resolver ("multiply" , lambda x , y : x * y , replace = True )
711- if not OmegaConf .has_resolver ("divide_ceil" ):
712- OmegaConf .register_new_resolver (
713- "divide_ceil" , lambda x , y : int (math .ceil (x / y )), replace = True
714- )
715- if not OmegaConf .has_resolver ("divide_floor" ):
716- OmegaConf .register_new_resolver (
717- "divide_floor" , lambda x , y : int (math .floor (x / y )), replace = True
718- )
719- if not OmegaConf .has_resolver ("add" ):
720- OmegaConf .register_new_resolver ("add" , lambda * numbers : sum (numbers ))
743+ _setup_omegaconf_resolvers ()
721744 final_recipe = _try_resolve_recipe (recipe )
722745 if final_recipe is None :
723746 final_recipe = _try_resolve_recipe (recipe , "recipes" )
724747 if final_recipe is None :
725748 final_recipe = _try_resolve_recipe (recipe , "training" )
726749 if final_recipe is None :
727750 raise RuntimeError ("Could not resolve provided recipe." )
728- OmegaConf .save (config = final_recipe , f = os .path .join (args ["source_dir" ], "recipe.yaml" ))
729- args ["hyperparameters" ] = {"config-path" : "." , "config-name" : "recipe.yaml" }
751+ cls .training_recipe_file = tempfile .NamedTemporaryFile (
752+ dir = args ["source_dir" ],
753+ prefix = recipe_name + "_" ,
754+ suffix = ".yaml" ,
755+ )
756+ OmegaConf .save (config = final_recipe , f = cls .training_recipe_file .name )
757+ args ["hyperparameters" ] = {
758+ "config-path" : "." ,
759+ "config-name" : os .path .basename (cls .training_recipe_file .name ),
760+ }
730761
731762 return args
0 commit comments