Skip to content

Commit 92e0118

Browse files
samluryemeta-codesync[bot]
authored andcommitted
Make sure context is propagated through PythonTask.spawn_blocking (#1730)
Summary: Pull Request resolved: #1730 Similar to `PythonTask.spawn`, we also need to make sure that the monarch context is propagated through `PythonTask.spawn_blocking`. ghstack-source-id: 322807310 exported-using-ghexport Reviewed By: mariusae, shayne-fletcher, zdevito Differential Revision: D85981930 fbshipit-source-id: 0176304354d33d5ac377df90c31d08c96d9ca3f3
1 parent fcdf726 commit 92e0118

File tree

2 files changed

+61
-11
lines changed

2 files changed

+61
-11
lines changed

monarch_hyperactor/src/pytokio.rs

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ use hyperactor::config;
4949
use hyperactor::config::CONFIG;
5050
use hyperactor::config::ConfigAttr;
5151
use monarch_types::SerializablePyErr;
52+
use monarch_types::py_global;
5253
use pyo3::IntoPyObjectExt;
5354
#[cfg(test)]
5455
use pyo3::PyClass;
@@ -77,6 +78,9 @@ declare_attrs! {
7778
pub attr ENABLE_UNAWAITED_PYTHON_TASK_TRACEBACK: bool = false;
7879
}
7980

81+
py_global!(context, "monarch._src.actor.actor_mesh", "context");
82+
py_global!(actor_mesh_module, "monarch._src.actor", "actor_mesh");
83+
8084
fn current_traceback() -> PyResult<Option<PyObject>> {
8185
if config::global::get(ENABLE_UNAWAITED_PYTHON_TASK_TRACEBACK) {
8286
Python::with_gil(|py| {
@@ -336,10 +340,7 @@ impl PyPythonTask {
336340
// context() from the context in which the PythonTask was constructed.
337341
// We need to do this manually because the value of the contextvar isn't
338342
// maintained inside the tokio runtime.
339-
let monarch_context = py
340-
.import("monarch._src.actor.actor_mesh")?
341-
.call_method0("context")?
342-
.unbind();
343+
let monarch_context = context(py).call0()?.unbind();
343344
PyPythonTask::new(async move {
344345
let (coroutine_iterator, none) = Python::with_gil(|py| {
345346
coro.into_bound(py)
@@ -355,11 +356,11 @@ impl PyPythonTask {
355356
let action: PyResult<Action> = Python::with_gil(|py| {
356357
// We may be executing in a new thread at this point, so we need to set the value
357358
// of context().
358-
let _context = py
359-
.import("monarch._src.actor.actor_mesh")?
360-
.getattr("_context")?;
359+
let _context = actor_mesh_module(py).getattr("_context")?;
361360
let old_context = _context.call_method1("get", (PyNone::get(py),))?;
362-
_context.call_method1("set", (monarch_context.clone_ref(py),))?;
361+
_context
362+
.call_method1("set", (monarch_context.clone_ref(py),))
363+
.expect("failed to set _context");
363364

364365
let result = match last {
365366
Ok(value) => coroutine_iterator.bind(py).call_method1("send", (value,)),
@@ -369,7 +370,9 @@ impl PyPythonTask {
369370
};
370371

371372
// Reset context() so that when this tokio thread yields, it has its original state.
372-
_context.call_method1("set", (old_context,))?;
373+
_context
374+
.call_method1("set", (old_context,))
375+
.expect("failed to restore _context");
373376
match result {
374377
Ok(task) => Ok(Action::Wait(
375378
task.extract::<Py<PyPythonTask>>()
@@ -415,14 +418,29 @@ impl PyPythonTask {
415418
}
416419

417420
#[staticmethod]
418-
fn spawn_blocking(f: PyObject) -> PyResult<PyShared> {
421+
fn spawn_blocking(py: Python<'_>, f: PyObject) -> PyResult<PyShared> {
419422
let (tx, rx) = watch::channel(None);
420423
let traceback = current_traceback()?;
421424
let traceback1 = traceback
422425
.as_ref()
423426
.map_or_else(|| None, |t| Python::with_gil(|py| Some(t.clone_ref(py))));
427+
let monarch_context = context(py).call0()?.unbind();
428+
// The `_context` contextvar needs to be propagated through to the thread that
429+
// runs the blocking tokio task. Upon completion, the original value of `_context`
430+
// is restored.
424431
let handle = get_tokio_runtime().spawn_blocking(move || {
425-
let result = Python::with_gil(|py| f.call0(py));
432+
let result = Python::with_gil(|py| {
433+
let _context = actor_mesh_module(py).getattr("_context")?;
434+
let old_context = _context.call_method1("get", (PyNone::get(py),))?;
435+
_context
436+
.call_method1("set", (monarch_context.clone_ref(py),))
437+
.expect("failed to set _context");
438+
let result = f.call0(py);
439+
_context
440+
.call_method1("set", (old_context,))
441+
.expect("failed to restore _context");
442+
result
443+
});
426444
send_result(tx, result, traceback1);
427445
});
428446
Ok(PyShared {

python/tests/test_python_actors.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1784,3 +1784,35 @@ def test_instance_name():
17841784
assert "actor=" not in logs.contents
17851785
finally:
17861786
monarch.actor.config.prefix_python_logs_with_actor = True
1787+
1788+
1789+
class TestPytokioActor(Actor):
1790+
@endpoint
1791+
def context_propagated_through_spawn(self) -> None:
1792+
cx = context()
1793+
1794+
async def task():
1795+
assert cx is context()
1796+
1797+
PythonTask.from_coroutine(coro=task()).spawn().block_on()
1798+
1799+
@endpoint
1800+
def context_propagated_through_spawn_blocking(self) -> None:
1801+
cx = context()
1802+
1803+
def task():
1804+
assert cx is context()
1805+
1806+
PythonTask.spawn_blocking(task).block_on()
1807+
1808+
1809+
def test_context_propagated_through_python_task_spawn():
1810+
p = this_host().spawn_procs()
1811+
a = p.spawn("test_pytokio_actor", TestPytokioActor)
1812+
a.context_propagated_through_spawn.call().get()
1813+
1814+
1815+
def test_context_propagated_through_python_task_spawn_blocking():
1816+
p = this_host().spawn_procs()
1817+
a = p.spawn("test_pytokio_actor", TestPytokioActor)
1818+
a.context_propagated_through_spawn_blocking.call().get()

0 commit comments

Comments
 (0)