2929from ..utils import pretty_print
3030from ..utils .generate_yaml import (
3131 generate_appwrapper ,
32+ head_worker_gpu_count_from_cluster ,
3233)
3334from ..utils .kube_api_helpers import _kube_api_error_handling
3435from ..utils .generate_yaml import is_openshift_cluster
@@ -118,48 +119,7 @@ def create_app_wrapper(self):
118119 f"Namespace { self .config .namespace } is of type { type (self .config .namespace )} . Check your Kubernetes Authentication."
119120 )
120121
121- # Before attempting to create the cluster AW, let's evaluate the ClusterConfig
122-
123- name = self .config .name
124- namespace = self .config .namespace
125- head_cpus = self .config .head_cpus
126- head_memory = self .config .head_memory
127- num_head_gpus = self .config .num_head_gpus
128- worker_cpu_requests = self .config .worker_cpu_requests
129- worker_cpu_limits = self .config .worker_cpu_limits
130- worker_memory_requests = self .config .worker_memory_requests
131- worker_memory_limits = self .config .worker_memory_limits
132- num_worker_gpus = self .config .num_worker_gpus
133- workers = self .config .num_workers
134- template = self .config .template
135- image = self .config .image
136- appwrapper = self .config .appwrapper
137- env = self .config .envs
138- image_pull_secrets = self .config .image_pull_secrets
139- write_to_file = self .config .write_to_file
140- local_queue = self .config .local_queue
141- labels = self .config .labels
142- return generate_appwrapper (
143- name = name ,
144- namespace = namespace ,
145- head_cpus = head_cpus ,
146- head_memory = head_memory ,
147- num_head_gpus = num_head_gpus ,
148- worker_cpu_requests = worker_cpu_requests ,
149- worker_cpu_limits = worker_cpu_limits ,
150- worker_memory_requests = worker_memory_requests ,
151- worker_memory_limits = worker_memory_limits ,
152- num_worker_gpus = num_worker_gpus ,
153- workers = workers ,
154- template = template ,
155- image = image ,
156- appwrapper = appwrapper ,
157- env = env ,
158- image_pull_secrets = image_pull_secrets ,
159- write_to_file = write_to_file ,
160- local_queue = local_queue ,
161- labels = labels ,
162- )
122+ return generate_appwrapper (self )
163123
164124 # creates a new cluster with the provided or default spec
165125 def up (self ):
@@ -305,7 +265,7 @@ def status(
305265
306266 if print_to_console :
307267 # overriding the number of gpus with requested
308- cluster .worker_gpu = self . config . num_worker_gpus
268+ _ , cluster .worker_gpu = head_worker_gpu_count_from_cluster ( self )
309269 pretty_print .print_cluster_status (cluster )
310270 elif print_to_console :
311271 if status == CodeFlareClusterStatus .UNKNOWN :
@@ -443,6 +403,29 @@ def job_logs(self, job_id: str) -> str:
443403 """
444404 return self .job_client .get_job_logs (job_id )
445405
406+ @staticmethod
407+ def _head_worker_extended_resources_from_rc_dict (rc : Dict ) -> Tuple [dict , dict ]:
408+ head_extended_resources , worker_extended_resources = {}, {}
409+ for resource in rc ["spec" ]["workerGroupSpecs" ][0 ]["template" ]["spec" ][
410+ "containers"
411+ ][0 ]["resources" ]["limits" ].keys ():
412+ if resource in ["memory" , "cpu" ]:
413+ continue
414+ worker_extended_resources [resource ] = rc ["spec" ]["workerGroupSpecs" ][0 ][
415+ "template"
416+ ]["spec" ]["containers" ][0 ]["resources" ]["limits" ][resource ]
417+
418+ for resource in rc ["spec" ]["headGroupSpec" ]["template" ]["spec" ]["containers" ][
419+ 0
420+ ]["resources" ]["limits" ].keys ():
421+ if resource in ["memory" , "cpu" ]:
422+ continue
423+ head_extended_resources [resource ] = rc ["spec" ]["headGroupSpec" ]["template" ][
424+ "spec"
425+ ]["containers" ][0 ]["resources" ]["limits" ][resource ]
426+
427+ return head_extended_resources , worker_extended_resources
428+
446429 def from_k8_cluster_object (
447430 rc ,
448431 appwrapper = True ,
@@ -456,6 +439,11 @@ def from_k8_cluster_object(
456439 else []
457440 )
458441
442+ (
443+ head_extended_resources ,
444+ worker_extended_resources ,
445+ ) = Cluster ._head_worker_extended_resources_from_rc_dict (rc )
446+
459447 cluster_config = ClusterConfiguration (
460448 name = rc ["metadata" ]["name" ],
461449 namespace = rc ["metadata" ]["namespace" ],
@@ -473,11 +461,8 @@ def from_k8_cluster_object(
473461 worker_memory_limits = rc ["spec" ]["workerGroupSpecs" ][0 ]["template" ]["spec" ][
474462 "containers"
475463 ][0 ]["resources" ]["limits" ]["memory" ],
476- num_worker_gpus = int (
477- rc ["spec" ]["workerGroupSpecs" ][0 ]["template" ]["spec" ]["containers" ][0 ][
478- "resources"
479- ]["limits" ]["nvidia.com/gpu" ]
480- ),
464+ worker_extended_resource_requests = worker_extended_resources ,
465+ head_extended_resource_requests = head_extended_resources ,
481466 image = rc ["spec" ]["workerGroupSpecs" ][0 ]["template" ]["spec" ]["containers" ][
482467 0
483468 ]["image" ],
@@ -858,6 +843,11 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
858843 protocol = "https"
859844 dashboard_url = f"{ protocol } ://{ ingress .spec .rules [0 ].host } "
860845
846+ (
847+ head_extended_resources ,
848+ worker_extended_resources ,
849+ ) = Cluster ._head_worker_extended_resources_from_rc_dict (rc )
850+
861851 return RayCluster (
862852 name = rc ["metadata" ]["name" ],
863853 status = status ,
@@ -872,17 +862,15 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
872862 worker_cpu = rc ["spec" ]["workerGroupSpecs" ][0 ]["template" ]["spec" ]["containers" ][
873863 0
874864 ]["resources" ]["limits" ]["cpu" ],
875- worker_gpu = 0 , # hard to detect currently how many gpus, can override it with what the user asked for
865+ worker_extended_resources = worker_extended_resources ,
876866 namespace = rc ["metadata" ]["namespace" ],
877867 head_cpus = rc ["spec" ]["headGroupSpec" ]["template" ]["spec" ]["containers" ][0 ][
878868 "resources"
879869 ]["limits" ]["cpu" ],
880870 head_mem = rc ["spec" ]["headGroupSpec" ]["template" ]["spec" ]["containers" ][0 ][
881871 "resources"
882872 ]["limits" ]["memory" ],
883- head_gpu = rc ["spec" ]["headGroupSpec" ]["template" ]["spec" ]["containers" ][0 ][
884- "resources"
885- ]["limits" ]["nvidia.com/gpu" ],
873+ head_extended_resources = head_extended_resources ,
886874 dashboard = dashboard_url ,
887875 )
888876
@@ -907,12 +895,12 @@ def _copy_to_ray(cluster: Cluster) -> RayCluster:
907895 worker_mem_min = cluster .config .worker_memory_requests ,
908896 worker_mem_max = cluster .config .worker_memory_limits ,
909897 worker_cpu = cluster .config .worker_cpu_requests ,
910- worker_gpu = cluster .config .num_worker_gpus ,
898+ worker_extended_resources = cluster .config .worker_extended_resource_requests ,
911899 namespace = cluster .config .namespace ,
912900 dashboard = cluster .cluster_dashboard_uri (),
913901 head_cpus = cluster .config .head_cpus ,
914902 head_mem = cluster .config .head_memory ,
915- head_gpu = cluster .config .num_head_gpus ,
903+ head_extended_resources = cluster .config .head_extended_resource_requests ,
916904 )
917905 if ray .status == CodeFlareClusterStatus .READY :
918906 ray .status = RayClusterStatus .READY
0 commit comments