1414from __future__ import absolute_import
1515
1616import logging
17+ import os
18+ import shutil
19+ import tempfile
1720from typing import Union , Optional , Dict
21+ from urllib .request import urlretrieve
1822
23+ from omegaconf import OmegaConf
1924from packaging .version import Version
2025
2126from sagemaker .estimator import Framework , EstimatorBase
2732 validate_distribution ,
2833 profiler_config_deprecation_warning ,
2934)
35+ from sagemaker .git_utils import _run_clone_command
3036from sagemaker .pytorch import defaults
3137from sagemaker .pytorch .model import PyTorchModel
3238from sagemaker .pytorch .training_compiler .config import TrainingCompilerConfig
@@ -44,16 +50,22 @@ class PyTorch(Framework):
4450 LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled"
4551 INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type"
4652
53+ # [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve
54+ # to retrieve the image uri below before GA.
55+ SM_ADAPTER_REPO = "git@github.com:aws/private-sagemaker-training-adapter-for-nemo-staging.git"
56+ SM_LAUNCHER_REPO = "git@github.com:aws/private-sagemaker-training-launcher-staging.git"
57+
4758 def __init__ (
4859 self ,
49- entry_point : Union [str , PipelineVariable ],
60+ entry_point : Optional [ Union [str , PipelineVariable ]] = None ,
5061 framework_version : Optional [str ] = None ,
5162 py_version : Optional [str ] = None ,
5263 source_dir : Optional [Union [str , PipelineVariable ]] = None ,
5364 hyperparameters : Optional [Dict [str , Union [str , PipelineVariable ]]] = None ,
5465 image_uri : Optional [Union [str , PipelineVariable ]] = None ,
5566 distribution : Optional [Dict ] = None ,
5667 compiler_config : Optional [TrainingCompilerConfig ] = None ,
68+ training_recipe : Optional [str ] = None ,
5769 ** kwargs ,
5870 ):
5971 """This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment.
@@ -246,6 +258,10 @@ def __init__(
246258 compiler_config (:class:`~sagemaker.pytorch.TrainingCompilerConfig`):
247259 Configures SageMaker Training Compiler to accelerate training.
248260
261+ training_recipe (str): Training recipe to use. This is a local file path,
262+ a url to fetch, or a recipe provided by Saagemaker
263+ training.
264+
249265 **kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework`
250266 constructor.
251267
@@ -255,6 +271,26 @@ def __init__(
255271 :class:`~sagemaker.estimator.Framework` and
256272 :class:`~sagemaker.estimator.EstimatorBase`.
257273 """
274+ if training_recipe is not None :
275+ if entry_point is not None :
276+ logger .warning ("Argument entry_point will be ignored with training_recipe." )
277+ if source_dir is not None :
278+ logger .warning ("Argument source_dir will be ignored with training_recipe." )
279+ if hyperparameters is not None :
280+ logger .warning ("Argument hyperparameters will be ignored with training recipe." )
281+ if distribution is not None :
282+ logger .warning ("Argument distribution will be ignored with training_recipe." )
283+ args = self ._setup_for_training_recipe (training_recipe , kwargs )
284+ entry_point = args ["entry_point" ]
285+ source_dir = args ["source_dir" ]
286+ hyperparameters = args ["hyperparameters" ]
287+ if image_uri is None :
288+ image_uri = args ["image_uri" ]
289+ distribution = args ["distribution" ]
290+ elif entry_point is None :
291+ raise ValueError (
292+ "Argument entry_point must be set when training_recipe is not provided"
293+ )
258294 validate_version_or_image_args (framework_version , py_version , image_uri )
259295 if py_version == "py2" :
260296 logger .warning (
@@ -480,3 +516,128 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
480516 )
481517
482518 return init_params
519+
520+ @classmethod
521+ def _setup_for_training_recipe (cls , training_recipe , kwargs ):
522+ """Performs training recipe specific setup and returns recipe specific args.
523+
524+ Updates kwargs and returns a dictionary of args to use for estimator
525+ initialization and setup when using a training recipe. Updates the paths in
526+ the recipe for Sagemaker Jobs environment.
527+
528+ Args:
529+ training_recipe (str): A recipe which is a local file path, a url or a
530+ sagemaker training recipe.
531+ kwargs (dict): Dictionary of args used for estimator initializaiton.
532+ Returns:
533+ dict containing arg values for estimator initialization and setup.
534+
535+ """
536+ cls .recipe_train_dir = tempfile .TemporaryDirectory (prefix = "training_" )
537+ cls .recipe_launcher_dir = tempfile .TemporaryDirectory (prefix = "launcher_" )
538+
539+ adapter_repo = os .environ .get ("training_adapter_git" , None ) or cls .SM_ADAPTER_REPO
540+ _run_clone_command (adapter_repo , cls .recipe_train_dir .name )
541+ source_dir = os .path .join (cls .recipe_train_dir .name , "scripts" )
542+
543+ model_type_to_script = {"llama_v3" : "llama_pretrain.py" }
544+
545+ args = {"source_dir" : source_dir }
546+ local_recipe_path = os .path .join (source_dir , "recipe.yaml" )
547+ if training_recipe .endswith (".yaml" ):
548+ if os .path .isfile (training_recipe ):
549+ shutil .copy (training_recipe , local_recipe_path )
550+ else :
551+ try :
552+ urlretrieve (training_recipe , local_recipe_path )
553+ except Exception as e :
554+ raise ValueError (
555+ f"Could not fetch the provided recipe { training_recipe } : exception { str (e )} "
556+ )
557+ else :
558+ launcher_repo = os .environ .get ("training_launcher_git" , None ) or cls .SM_LAUNCHER_REPO
559+ _run_clone_command (launcher_repo , cls .recipe_launcher_dir .name )
560+ recipe = os .path .join (
561+ cls .recipe_launcher_dir .name ,
562+ "examples" ,
563+ "recipes" ,
564+ "training" ,
565+ training_recipe + ".yaml" ,
566+ )
567+ if os .path .isfile (recipe ):
568+ shutil .copy (recipe , local_recipe_path )
569+ else :
570+ raise ValueError (f"Recipe { training_recipe } not found." )
571+
572+ recipe = OmegaConf .load (local_recipe_path )
573+
574+ if "model" not in recipe :
575+ raise ValueError ("Supplied recipe does not contain required field model." )
576+ if "model_type" not in recipe ["model" ]:
577+ raise ValueError ("Supplied recipe does not contain required field model_type." )
578+ model_type = recipe ["model" ]["model_type" ]
579+ if model_type not in model_type_to_script :
580+ raise ValueError (f"Model type { model_type } not supported" )
581+ args ["model_type" ] = model_type
582+ args ["entry_point" ] = model_type_to_script [model_type ]
583+ args ["hyperparameters" ] = {"config-path" : "." , "config-name" : "recipe.yaml" }
584+
585+ if "trainer" not in recipe :
586+ raise ValueError ("Supplied recipe does not contain required field trainer." )
587+ if "instance_count" in kwargs and "num_nodes" in recipe ["trainer" ]:
588+ logger .warning (
589+ "Using instance_count argument to estimator to set number "
590+ " of nodes. Ignoring trainer -> num_nodes in recipe."
591+ )
592+ if "instance_count" not in kwargs :
593+ if "num_nodes" not in recipe ["trainer" ]:
594+ raise ValueError (
595+ "Must set either instance_count argument for estimator or"
596+ "set trainer -> num_nodes in recipe."
597+ )
598+ kwargs ["instance_count" ] = recipe ["trainer" ]["num_nodes" ]
599+
600+ if "accelerator" not in recipe ["trainer" ]:
601+ raise ValueError (
602+ "Supplied recipe does not contain required field trainer -> accelerator."
603+ )
604+ accelerator = recipe ["trainer" ]["accelerator" ]
605+ if accelerator == "gpu" :
606+ # [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve
607+ # to retrieve the image uri below before we go GA.
608+ args ["image_uri" ] = (
609+ "855988369404.dkr.ecr.us-west-2.amazonaws.com/chinmayee-dev:adaptor_sept9_v1"
610+ )
611+ smp_options = {
612+ "enabled" : True ,
613+ "parameters" : {
614+ "placement_strategy" : "cluster" ,
615+ },
616+ }
617+ args ["distribution" ] = {
618+ "smdistributed" : {"modelparallel" : smp_options },
619+ "torch_distributed" : {"enabled" : True },
620+ }
621+ else :
622+ raise ValueError (f"Accelerator type { accelerator } not yet supported." )
623+
624+ try :
625+ recipe ["run" ]["results_dir" ] = "/opt/ml/model/"
626+ recipe ["exp_manager" ]["exp_dir" ] = "/opt/ml/model/"
627+ recipe ["exp_manager" ]["explicit_log_dir" ] = "/opt/ml/output/tensorboard"
628+ recipe ["exp_manager" ]["checkpoint_dir" ] = "/opt/ml/checkpoints"
629+ recipe ["model" ]["data" ]["train_dir" ] = ["/opt/ml/input/data/train" ]
630+ recipe ["model" ]["data" ]["val_dir" ] = ["/opt/ml/input/data/val" ]
631+ except KeyError as e :
632+ raise RuntimeError (
633+ f"Error when trying to update recipe for sagemaker jobs with key { str (e )} ."
634+ )
635+
636+ if "container" in recipe and not recipe ["container" ]:
637+ logger .warning (
638+ "Ignoring container from training_recipe. Use image_uri arg for estimator."
639+ )
640+
641+ OmegaConf .save (config = recipe , f = local_recipe_path )
642+
643+ return args
0 commit comments