4040
4141KUBERNETES_DATETIME_FORMAT : Final [str ] = "%Y-%m-%dT%H:%M:%SZ"
4242
43- DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION : Final [
44- str
45- ] = "kubernetes.dask.org/cooldown-until"
43+ DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION : Final [str ] = "kubernetes.dask.org/cooldown-until"
4644
4745# Load operator plugins from other packages
4846PLUGINS : list [Any ] = []
@@ -59,20 +57,15 @@ def _get_annotations(meta: kopf.Meta) -> dict[str, str]:
5957 return {
6058 annotation_key : annotation_value
6159 for annotation_key , annotation_value in meta .annotations .items ()
62- if not any (
63- annotation_key .startswith (namespace )
64- for namespace in _ANNOTATION_NAMESPACES_TO_IGNORE
65- )
60+ if not any (annotation_key .startswith (namespace ) for namespace in _ANNOTATION_NAMESPACES_TO_IGNORE )
6661 }
6762
6863
6964def _get_labels (meta : kopf .Meta ) -> dict [str , str ]:
7065 return {
7166 label_key : label_value
7267 for label_key , label_value in meta .labels .items ()
73- if not any (
74- label_key .startswith (namespace ) for namespace in _LABEL_NAMESPACES_TO_IGNORE
75- )
68+ if not any (label_key .startswith (namespace ) for namespace in _LABEL_NAMESPACES_TO_IGNORE )
7669 }
7770
7871
@@ -351,21 +344,15 @@ async def daskcluster_create_components(
351344 annotations .update (** scheduler_spec ["metadata" ]["annotations" ])
352345 if "labels" in scheduler_spec ["metadata" ]:
353346 labels .update (** scheduler_spec ["metadata" ]["labels" ])
354- data = build_scheduler_deployment_spec (
355- name , scheduler_spec .get ("spec" ), annotations , labels
356- )
347+ data = build_scheduler_deployment_spec (name , scheduler_spec .get ("spec" ), annotations , labels )
357348 kopf .adopt (data )
358349 scheduler_deployment = await Deployment (data , namespace = namespace )
359350 if not await scheduler_deployment .exists ():
360351 await scheduler_deployment .create ()
361- logger .info (
362- f"Scheduler deployment { scheduler_deployment .name } created in { namespace } ."
363- )
352+ logger .info (f"Scheduler deployment { scheduler_deployment .name } created in { namespace } ." )
364353
365354 # Create scheduler service
366- data = build_scheduler_service_spec (
367- name , scheduler_spec .get ("service" ), annotations , labels
368- )
355+ data = build_scheduler_service_spec (name , scheduler_spec .get ("service" ), annotations , labels )
369356 kopf .adopt (data )
370357 scheduler_service = await Service (data , namespace = namespace )
371358 if not await scheduler_service .exists ():
@@ -389,6 +376,92 @@ async def daskcluster_create_components(
389376
390377 patch .status ["phase" ] = "Pending"
391378
379+ @kopf .on .update ("daskcluster.kubernetes.dask.org" )
380+ async def daskcluster_update (
381+ spec : kopf .Spec ,
382+ status : kopf .Status ,
383+ meta : kopf .Meta ,
384+ name : str | None ,
385+ namespace : str | None ,
386+ diff : kopf .Diff ,
387+ patch : kopf .Patch ,
388+ logger : kopf .Logger ,
389+ ** __ : Any
390+ ):
391+ """When the DaskCluster resource is updated update all the components."""
392+ assert name
393+ assert namespace
394+ logger .info (f"Handling update for DaskCluster '{ name } '" )
395+
396+ scheduler_changed = any (op ['path' ].startswith ('/spec/scheduler' ) for op in diff )
397+ worker_changed = any (op ['path' ].startswith ('/spec/worker' ) for op in diff )
398+
399+ base_annotations = _get_annotations (meta )
400+ base_labels = _get_labels (meta )
401+
402+ if scheduler_changed :
403+ logger .info ("Scheduler spec changed, reconciling scheduler components." )
404+ scheduler_spec_part = spec .get ("scheduler" , {})
405+
406+ scheduler_annotations = base_annotations .copy ()
407+ scheduler_labels = base_labels .copy ()
408+ if "metadata" in scheduler_spec_part :
409+ scheduler_annotations .update (scheduler_spec_part .get ("metadata" , {}).get ("annotations" , {}))
410+ scheduler_labels .update (scheduler_spec_part .get ("metadata" , {}).get ("labels" , {}))
411+
412+ desired_dep_spec = build_scheduler_deployment_spec (
413+ name , scheduler_spec_part .get ("spec" ), scheduler_annotations , scheduler_labels
414+ )
415+ scheduler_deployment = await Deployment (
416+ SCHEDULER_NAME_TEMPLATE .format (cluster_name = name ), namespace = namespace # Use name
417+ )
418+ if await scheduler_deployment .exists ():
419+ await scheduler_deployment .patch (desired_dep_spec )
420+ logger .info (f"Scheduler deployment { scheduler_deployment .name } patched." )
421+ else :
422+ logger .warning (f"Scheduler deployment { scheduler_deployment .name } not found. Recreating." )
423+ kopf .adopt (desired_dep_spec , owner = meta )
424+ await scheduler_deployment .create (desired_dep_spec )
425+
426+ desired_svc_spec = build_scheduler_service_spec (
427+ name , scheduler_spec_part .get ("service" ), scheduler_annotations , scheduler_labels
428+ )
429+ scheduler_service = await Service (
430+ SCHEDULER_NAME_TEMPLATE .format (cluster_name = name ), namespace = namespace # Use name
431+ )
432+ if await scheduler_service .exists ():
433+ await scheduler_service .patch (desired_svc_spec )
434+ logger .info (f"Scheduler service { scheduler_service .name } patched." )
435+ else :
436+ logger .warning (f"Scheduler service { scheduler_service .name } not found. Recreating." )
437+ kopf .adopt (desired_svc_spec , owner = meta )
438+ await scheduler_service .create (desired_svc_spec )
439+
440+ if worker_changed :
441+ logger .info ("Worker spec changed, reconciling default worker group." )
442+ worker_spec_part = spec .get ("worker" , {})
443+
444+ worker_annotations = base_annotations .copy ()
445+ worker_labels = base_labels .copy ()
446+ if "metadata" in worker_spec_part :
447+ worker_annotations .update (worker_spec_part .get ("metadata" , {}).get ("annotations" , {}))
448+ worker_labels .update (worker_spec_part .get ("metadata" , {}).get ("labels" , {}))
449+
450+ desired_wg_spec = build_default_worker_group_spec (
451+ name , worker_spec_part , worker_annotations , worker_labels
452+ )
453+ worker_group = await DaskWorkerGroup .get (f"{ name } -default" , namespace = namespace )
454+
455+ if await worker_group .exists ():
456+ await worker_group .patch (desired_wg_spec )
457+ logger .info (f"Worker group { worker_group .name } patched." )
458+ else :
459+ logger .warning (f"Worker group { worker_group .name } not found. Recreating." )
460+ kopf .adopt (desired_wg_spec , owner = meta )
461+ await worker_group .create (desired_wg_spec )
462+
463+ patch .status ["observedGeneration" ] = meta .generation
464+ logger .info (f"Update handler finished for DaskCluster '{ name } '." )
392465
393466@kopf .on .field ("service" , field = "status" , labels = {"dask.org/component" : "scheduler" })
394467async def handle_scheduler_service_status (
@@ -400,23 +473,17 @@ async def handle_scheduler_service_status(
400473) -> None :
401474 assert namespace
402475 # If the Service is a LoadBalancer with no ingress endpoints mark the cluster as Pending
403- if spec ["type" ] == "LoadBalancer" and not len (
404- status .get ("loadBalancer" , {}).get ("ingress" , [])
405- ):
476+ if spec ["type" ] == "LoadBalancer" and not len (status .get ("loadBalancer" , {}).get ("ingress" , [])):
406477 phase = "Pending"
407478 # Otherwise mark it as Running
408479 else :
409480 phase = "Running"
410- cluster = await DaskCluster .get (
411- labels ["dask.org/cluster-name" ], namespace = namespace
412- )
481+ cluster = await DaskCluster .get (labels ["dask.org/cluster-name" ], namespace = namespace )
413482 await cluster .patch ({"status" : {"phase" : phase }})
414483
415484
416485@kopf .on .create ("daskworkergroup.kubernetes.dask.org" )
417- async def daskworkergroup_create (
418- body : kopf .Body , namespace : str | None , logger : kopf .Logger , ** kwargs : Any
419- ) -> None :
486+ async def daskworkergroup_create (body : kopf .Body , namespace : str | None , logger : kopf .Logger , ** kwargs : Any ) -> None :
420487 assert namespace
421488 wg = await DaskWorkerGroup (body , namespace = namespace )
422489 cluster = await wg .cluster ()
@@ -463,9 +530,7 @@ async def retire_workers(
463530 )
464531
465532 # Otherwise try gracefully retiring via the RPC
466- logger .debug (
467- f"Scaling { worker_group_name } failed via the HTTP API, falling back to the Dask RPC"
468- )
533+ logger .debug (f"Scaling { worker_group_name } failed via the HTTP API, falling back to the Dask RPC" )
469534 # Dask version mismatches between the operator and scheduler may cause this to fail in any number of unexpected ways
470535 with suppress (Exception ):
471536 comm_address = await get_scheduler_address (
@@ -499,9 +564,7 @@ def retire_workers_lifo(workers, n_workers: int) -> list[str]:
499564 return [w .name for w in workers [- n_workers :]]
500565
501566
502- async def check_scheduler_idle (
503- scheduler_service_name : str , namespace : str | None , logger : kopf .Logger
504- ) -> float :
567+ async def check_scheduler_idle (scheduler_service_name : str , namespace : str | None , logger : kopf .Logger ) -> float :
505568 assert namespace
506569 # Try getting idle time via HTTP API
507570 dashboard_address = await get_scheduler_address (
@@ -525,9 +588,7 @@ async def check_scheduler_idle(
525588 )
526589
527590 # Otherwise try gracefully checking via the RPC
528- logger .debug (
529- f"Checking { scheduler_service_name } idleness failed via the HTTP API, falling back to the Dask RPC"
530- )
591+ logger .debug (f"Checking { scheduler_service_name } idleness failed via the HTTP API, falling back to the Dask RPC" )
531592 # Dask version mismatches between the operator and scheduler may cause this to fail in any number of unexpected ways
532593 with suppress (Exception ):
533594 comm_address = await get_scheduler_address (
@@ -573,9 +634,7 @@ def idle_since_func(dask_scheduler: Scheduler) -> float:
573634 return float (idle_since )
574635
575636
576- async def get_desired_workers (
577- scheduler_service_name : str , namespace : str | None
578- ) -> Any :
637+ async def get_desired_workers (scheduler_service_name : str , namespace : str | None ) -> Any :
579638 assert namespace
580639 # Try gracefully retiring via the HTTP API
581640 dashboard_address = await get_scheduler_address (
@@ -602,9 +661,7 @@ async def get_desired_workers(
602661 async with rpc (comm_address ) as scheduler_comm :
603662 return await scheduler_comm .adaptive_target ()
604663 except Exception as e :
605- raise SchedulerCommError (
606- "Unable to get number of desired workers from scheduler"
607- ) from e
664+ raise SchedulerCommError ("Unable to get number of desired workers from scheduler" ) from e
608665
609666
610667worker_group_scale_locks : dict [str , asyncio .Lock ] = defaultdict (lambda : asyncio .Lock ())
@@ -669,13 +726,9 @@ async def daskworkergroup_replica_update(
669726 if "labels" in worker_spec ["metadata" ]:
670727 labels .update (** worker_spec ["metadata" ]["labels" ])
671728
672- batch_size = int (
673- dask .config .get ("kubernetes.controller.worker-allocation.batch-size" ) or 0
674- )
729+ batch_size = int (dask .config .get ("kubernetes.controller.worker-allocation.batch-size" ) or 0 )
675730 batch_size = min (workers_needed , batch_size ) if batch_size else workers_needed
676- batch_delay = int (
677- dask .config .get ("kubernetes.controller.worker-allocation.delay" ) or 0
678- )
731+ batch_delay = int (dask .config .get ("kubernetes.controller.worker-allocation.delay" ) or 0 )
679732 if workers_needed > 0 :
680733 for _ in range (batch_size ):
681734 data = build_worker_deployment_spec (
@@ -701,9 +754,7 @@ async def daskworkergroup_replica_update(
701754 if workers_needed < 0 :
702755 worker_ids = await retire_workers (
703756 n_workers = - workers_needed ,
704- scheduler_service_name = SCHEDULER_NAME_TEMPLATE .format (
705- cluster_name = cluster_name
706- ),
757+ scheduler_service_name = SCHEDULER_NAME_TEMPLATE .format (cluster_name = cluster_name ),
707758 worker_group_name = name ,
708759 namespace = namespace ,
709760 logger = logger ,
@@ -712,15 +763,11 @@ async def daskworkergroup_replica_update(
712763 for wid in worker_ids :
713764 worker_deployment = await Deployment (wid , namespace = namespace )
714765 await worker_deployment .delete ()
715- logger .info (
716- f"Scaled worker group { name } down to { desired_workers } workers."
717- )
766+ logger .info (f"Scaled worker group { name } down to { desired_workers } workers." )
718767
719768
720769@kopf .on .delete ("daskworkergroup.kubernetes.dask.org" , optional = True )
721- async def daskworkergroup_remove (
722- name : str | None , namespace : str | None , ** __ : Any
723- ) -> None :
770+ async def daskworkergroup_remove (name : str | None , namespace : str | None , ** __ : Any ) -> None :
724771 assert name
725772 assert namespace
726773 lock_key = f"{ name } /{ namespace } "
@@ -742,9 +789,7 @@ async def daskjob_create(
742789 patch .status ["jobStatus" ] = "JobCreated"
743790
744791
745- @kopf .on .field (
746- "daskjob.kubernetes.dask.org" , field = "status.jobStatus" , new = "JobCreated"
747- )
792+ @kopf .on .field ("daskjob.kubernetes.dask.org" , field = "status.jobStatus" , new = "JobCreated" )
748793async def daskjob_create_components (
749794 spec : kopf .Spec ,
750795 name : str | None ,
@@ -776,9 +821,7 @@ async def daskjob_create_components(
776821 kopf .adopt (cluster_spec )
777822 cluster = await DaskCluster (cluster_spec , namespace = namespace )
778823 await cluster .create ()
779- logger .info (
780- f"Cluster { cluster_spec ['metadata' ]['name' ]} for job { name } created in { namespace } ."
781- )
824+ logger .info (f"Cluster { cluster_spec ['metadata' ]['name' ]} for job { name } created in { namespace } ." )
782825
783826 labels = _get_labels (meta )
784827 annotations = _get_annotations (meta )
@@ -881,9 +924,7 @@ async def handle_runner_status_change_failed(
881924
882925
883926@kopf .on .create ("daskautoscaler.kubernetes.dask.org" )
884- async def daskautoscaler_create (
885- body : kopf .Body , logger : kopf .Logger , ** __ : Any
886- ) -> None :
927+ async def daskautoscaler_create (body : kopf .Body , logger : kopf .Logger , ** __ : Any ) -> None :
887928 """When an autoscaler is created make it a child of the associated cluster for cascade deletion."""
888929 autoscaler = await DaskAutoscaler (body )
889930 cluster = await autoscaler .cluster ()
@@ -916,16 +957,10 @@ async def daskautoscaler_adapt(
916957 return
917958
918959 autoscaler = await DaskAutoscaler .get (name , namespace = namespace )
919- worker_group = await DaskWorkerGroup .get (
920- f"{ spec ['cluster' ]} -default" , namespace = namespace
921- )
960+ worker_group = await DaskWorkerGroup .get (f"{ spec ['cluster' ]} -default" , namespace = namespace )
922961
923962 current_replicas = worker_group .replicas
924- cooldown_until = float (
925- autoscaler .annotations .get (
926- DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION , time .time ()
927- )
928- )
963+ cooldown_until = float (autoscaler .annotations .get (DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION , time .time ()))
929964
930965 # Cooldown autoscaling to prevent thrashing
931966 if time .time () < cooldown_until :
@@ -957,9 +992,7 @@ async def daskautoscaler_adapt(
957992
958993 cooldown_until = time .time () + 15
959994
960- await autoscaler .annotate (
961- {DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION : str (cooldown_until )}
962- )
995+ await autoscaler .annotate ({DASK_AUTOSCALER_COOLDOWN_UNTIL_ANNOTATION : str (cooldown_until )})
963996
964997 logger .info (
965998 "Autoscaler updated %s worker count from %d to %d" ,
@@ -968,9 +1001,7 @@ async def daskautoscaler_adapt(
9681001 desired_workers ,
9691002 )
9701003 else :
971- logger .debug (
972- "Not autoscaling %s with %d workers" , spec ["cluster" ], current_replicas
973- )
1004+ logger .debug ("Not autoscaling %s with %d workers" , spec ["cluster" ], current_replicas )
9741005
9751006
9761007@kopf .timer ("daskcluster.kubernetes.dask.org" , interval = 5.0 )
@@ -990,9 +1021,7 @@ async def daskcluster_autoshutdown(
9901021 logger = logger ,
9911022 )
9921023 except Exception : # TODO: Not use broad "Exception" catch here
993- logger .warning (
994- "Unable to connect to scheduler, skipping autoshutdown check."
995- )
1024+ logger .warning ("Unable to connect to scheduler, skipping autoshutdown check." )
9961025 return
9971026 if idle_since and time .time () > idle_since + idle_timeout :
9981027 cluster = await DaskCluster .get (name , namespace = namespace )
0 commit comments