@@ -76,6 +76,96 @@ def validate_source_dir(script, directory):
7676 return True
7777
7878
79+ def get_mp_parameters (distribution ):
80+ """Get the model parallelism parameters provided by the user
81+
82+ Args:
83+ distribution: distribution dictionary defined by the user
84+
85+ Returns:
86+ params: dictionary containing model parallelism parameters
87+ to be used for training
88+ """
89+ try :
90+ mp_dict = distribution ["smdistributed" ]["modelparallel" ]
91+ except KeyError :
92+ mp_dict = {}
93+ if mp_dict .get ("enabled" , False ) is True :
94+ params = mp_dict .get ("parameters" , {})
95+ validate_mp_config (params )
96+ return params
97+ return None
98+
99+
100+ def validate_mp_config (config ):
101+ """Validate the configuration dictionary for model parallelism.
102+
103+ Args:
104+ config (dict): Dictionary holding configuration keys and values.
105+
106+ Raises:
107+ ValueError: If any of the keys have incorrect values.
108+ """
109+
110+ if "partitions" not in config :
111+ raise ValueError ("'partitions' is a required parameter." )
112+
113+ def validate_positive (key ):
114+ try :
115+ if not isinstance (config [key ], int ) or config [key ] < 1 :
116+ raise ValueError (f"The number of { key } must be a positive integer." )
117+ except KeyError :
118+ pass
119+
120+ def validate_in (key , vals ):
121+ try :
122+ if config [key ] not in vals :
123+ raise ValueError (f"{ key } must be a value in: { vals } ." )
124+ except KeyError :
125+ pass
126+
127+ def validate_bool (keys ):
128+ validate_in (keys , [True , False ])
129+
130+ validate_in ("pipeline" , ["simple" , "interleaved" , "_only_forward" ])
131+ validate_in ("placement_strategy" , ["spread" , "cluster" ])
132+ validate_in ("optimize" , ["speed" , "memory" ])
133+
134+ for key in ["microbatches" , "partitions" ]:
135+ validate_positive (key )
136+
137+ for key in ["auto_partition" , "contiguous" , "load_partition" , "horovod" , "ddp" ]:
138+ validate_bool (key )
139+
140+ if "partition_file" in config and not isinstance (config .get ("partition_file" ), str ):
141+ raise ValueError ("'partition_file' must be a str." )
142+
143+ if config .get ("auto_partition" ) is False and "default_partition" not in config :
144+ raise ValueError ("default_partition must be supplied if auto_partition is set to False!" )
145+
146+ if "default_partition" in config and config ["default_partition" ] >= config ["partitions" ]:
147+ raise ValueError ("default_partition must be less than the number of partitions!" )
148+
149+ if "memory_weight" in config and (
150+ config ["memory_weight" ] > 1.0 or config ["memory_weight" ] < 0.0
151+ ):
152+ raise ValueError ("memory_weight must be between 0.0 and 1.0!" )
153+
154+ if "ddp_port" in config and "ddp" not in config :
155+ raise ValueError ("`ddp_port` needs `ddp` to be set as well" )
156+
157+ if "ddp_dist_backend" in config and "ddp" not in config :
158+ raise ValueError ("`ddp_dist_backend` needs `ddp` to be set as well" )
159+
160+ if "ddp_port" in config :
161+ if not isinstance (config ["ddp_port" ], int ) or config ["ddp_port" ] < 0 :
162+ value = config ["ddp_port" ]
163+ raise ValueError (f"Invalid port number { value } ." )
164+
165+ if config .get ("horovod" , False ) and config .get ("ddp" , False ):
166+ raise ValueError ("'ddp' and 'horovod' cannot be simultaneously enabled." )
167+
168+
79169def tar_and_upload_dir (
80170 session ,
81171 bucket ,
0 commit comments