Skip to content

Commit b8d1524

Browse files
dulinrileymeta-codesync[bot]
authored andcommitted
Add cleanup function to Actor trait and use it from PythonActor (#1836)
Summary: Pull Request resolved: #1836 Fixes #1849 RFC: Add a "cleanup" for actors to run on stop. This method is invoked after all child actors are stopped, but before the current actor exits. If an exception is thrown, it will become an error at shutdown which will propagate. The cleanup method in python can be sync or async, and must match the syncness of endpoints. It takes one argument, which is an `Optional[Exception]`. If this was an abnormal exit, that exception is not None, specifically to the result of `ActorError::to_string` (wrapped in an exception object). Later on we may be able to preserve the original exception, if there is one. The motivation is from actors that want to call `dist.destroy_process_group()`, as that is one of the most frequent cleanup action in users of monarch. Actors should *not* call `stop()` on any actor or proc meshes they own. This will be handled automatically in the future, and they will have already been stopped by the time cleanup is invoked. This cleanup is per-actor, not per-proc. So if the cleanup is destroying process-wide resources (as does "destroy_process_group"), then the actor shouldn't be colocated with any other actors on the same proc using the same resource. If the cleanup takes too long, ignore the result and continue with stopping. ProcMesh::stop() already does a graceful stop of all actors, so this cleanup will be run automatically when proc meshes are stopped. Reviewed By: mariusae Differential Revision: D85624518 fbshipit-source-id: 172eeebf18eddc1a7f5928dcc31efd4cd9120287
1 parent bb4f303 commit b8d1524

File tree

7 files changed

+272
-5
lines changed

7 files changed

+272
-5
lines changed

hyperactor/src/actor.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,25 @@ pub trait Actor: Sized + Send + Debug + 'static {
8282
Ok(())
8383
}
8484

85+
/// Cleanup things used by this actor before shutting down. Notably this function
86+
/// is async and allows more complex cleanup. Simpler cleanup can be handled
87+
/// by the impl Drop for this Actor.
88+
/// If err is not None, it is the error that this actor is failing with. Any
89+
/// errors returned by this function will be logged and ignored.
90+
/// If err is None, any errors returned by this function will be propagated
91+
/// as an ActorError.
92+
/// This function is not called if there is a panic in the actor, as the
93+
/// actor may be in an indeterminate state. It is also not called if the
94+
/// process is killed, there is no atexit handler or signal handler.
95+
async fn cleanup(
96+
&mut self,
97+
_this: &Instance<Self>,
98+
_err: Option<&ActorError>,
99+
) -> Result<(), anyhow::Error> {
100+
// Default implementation: no cleanup.
101+
Ok(())
102+
}
103+
85104
/// Spawn a child actor, given a spawning capability (usually given by [`Instance`]).
86105
/// The spawned actor will be supervised by the parent (spawning) actor.
87106
async fn spawn(
@@ -343,6 +362,11 @@ impl ActorErrorKind {
343362
Self::Generic(format!("initialization error: {}", err))
344363
}
345364

365+
/// Error during actor cleanup.
366+
pub fn cleanup(err: anyhow::Error) -> Self {
367+
Self::Generic(format!("cleanup error: {}", err))
368+
}
369+
346370
/// An underlying mailbox error.
347371
pub fn mailbox(err: MailboxError) -> Self {
348372
Self::Generic(err.to_string())

hyperactor/src/config.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,14 @@ declare_attrs! {
155155
})
156156
pub attr STOP_ACTOR_TIMEOUT: Duration = Duration::from_secs(10);
157157

158+
/// Timeout used by proc for running the cleanup callback on an actor.
159+
/// Should be less than the timeout for STOP_ACTOR_TIMEOUT.
160+
@meta(CONFIG = ConfigAttr {
161+
env_name: Some("HYPERACTOR_CLEANUP_TIMEOUT".to_string()),
162+
py_name: None,
163+
})
164+
pub attr CLEANUP_TIMEOUT: Duration = Duration::from_secs(3);
165+
158166
/// Heartbeat interval for remote allocator
159167
@meta(CONFIG = ConfigAttr {
160168
env_name: Some("HYPERACTOR_REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL".to_string()),

hyperactor/src/proc.rs

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,13 +1180,15 @@ impl<A: Actor> Instance<A> {
11801180
// https://docs.rs/tokio/latest/tokio/task/struct.JoinError.html#method.is_panic
11811181
// What we do here is just to catch it early so we can handle it.
11821182

1183+
let mut did_panic = false;
11831184
let result = match AssertUnwindSafe(self.run(actor, &mut actor_loop_receivers, work_rx))
11841185
.catch_unwind()
11851186
.await
11861187
{
11871188
Ok(result) => result,
11881189
Err(err) => {
11891190
// This is only the error message. Backtrace is not included.
1191+
did_panic = true;
11901192
let err_msg = err
11911193
.downcast_ref::<&str>()
11921194
.copied()
@@ -1252,8 +1254,41 @@ impl<A: Actor> Instance<A> {
12521254
}
12531255
}
12541256
}
1255-
1256-
result
1257+
// Run the actor cleanup function before the actor stops to delete
1258+
// resources. If it times out, continue with stopping the actor.
1259+
// Don't call it if there was a panic, because the actor may
1260+
// be in an invalid state and unable to access anything, for example
1261+
// the GIL.
1262+
let cleanup_result = if !did_panic {
1263+
let cleanup_timeout = config::global::get(config::CLEANUP_TIMEOUT);
1264+
match RealClock
1265+
.timeout(cleanup_timeout, actor.cleanup(self, result.as_ref().err()))
1266+
.await
1267+
{
1268+
Ok(Ok(x)) => Ok(x),
1269+
Ok(Err(e)) => Err(ActorError::new(self.self_id(), ActorErrorKind::cleanup(e))),
1270+
Err(e) => Err(ActorError::new(
1271+
self.self_id(),
1272+
ActorErrorKind::cleanup(e.into()),
1273+
)),
1274+
}
1275+
} else {
1276+
Ok(())
1277+
};
1278+
if let Err(ref actor_err) = result {
1279+
// The original result error takes precedence over the cleanup error,
1280+
// so make sure the cleanup error is still logged in that case.
1281+
if let Err(ref err) = cleanup_result {
1282+
tracing::warn!(
1283+
cleanup_err = %err,
1284+
%actor_err,
1285+
"ignoring cleanup error after actor error",
1286+
);
1287+
}
1288+
}
1289+
// If the original exit was not an error, let cleanup errors be
1290+
// surfaced.
1291+
result.and(cleanup_result)
12571292
}
12581293

12591294
/// Initialize and run the actor until it fails or is stopped.

monarch_hyperactor/src/actor.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ use hyperactor::Named;
2222
use hyperactor::OncePortHandle;
2323
use hyperactor::PortHandle;
2424
use hyperactor::ProcId;
25+
use hyperactor::actor::ActorError;
26+
use hyperactor::attrs::Attrs;
2527
use hyperactor::mailbox::MessageEnvelope;
2628
use hyperactor::mailbox::Undeliverable;
2729
use hyperactor::message::Bind;
@@ -570,6 +572,63 @@ impl Actor for PythonActor {
570572
})?)
571573
}
572574

575+
async fn cleanup(
576+
&mut self,
577+
this: &Instance<Self>,
578+
err: Option<&ActorError>,
579+
) -> anyhow::Result<()> {
580+
// Calls the "__cleanup__" method on the python instance to allow the actor
581+
// to control its own cleanup.
582+
// No headers because this isn't in the context of a message.
583+
let cx = Context::new(this, Attrs::new());
584+
// Turn the ActorError into a representation of the error. We may not
585+
// have an original exception object or traceback, so we just pass in
586+
// the message.
587+
let err_as_str = err.map(|e| e.to_string());
588+
let future = Python::with_gil(|py| {
589+
let py_cx = match self.instance {
590+
Some(ref instance) => crate::context::PyContext::new(&cx, instance.clone_ref(py)),
591+
None => {
592+
let py_instance: crate::context::PyInstance = this.into();
593+
crate::context::PyContext::new(
594+
&cx,
595+
py_instance
596+
.into_py_any(py)?
597+
.downcast_bound(py)
598+
.map_err(PyErr::from)?
599+
.clone()
600+
.unbind(),
601+
)
602+
}
603+
}
604+
.into_bound_py_any(py)?;
605+
let actor = self.actor.bind(py);
606+
// Some tests don't use the Actor base class, so add this check
607+
// to be defensive.
608+
match actor.hasattr("__cleanup__") {
609+
Ok(false) | Err(_) => {
610+
// No cleanup found, default to returning None
611+
return Ok(None);
612+
}
613+
_ => {}
614+
}
615+
let awaitable = actor
616+
.call_method("__cleanup__", (&py_cx, err_as_str), None)
617+
.map_err(|err| anyhow::Error::from(SerializablePyErr::from(py, &err)))?;
618+
if awaitable.is_none() {
619+
Ok(None)
620+
} else {
621+
pyo3_async_runtimes::into_future_with_locals(self.get_task_locals(py), awaitable)
622+
.map(Some)
623+
.map_err(anyhow::Error::from)
624+
}
625+
})?;
626+
if let Some(future) = future {
627+
future.await.map_err(anyhow::Error::from)?;
628+
}
629+
Ok(())
630+
}
631+
573632
async fn handle_undeliverable_message(
574633
&mut self,
575634
ins: &Instance<Self>,

python/monarch/_src/actor/actor_mesh.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,6 +1145,34 @@ def __supervise__(self, cx: Context, *args: Any, **kwargs: Any) -> object:
11451145
# propagated to the next owner.
11461146
return None
11471147

1148+
async def __cleanup__(self, cx: Context, exc: str | Exception | None) -> None:
1149+
"""Cleans up any resources owned by this Actor before stopping. Automatically
1150+
called even if there is an error"""
1151+
_context.set(cx)
1152+
instance = self.instance
1153+
if instance is None:
1154+
# If there is no instance, there's nothing to clean up, the actor
1155+
# was never constructed
1156+
return None
1157+
1158+
# Forward a call to supervise on this actor to the user-provided instance.
1159+
cleanup = getattr(instance, "__cleanup__", None)
1160+
if cleanup is None:
1161+
return None
1162+
1163+
if isinstance(exc, str):
1164+
# Wrap the string in an exception object so the main API of __cleanup__
1165+
# is to take an optional exception object.
1166+
# The raw string is used for wider compatibility with other error
1167+
# types for now.
1168+
exc = Exception(exc)
1169+
1170+
if inspect.iscoroutinefunction(cleanup):
1171+
return await cleanup(exc)
1172+
else:
1173+
with fake_sync_state():
1174+
return cleanup(exc)
1175+
11481176
def __repr__(self) -> str:
11491177
return f"_Actor(instance={self.instance!r})"
11501178

@@ -1232,6 +1260,7 @@ def __init__(
12321260

12331261
async_endpoints = []
12341262
sync_endpoints = []
1263+
async_cleanup = None
12351264
for attr_name in dir(self._class):
12361265
attr_value = getattr(self._class, attr_name, None)
12371266
if isinstance(attr_value, EndpointProperty):
@@ -1255,13 +1284,28 @@ def __init__(
12551284
async_endpoints.append(attr_name)
12561285
else:
12571286
sync_endpoints.append(attr_name)
1287+
if attr_name == "__cleanup__" and attr_value is not None:
1288+
async_cleanup = inspect.iscoroutinefunction(attr_value)
12581289

12591290
if sync_endpoints and async_endpoints:
12601291
raise ValueError(
12611292
f"{self._class} mixes both async and sync endpoints."
12621293
"Synchronous endpoints cannot be mixed with async endpoints because they can cause the asyncio loop to deadlock if they wait."
12631294
f"sync: {sync_endpoints} async: {async_endpoints}"
12641295
)
1296+
if sync_endpoints and async_cleanup:
1297+
raise ValueError(
1298+
f"{self._class} has sync endpoints, but an async __cleanup__. Make sure __cleanup__ is also synchronous."
1299+
"Synchronous endpoints cannot be mixed with async endpoints because they can cause the asyncio loop to deadlock if they wait."
1300+
f"sync: {sync_endpoints}"
1301+
)
1302+
# Check for False explicitly because None means there is no cleanup.
1303+
if async_endpoints and async_cleanup is False:
1304+
raise ValueError(
1305+
f"{self._class} has async endpoints, but a synchronous __cleanup__. Make sure __cleanup__ is also async."
1306+
"Synchronous endpoints cannot be mixed with async endpoints because they can cause the asyncio loop to deadlock if they wait."
1307+
f"sync: {sync_endpoints}"
1308+
)
12651309

12661310
def __getattr__(self, attr: str) -> NotAnEndpoint:
12671311
if attr in dir(self._class):

python/tests/test_env_before_cuda.py renamed to python/tests/test_cuda.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
import os
1010
import sys
1111
import unittest
12-
from typing import Dict, List
12+
from typing import cast, Dict, List
1313

1414
import cloudpickle
15-
import monarch.actor
1615
import torch
16+
import torch.distributed as dist
17+
from monarch._src.actor.actor_mesh import ActorMesh
1718
from monarch._src.actor.host_mesh import create_local_host_mesh, fake_in_process_host
18-
from monarch.actor import Actor, endpoint
19+
from monarch.actor import Actor, current_rank, current_size, endpoint, this_host
1920

2021

2122
class CudaInitTestActor(Actor):
@@ -46,6 +47,42 @@ async def is_cuda_initialized(self) -> bool:
4647
return self.cuda_initialized
4748

4849

50+
class TorchDistributedActor(Actor):
51+
"""Actor that initializes CUDA and checks environment variables"""
52+
53+
def __init__(self) -> None:
54+
self.rank = int(current_rank()["gpus"])
55+
self.world_size = int(current_size()["gpus"])
56+
self.port = 29500
57+
os.environ["MASTER_ADDR"] = "localhost"
58+
os.environ["MASTER_PORT"] = str(self.port)
59+
60+
@endpoint
61+
def init_torch_distributed(self) -> None:
62+
if not dist.is_initialized():
63+
dist.init_process_group(
64+
backend="nccl",
65+
world_size=self.world_size,
66+
rank=self.rank,
67+
)
68+
69+
@endpoint
70+
def is_initialized(self) -> bool:
71+
return dist.is_initialized()
72+
73+
# Cleanup is a special function called automatically on actor stop.
74+
def __cleanup__(self, exc: Exception | None) -> None:
75+
self.logger.info(f"Cleanup called with exception: {exc}")
76+
if dist.is_initialized():
77+
dist.destroy_process_group()
78+
79+
80+
class IsTorchInitializedActor(Actor):
81+
@endpoint
82+
def is_initialized(self) -> bool:
83+
return dist.is_initialized()
84+
85+
4986
class TestEnvBeforeCuda(unittest.IsolatedAsyncioTestCase):
5087
"""Test that the env vars are setup before cuda init"""
5188

@@ -149,3 +186,16 @@ async def test_proc_mesh_with_dictionary_env(self) -> None:
149186
env_vars.get("CUDA_DEVICE_MAX_CONNECTIONS"),
150187
"1",
151188
)
189+
190+
async def test_cleanup_torch_distributed(self) -> None:
191+
"""Test that calling stop on the actor destroys the process group"""
192+
proc_mesh = this_host().spawn_procs(per_host={"gpus": 1})
193+
194+
actor = proc_mesh.spawn("torch_init", TorchDistributedActor)
195+
tester = proc_mesh.spawn("check", IsTorchInitializedActor)
196+
await actor.init_torch_distributed.call_one()
197+
self.assertTrue(await actor.is_initialized.call_one())
198+
# Stop the actor and ensure cleanup is called, by using another actor
199+
# on the same proc.
200+
await cast(ActorMesh[TorchDistributedActor], actor).stop()
201+
self.assertFalse(await tester.is_initialized.call_one())

python/tests/test_python_actors.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1816,3 +1816,50 @@ def test_context_propagated_through_python_task_spawn_blocking():
18161816
p = this_host().spawn_procs()
18171817
a = p.spawn("test_pytokio_actor", TestPytokioActor)
18181818
a.context_propagated_through_spawn_blocking.call().get()
1819+
1820+
1821+
class ActorWithCleanup(Actor):
1822+
def __init__(self, counter: Counter) -> None:
1823+
self.counter = counter
1824+
1825+
@endpoint
1826+
def check(self) -> None:
1827+
pass
1828+
1829+
def __cleanup__(self, exc: Exception | None):
1830+
self.logger.info(f"Calling __cleanup__ on {self}, {exc=}")
1831+
self.counter.incr.call_one().get()
1832+
1833+
1834+
class ActorWithAsyncCleanup(Actor):
1835+
def __init__(self, counter: Counter) -> None:
1836+
self.counter = counter
1837+
1838+
@endpoint
1839+
async def check(self) -> None:
1840+
pass
1841+
1842+
# Cleanup should match the async-ness of the other endpoints.
1843+
async def __cleanup__(self, exc: Exception | None):
1844+
self.logger.info(f"Calling __cleanup__ on {self}, {exc=}")
1845+
await self.counter.incr.call_one()
1846+
1847+
1848+
def test_cleanup():
1849+
procs = this_host().spawn_procs(per_host={"gpus": 1})
1850+
counter = procs.spawn("counter", Counter, 0)
1851+
cleanup = procs.spawn("cleanup", ActorWithCleanup, counter)
1852+
# Call an endpoint to ensure it is constructed.
1853+
cleanup.check.call_one().get()
1854+
cast(ActorMesh[ActorWithCleanup], cleanup).stop().get()
1855+
assert counter.value.call_one().get() == 1
1856+
1857+
1858+
def test_cleanup_async():
1859+
procs = this_host().spawn_procs(per_host={"gpus": 1})
1860+
counter = procs.spawn("counter", Counter, 0)
1861+
cleanup = procs.spawn("cleanup", ActorWithCleanup, counter)
1862+
# Call an endpoint to ensure it is constructed.
1863+
cleanup.check.call_one().get()
1864+
cast(ActorMesh[ActorWithCleanup], cleanup).stop().get()
1865+
assert counter.value.call_one().get() == 1

0 commit comments

Comments
 (0)