@@ -107,35 +107,41 @@ def update_priority(yaml, item, dispatch_priority, priority_val):
107107
108108
109109def update_custompodresources (
110- item , min_cpu , max_cpu , min_memory , max_memory , gpu , workers
111- ):
110+ item , min_cpu , max_cpu , min_memory , max_memory , gpu , workers , head_cpus , head_memory , head_gpus ):
112111 if "custompodresources" in item .keys ():
113112 custompodresources = item .get ("custompodresources" )
114113 for i in range (len (custompodresources )):
114+ resource = custompodresources [i ]
115115 if i == 0 :
116116 # Leave head node resources as template default
117- continue
118- resource = custompodresources [i ]
119- for k , v in resource .items ():
120- if k == "replicas" and i == 1 :
121- resource [k ] = workers
122- if k == "requests" or k == "limits" :
123- for spec , _ in v .items ():
124- if spec == "cpu" :
125- if k == "limits" :
126- resource [k ][spec ] = max_cpu
127- else :
128- resource [k ][spec ] = min_cpu
129- if spec == "memory" :
130- if k == "limits" :
131- resource [k ][spec ] = str (max_memory ) + "G"
132- else :
133- resource [k ][spec ] = str (min_memory ) + "G"
134- if spec == "nvidia.com/gpu" :
135- if i == 0 :
136- resource [k ][spec ] = 0
137- else :
138- resource [k ][spec ] = gpu
117+ resource ["requests" ]["cpu" ] = head_cpus
118+ resource ["limits" ]["cpu" ] = head_cpus
119+ resource ["requests" ]["memory" ] = str (head_memory ) + "G"
120+ resource ["limits" ]["memory" ] = str (head_memory ) + "G"
121+ resource ["requests" ]["nvidia.com/gpu" ] = head_gpus
122+ resource ["limits" ]["nvidia.com/gpu" ] = head_gpus
123+
124+ else :
125+ for k , v in resource .items ():
126+ if k == "replicas" and i == 1 :
127+ resource [k ] = workers
128+ if k == "requests" or k == "limits" :
129+ for spec , _ in v .items ():
130+ if spec == "cpu" :
131+ if k == "limits" :
132+ resource [k ][spec ] = max_cpu
133+ else :
134+ resource [k ][spec ] = min_cpu
135+ if spec == "memory" :
136+ if k == "limits" :
137+ resource [k ][spec ] = str (max_memory ) + "G"
138+ else :
139+ resource [k ][spec ] = str (min_memory ) + "G"
140+ if spec == "nvidia.com/gpu" :
141+ if i == 0 :
142+ resource [k ][spec ] = 0
143+ else :
144+ resource [k ][spec ] = gpu
139145 else :
140146 sys .exit ("Error: malformed template" )
141147
@@ -205,11 +211,15 @@ def update_nodes(
205211 instascale ,
206212 env ,
207213 image_pull_secrets ,
214+ head_cpus ,
215+ head_memory ,
216+ head_gpus ,
208217):
209218 if "generictemplate" in item .keys ():
210219 head = item .get ("generictemplate" ).get ("spec" ).get ("headGroupSpec" )
220+ head ["rayStartParams" ]["num_gpus" ] = str (int (head_gpus ))
221+
211222 worker = item .get ("generictemplate" ).get ("spec" ).get ("workerGroupSpecs" )[0 ]
212-
213223 # Head counts as first worker
214224 worker ["replicas" ] = workers
215225 worker ["minReplicas" ] = workers
@@ -225,7 +235,7 @@ def update_nodes(
225235 update_env (spec , env )
226236 if comp == head :
227237 # TODO: Eventually add head node configuration outside of template
228- continue
238+ update_resources ( spec , head_cpus , head_cpus , head_memory , head_memory , head_gpus )
229239 else :
230240 update_resources (spec , min_cpu , max_cpu , min_memory , max_memory , gpu )
231241
@@ -350,6 +360,9 @@ def write_user_appwrapper(user_yaml, output_file_name):
350360def generate_appwrapper (
351361 name : str ,
352362 namespace : str ,
363+ head_cpus : int ,
364+ head_memory : int ,
365+ head_gpus : int ,
353366 min_cpu : int ,
354367 max_cpu : int ,
355368 min_memory : int ,
@@ -375,8 +388,7 @@ def generate_appwrapper(
375388 update_labels (user_yaml , instascale , instance_types )
376389 update_priority (user_yaml , item , dispatch_priority , priority_val )
377390 update_custompodresources (
378- item , min_cpu , max_cpu , min_memory , max_memory , gpu , workers
379- )
391+ item , min_cpu , max_cpu , min_memory , max_memory , gpu , workers , head_cpus , head_memory , head_gpus )
380392 update_nodes (
381393 item ,
382394 appwrapper_name ,
@@ -390,6 +402,9 @@ def generate_appwrapper(
390402 instascale ,
391403 env ,
392404 image_pull_secrets ,
405+ head_cpus ,
406+ head_memory ,
407+ head_gpus ,
393408 )
394409 update_dashboard_route (route_item , cluster_name , namespace )
395410 if local_interactive :
0 commit comments