Skip to content

Commit a55083a

Browse files
Replace RPC with Scheduler HTTP API (#499)
* Replace RPC with scheduler API * Name worker the same as pod * Pass the names of the workers from the scheduler API to pod deletion function * Remove line adding name to worker * Generalize scheduler API endpoint * Fix retreiving JSON response * Remove MIME type handling now that it is fixed upstream * Reinstate RPC as fallback option, and last-in-first-out as second fallback Co-authored-by: Jacob Tomlinson <jtomlinson@nvidia.com>
1 parent 050146c commit a55083a

File tree

2 files changed

+67
-16
lines changed

2 files changed

+67
-16
lines changed

dask_kubernetes/common/networking.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,17 @@
1111

1212

1313
async def get_external_address_for_scheduler_service(
14-
core_api, service, port_forward_cluster_ip=None, service_name_resolution_retries=20
14+
core_api,
15+
service,
16+
port_forward_cluster_ip=None,
17+
service_name_resolution_retries=20,
18+
port_name="comm",
1519
):
1620
"""Take a service object and return the scheduler address."""
1721
[port] = [
1822
port.port
1923
for port in service.spec.ports
20-
if port.name == service.metadata.name or port.name == "comm"
24+
if port.name == service.metadata.name or port.name == port_name
2125
]
2226
if service.spec.type == "LoadBalancer":
2327
lb = service.status.load_balancer.ingress[0]
@@ -104,13 +108,16 @@ async def port_forward_dashboard(service_name, namespace):
104108
return port
105109

106110

107-
async def get_scheduler_address(service_name, namespace):
111+
async def get_scheduler_address(service_name, namespace, port_name="comm"):
108112
async with kubernetes.client.api_client.ApiClient() as api_client:
109113
api = kubernetes.client.CoreV1Api(api_client)
110114
service = await api.read_namespaced_service(service_name, namespace)
111115
port_forward_cluster_ip = None
112116
address = await get_external_address_for_scheduler_service(
113-
api, service, port_forward_cluster_ip=port_forward_cluster_ip
117+
api,
118+
service,
119+
port_forward_cluster_ip=port_forward_cluster_ip,
120+
port_name=port_name,
114121
)
115122
return address
116123

dask_kubernetes/operator/operator.py

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import asyncio
2-
3-
from distributed.core import rpc
2+
import aiohttp
3+
from contextlib import suppress
44

55
import kopf
66
import kubernetes_asyncio as kubernetes
77

88
from uuid import uuid4
99

10+
from distributed.core import rpc
11+
1012
from dask_kubernetes.common.auth import ClusterAuth
1113
from dask_kubernetes.common.networking import (
1214
get_scheduler_address,
@@ -195,6 +197,52 @@ async def daskworkergroup_create(spec, name, namespace, logger, **kwargs):
195197
)
196198

197199

200+
async def retire_workers(
201+
n_workers, scheduler_service_name, worker_group_name, namespace, logger
202+
):
203+
# Try gracefully retiring via the HTTP API
204+
dashboard_address = await get_scheduler_address(
205+
scheduler_service_name,
206+
namespace,
207+
port_name="dashboard",
208+
)
209+
async with aiohttp.ClientSession() as session:
210+
url = f"{dashboard_address}/api/v1/retire_workers"
211+
params = {"n": n_workers}
212+
async with session.post(url, json=params) as resp:
213+
if resp.status <= 300:
214+
retired_workers = await resp.json()
215+
return [retired_workers[w]["name"] for w in retired_workers.keys()]
216+
217+
# Otherwise try gracefully retiring via the RPC
218+
logger.info(
219+
f"Scaling {worker_group_name} failed via the HTTP API, falling back to the Dask RPC"
220+
)
221+
# Dask version mismatches between the operator and scheduler may cause this to fail in any number of unexpected ways
222+
with suppress(Exception):
223+
comm_address = await get_scheduler_address(
224+
scheduler_service_name,
225+
namespace,
226+
)
227+
async with rpc(comm_address) as scheduler_comm:
228+
return await scheduler_comm.workers_to_close(
229+
n=n_workers,
230+
attribute="name",
231+
)
232+
233+
# Finally fall back to last-in-first-out scaling
234+
logger.info(
235+
f"Scaling {worker_group_name} failed via the Dask RPC, falling back to LIFO scaling"
236+
)
237+
async with kubernetes.client.api_client.ApiClient() as api_client:
238+
api = kubernetes.client.CoreV1Api(api_client)
239+
workers = await api.list_namespaced_pod(
240+
namespace=namespace,
241+
label_selector=f"dask.org/workergroup-name={worker_group_name}",
242+
)
243+
return [w["metadata"]["name"] for w in workers.items[:-n_workers]]
244+
245+
198246
@kopf.on.update("daskworkergroup")
199247
async def daskworkergroup_update(spec, name, namespace, logger, **kwargs):
200248
async with kubernetes.client.api_client.ApiClient() as api_client:
@@ -226,17 +274,13 @@ async def daskworkergroup_update(spec, name, namespace, logger, **kwargs):
226274
f"Scaled worker group {name} up to {spec['worker']['replicas']} workers."
227275
)
228276
if workers_needed < 0:
229-
service_address = await get_scheduler_address(
230-
f"{spec['cluster']}-service", namespace
277+
worker_ids = await retire_workers(
278+
n_workers=-workers_needed,
279+
scheduler_service_name=f"{spec['cluster']}-service",
280+
worker_group_name=name,
281+
namespace=namespace,
282+
logger=logger,
231283
)
232-
logger.info(
233-
f"Asking scheduler to retire {-workers_needed} on {service_address}"
234-
)
235-
async with rpc(service_address) as scheduler:
236-
worker_ids = await scheduler.workers_to_close(
237-
n=-workers_needed, attribute="name"
238-
)
239-
# TODO: Check that were deting workers in the right worker group
240284
logger.info(f"Workers to close: {worker_ids}")
241285
for wid in worker_ids:
242286
await api.delete_namespaced_pod(

0 commit comments

Comments
 (0)