1+ from __future__ import annotations
2+
13import asyncio
4+ import atexit
5+ from contextlib import suppress
26from enum import Enum
7+ import time
8+ from typing import ClassVar
9+ import weakref
10+
311import kubernetes_asyncio as kubernetes
412
5- from distributed .core import rpc
13+ from distributed .core import Status , rpc
614from distributed .deploy import Cluster
715
8- from distributed .utils import Log , Logs , LoopRunner
16+ from distributed .utils import Log , Logs , LoopRunner , TimeoutError
917
1018from dask_kubernetes .common .auth import ClusterAuth
1119from dask_kubernetes .operator import (
@@ -103,6 +111,8 @@ class KubeCluster(Cluster):
103111 KubeCluster.from_name
104112 """
105113
114+ _instances : ClassVar [weakref .WeakSet [KubeCluster ]] = weakref .WeakSet ()
115+
106116 def __init__ (
107117 self ,
108118 name ,
@@ -133,6 +143,8 @@ def __init__(
133143 self ._loop_runner = LoopRunner (loop = loop , asynchronous = asynchronous )
134144 self .loop = self ._loop_runner .loop
135145
146+ self ._instances .add (self )
147+
136148 super ().__init__ (asynchronous = asynchronous , ** kwargs )
137149 if not self .asynchronous :
138150 self ._loop_runner .start ()
@@ -363,11 +375,11 @@ async def _delete_worker_group(self, name):
363375 name = f"{ self .name } -cluster-{ name } " ,
364376 )
365377
366- def close (self ):
378+ def close (self , timeout = 3600 ):
367379 """Delete the dask cluster"""
368- return self .sync (self ._close )
380+ return self .sync (self ._close , timeout = timeout )
369381
370- async def _close (self ):
382+ async def _close (self , timeout = None ):
371383 await super ()._close ()
372384 if self .shutdown_on_close :
373385 async with kubernetes .client .api_client .ApiClient () as api_client :
@@ -379,7 +391,12 @@ async def _close(self):
379391 namespace = self .namespace ,
380392 name = self .cluster_name ,
381393 )
394+ start = time .time ()
382395 while (await self ._get_cluster ()) is not None :
396+ if time .time () > start + timeout :
397+ raise TimeoutError (
398+ f"Timed out deleting cluster resource { self .cluster_name } "
399+ )
383400 await asyncio .sleep (1 )
384401
385402 def scale (self , n , worker_group = "default" ):
@@ -537,3 +554,19 @@ def from_name(cls, name, **kwargs):
537554 >>> cluster = KubeCluster.from_name(name="simple-cluster")
538555 """
539556 return cls (name = name , create_mode = CreateMode .CONNECT_ONLY , ** kwargs )
557+
558+
559+ @atexit .register
560+ def reap_clusters ():
561+ async def _reap_clusters ():
562+ for cluster in list (KubeCluster ._instances ):
563+ if cluster .shutdown_on_close and cluster .status != Status .closed :
564+ await ClusterAuth .load_first (cluster .auth )
565+ with suppress (TimeoutError ):
566+ if cluster .asynchronous :
567+ await cluster .close (timeout = 10 )
568+ else :
569+ cluster .close (timeout = 10 )
570+
571+ loop = asyncio .get_event_loop ()
572+ loop .run_until_complete (_reap_clusters ())
0 commit comments