1414from __future__ import absolute_import
1515
1616import logging
17+ import math
1718import os
1819import shutil
1920import tempfile
2021from typing import Union , Optional , Dict
2122from urllib .request import urlretrieve
2223
23- from omegaconf import OmegaConf
24+ import omegaconf
25+ from omegaconf import OmegaConf , dictconfig
2426from packaging .version import Version
2527
2628from sagemaker .estimator import Framework , EstimatorBase
4244logger = logging .getLogger ("sagemaker" )
4345
4446
47+ def _try_resolve_recipe (recipe , key = None ):
48+ """Try to resolve recipe and return resolved recipe."""
49+ if key is not None :
50+ recipe = dictconfig .DictConfig ({key : recipe })
51+ try :
52+ OmegaConf .resolve (recipe )
53+ except omegaconf .errors .OmegaConfBaseException :
54+ return None
55+ if key is None :
56+ return recipe
57+ return recipe [key ]
58+
59+
4560class PyTorch (Framework ):
4661 """Handle end-to-end training and deployment of custom PyTorch code."""
4762
@@ -551,7 +566,9 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
551566 cls .recipe_train_dir = tempfile .TemporaryDirectory (prefix = "training_" )
552567 cls .recipe_launcher_dir = tempfile .TemporaryDirectory (prefix = "launcher_" )
553568
554- temp_local_recipe = tempfile .NamedTemporaryFile (prefix = "recipe" ).name
569+ temp_local_recipe = tempfile .NamedTemporaryFile (
570+ prefix = "recipe_original" , suffix = ".yaml"
571+ ).name
555572 if training_recipe .endswith (".yaml" ):
556573 if os .path .isfile (training_recipe ):
557574 shutil .copy (training_recipe , temp_local_recipe )
@@ -567,7 +584,7 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
567584 _run_clone_command (launcher_repo , cls .recipe_launcher_dir .name )
568585 recipe = os .path .join (
569586 cls .recipe_launcher_dir .name ,
570- "recipes-collection " ,
587+ "recipes_collection " ,
571588 "recipes" ,
572589 "training" ,
573590 training_recipe + ".yaml" ,
@@ -578,6 +595,7 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
578595 raise ValueError (f"Recipe { training_recipe } not found." )
579596
580597 recipe = OmegaConf .load (temp_local_recipe )
598+ os .unlink (temp_local_recipe )
581599
582600 if "instance_type" not in kwargs :
583601 raise ValueError ("Must pass instance type to estimator when using training recipes." )
@@ -662,7 +680,26 @@ def _setup_for_training_recipe(cls, training_recipe, recipe_overrides, kwargs):
662680 "Ignoring container from training_recipe. Use image_uri arg for estimator."
663681 )
664682
665- OmegaConf .save (config = recipe , f = os .path .join (args ["source_dir" ], "recipe.yaml" ))
683+ if not OmegaConf .has_resolver ("multiply" ):
684+ OmegaConf .register_new_resolver ("multiply" , lambda x , y : x * y , replace = True )
685+ if not OmegaConf .has_resolver ("divide_ceil" ):
686+ OmegaConf .register_new_resolver (
687+ "divide_ceil" , lambda x , y : int (math .ceil (x / y )), replace = True
688+ )
689+ if not OmegaConf .has_resolver ("divide_floor" ):
690+ OmegaConf .register_new_resolver (
691+ "divide_floor" , lambda x , y : int (math .floor (x / y )), replace = True
692+ )
693+ if not OmegaConf .has_resolver ("add" ):
694+ OmegaConf .register_new_resolver ("add" , lambda * numbers : sum (numbers ))
695+ final_recipe = _try_resolve_recipe (recipe )
696+ if final_recipe is None :
697+ final_recipe = _try_resolve_recipe (recipe , "recipes" )
698+ if final_recipe is None :
699+ final_recipe = _try_resolve_recipe (recipe , "training" )
700+ if final_recipe is None :
701+ raise RuntimeError ("Could not resolve provided recipe." )
702+ OmegaConf .save (config = final_recipe , f = os .path .join (args ["source_dir" ], "recipe.yaml" ))
666703 args ["hyperparameters" ] = {"config-path" : "." , "config-name" : "recipe.yaml" }
667704
668705 return args
0 commit comments