@@ -43,7 +43,7 @@ class MfnAppTextFormat():
4343mfntestfailed = MfnAppTextFormat .STYLE_BOLD + MfnAppTextFormat .COLOR_RED + 'FAILED' + MfnAppTextFormat .END + MfnAppTextFormat .END
4444
4545class MFNTest ():
46- def __init__ (self , test_name = None , timeout = None , workflow_filename = None , new_user = False , delete_user = False ):
46+ def __init__ (self , test_name = None , timeout = None , workflow_filename = None , new_user = False , delete_user = False , num_gpu = None ):
4747
4848 self ._settings = self ._get_settings ()
4949
@@ -84,6 +84,9 @@ def __init__(self, test_name=None, timeout=None, workflow_filename=None, new_use
8484 if timeout is not None :
8585 self ._settings ["timeout" ] = timeout
8686
87+ if num_gpu is not None :
88+ self ._settings ["num_gpu" ] = num_gpu
89+
8790 self ._log_clear_timestamp = int (time .time () * 1000.0 * 1000.0 )
8891
8992 # will be the deployed workflow object in self._client
@@ -190,6 +193,9 @@ def _get_resource_info_map(self, workflow_description=None, resource_info_map=No
190193 resource_info ["resource_req_filename" ] = "requirements/" + resource_ref + "_requirements.txt"
191194 resource_info ["resource_env_filename" ] = "environment_variables/" + resource_ref + "_environment_variables.txt"
192195 resource_info_map [resource_ref ] = resource_info
196+ resource_info_map [resource_ref ]['num_gpu' ] = self ._settings ['num_gpu' ]
197+ print ("resource_info_map: " + json .dumps (resource_info_map ))
198+
193199
194200 elif "States" in workflow_description :
195201 states = workflow_description ["States" ]
@@ -203,6 +209,9 @@ def _get_resource_info_map(self, workflow_description=None, resource_info_map=No
203209 resource_info ["resource_req_filename" ] = "requirements/" + resource_name + "_requirements.txt"
204210 resource_info ["resource_env_filename" ] = "environment_variables/" + resource_name + "_environment_variables.txt"
205211 resource_info_map [resource_name ] = resource_info
212+ resource_info_map [resource_name ]['num_gpu' ] = self ._settings ['num_gpu' ]
213+ print ("resource_info_map: " + json .dumps (resource_info_map ))
214+
206215
207216 if "Type" in state and state ["Type" ] == "Parallel" :
208217 branches = state ['Branches' ]
@@ -219,10 +228,6 @@ def _get_resource_info_map(self, workflow_description=None, resource_info_map=No
219228 print ("ERROR: invalid workflow description." )
220229 assert False
221230
222- #resource_info_map[resource_name]['on_gpu'] = True
223-
224- #print("resource_info_map: " + str(resource_info_map))
225-
226231 return resource_info_map
227232
228233 def _delete_resource_if_existing (self , existing_resources , resource_name ):
@@ -299,7 +304,7 @@ def deploy_workflow(self):
299304 try :
300305 wf = self ._client .add_workflow (self ._workflow_name )
301306 wf .json = json .dumps (self ._workflow_description )
302- wf .deploy (self ._settings ["timeout" ])
307+ wf .deploy (self ._settings ["timeout" ]) #, num_gpu=self._settings['num_gpu'])
303308 self ._workflow = wf
304309 if self ._workflow .status != "failed" :
305310 print ("MFN workflow " + self ._workflow_name + " deployed." )
0 commit comments