@@ -49,6 +49,7 @@ use hyperactor::config;
4949use hyperactor:: config:: CONFIG ;
5050use hyperactor:: config:: ConfigAttr ;
5151use monarch_types:: SerializablePyErr ;
52+ use monarch_types:: py_global;
5253use pyo3:: IntoPyObjectExt ;
5354#[ cfg( test) ]
5455use pyo3:: PyClass ;
@@ -77,6 +78,9 @@ declare_attrs! {
7778 pub attr ENABLE_UNAWAITED_PYTHON_TASK_TRACEBACK : bool = false ;
7879}
7980
81+ py_global ! ( context, "monarch._src.actor.actor_mesh" , "context" ) ;
82+ py_global ! ( actor_mesh_module, "monarch._src.actor" , "actor_mesh" ) ;
83+
8084fn current_traceback ( ) -> PyResult < Option < PyObject > > {
8185 if config:: global:: get ( ENABLE_UNAWAITED_PYTHON_TASK_TRACEBACK ) {
8286 Python :: with_gil ( |py| {
@@ -336,10 +340,7 @@ impl PyPythonTask {
336340 // context() from the context in which the PythonTask was constructed.
337341 // We need to do this manually because the value of the contextvar isn't
338342 // maintained inside the tokio runtime.
339- let monarch_context = py
340- . import ( "monarch._src.actor.actor_mesh" ) ?
341- . call_method0 ( "context" ) ?
342- . unbind ( ) ;
343+ let monarch_context = context ( py) . call0 ( ) ?. unbind ( ) ;
343344 PyPythonTask :: new ( async move {
344345 let ( coroutine_iterator, none) = Python :: with_gil ( |py| {
345346 coro. into_bound ( py)
@@ -355,11 +356,11 @@ impl PyPythonTask {
355356 let action: PyResult < Action > = Python :: with_gil ( |py| {
356357 // We may be executing in a new thread at this point, so we need to set the value
357358 // of context().
358- let _context = py
359- . import ( "monarch._src.actor.actor_mesh" ) ?
360- . getattr ( "_context" ) ?;
359+ let _context = actor_mesh_module ( py) . getattr ( "_context" ) ?;
361360 let old_context = _context. call_method1 ( "get" , ( PyNone :: get ( py) , ) ) ?;
362- _context. call_method1 ( "set" , ( monarch_context. clone_ref ( py) , ) ) ?;
361+ _context
362+ . call_method1 ( "set" , ( monarch_context. clone_ref ( py) , ) )
363+ . expect ( "failed to set _context" ) ;
363364
364365 let result = match last {
365366 Ok ( value) => coroutine_iterator. bind ( py) . call_method1 ( "send" , ( value, ) ) ,
@@ -369,7 +370,9 @@ impl PyPythonTask {
369370 } ;
370371
371372 // Reset context() so that when this tokio thread yields, it has its original state.
372- _context. call_method1 ( "set" , ( old_context, ) ) ?;
373+ _context
374+ . call_method1 ( "set" , ( old_context, ) )
375+ . expect ( "failed to restore _context" ) ;
373376 match result {
374377 Ok ( task) => Ok ( Action :: Wait (
375378 task. extract :: < Py < PyPythonTask > > ( )
@@ -415,14 +418,29 @@ impl PyPythonTask {
415418 }
416419
417420 #[ staticmethod]
418- fn spawn_blocking ( f : PyObject ) -> PyResult < PyShared > {
421+ fn spawn_blocking ( py : Python < ' _ > , f : PyObject ) -> PyResult < PyShared > {
419422 let ( tx, rx) = watch:: channel ( None ) ;
420423 let traceback = current_traceback ( ) ?;
421424 let traceback1 = traceback
422425 . as_ref ( )
423426 . map_or_else ( || None , |t| Python :: with_gil ( |py| Some ( t. clone_ref ( py) ) ) ) ;
427+ let monarch_context = context ( py) . call0 ( ) ?. unbind ( ) ;
428+ // The `_context` contextvar needs to be propagated through to the thread that
429+ // runs the blocking tokio task. Upon completion, the original value of `_context`
430+ // is restored.
424431 let handle = get_tokio_runtime ( ) . spawn_blocking ( move || {
425- let result = Python :: with_gil ( |py| f. call0 ( py) ) ;
432+ let result = Python :: with_gil ( |py| {
433+ let _context = actor_mesh_module ( py) . getattr ( "_context" ) ?;
434+ let old_context = _context. call_method1 ( "get" , ( PyNone :: get ( py) , ) ) ?;
435+ _context
436+ . call_method1 ( "set" , ( monarch_context. clone_ref ( py) , ) )
437+ . expect ( "failed to set _context" ) ;
438+ let result = f. call0 ( py) ;
439+ _context
440+ . call_method1 ( "set" , ( old_context, ) )
441+ . expect ( "failed to restore _context" ) ;
442+ result
443+ } ) ;
426444 send_result ( tx, result, traceback1) ;
427445 } ) ;
428446 Ok ( PyShared {
0 commit comments