diff --git a/scripts/gen_bridge_client.py b/scripts/gen_bridge_client.py index 89ae54ec1..f0a4457dc 100644 --- a/scripts/gen_bridge_client.py +++ b/scripts/gen_bridge_client.py @@ -171,6 +171,7 @@ def generate_rust_service_call(service_descriptor: ServiceDescriptor) -> str: py: Python<'p>, call: RpcCall, ) -> PyResult> { + self.runtime.assert_same_process("use client")?; use temporal_client::${descriptor_name}; let mut retry_client = self.retry_client.clone(); self.runtime.future_into_py(py, async move { diff --git a/temporalio/bridge/src/client.rs b/temporalio/bridge/src/client.rs index dfbd432a1..abe4a2354 100644 --- a/temporalio/bridge/src/client.rs +++ b/temporalio/bridge/src/client.rs @@ -92,6 +92,7 @@ pub fn connect_client<'a>( config: ClientConfig, ) -> PyResult> { let opts: ClientOptions = config.try_into()?; + runtime_ref.runtime.assert_same_process("create client")?; let runtime = runtime_ref.runtime.clone(); runtime_ref.runtime.future_into_py(py, async move { Ok(ClientRef { diff --git a/temporalio/bridge/src/client_rpc_generated.rs b/temporalio/bridge/src/client_rpc_generated.rs index 659f5d8cf..0b2d2ffa8 100644 --- a/temporalio/bridge/src/client_rpc_generated.rs +++ b/temporalio/bridge/src/client_rpc_generated.rs @@ -15,6 +15,7 @@ impl ClientRef { py: Python<'p>, call: RpcCall, ) -> PyResult> { + self.runtime.assert_same_process("use client")?; use temporal_client::WorkflowService; let mut retry_client = self.retry_client.clone(); self.runtime.future_into_py(py, async move { @@ -566,6 +567,7 @@ impl ClientRef { py: Python<'p>, call: RpcCall, ) -> PyResult> { + self.runtime.assert_same_process("use client")?; use temporal_client::OperatorService; let mut retry_client = self.retry_client.clone(); self.runtime.future_into_py(py, async move { @@ -628,6 +630,7 @@ impl ClientRef { } fn call_cloud_service<'p>(&self, py: Python<'p>, call: RpcCall) -> PyResult> { + self.runtime.assert_same_process("use client")?; use temporal_client::CloudService; let mut retry_client = self.retry_client.clone(); self.runtime.future_into_py(py, async move { @@ -842,6 +845,7 @@ impl ClientRef { } fn call_test_service<'p>(&self, py: Python<'p>, call: RpcCall) -> PyResult> { + self.runtime.assert_same_process("use client")?; use temporal_client::TestService; let mut retry_client = self.retry_client.clone(); self.runtime.future_into_py(py, async move { @@ -881,6 +885,7 @@ impl ClientRef { } fn call_health_service<'p>(&self, py: Python<'p>, call: RpcCall) -> PyResult> { + self.runtime.assert_same_process("use client")?; use temporal_client::HealthService; let mut retry_client = self.retry_client.clone(); self.runtime.future_into_py(py, async move { diff --git a/temporalio/bridge/src/runtime.rs b/temporalio/bridge/src/runtime.rs index 72cc905ae..a75aeb3e3 100644 --- a/temporalio/bridge/src/runtime.rs +++ b/temporalio/bridge/src/runtime.rs @@ -1,5 +1,5 @@ use futures::channel::mpsc::Receiver; -use pyo3::exceptions::{PyRuntimeError, PyValueError}; +use pyo3::exceptions::{PyAssertionError, PyRuntimeError, PyValueError}; use pyo3::prelude::*; use pythonize::pythonize; use std::collections::HashMap; @@ -33,6 +33,7 @@ pub struct RuntimeRef { #[derive(Clone)] pub(crate) struct Runtime { + pub(crate) pid: u32, pub(crate) core: Arc, metrics_call_buffer: Option>>, log_forwarder_handle: Option>>, @@ -173,6 +174,7 @@ pub fn init_runtime(telemetry_config: TelemetryConfig) -> PyResult { Ok(RuntimeRef { runtime: Runtime { + pid: std::process::id(), core: Arc::new(core), metrics_call_buffer, log_forwarder_handle, @@ -197,6 +199,18 @@ impl Runtime { let _guard = self.core.tokio_handle().enter(); pyo3_async_runtimes::generic::future_into_py::(py, fut) } + + pub(crate) fn assert_same_process(&self, action: &'static str) -> PyResult<()> { + let current_pid = std::process::id(); + if self.pid != current_pid { + Err(PyAssertionError::new_err(format!( + "Cannot {} across forks (original runtime PID is {}, current is {})", + action, self.pid, current_pid, + ))) + } else { + Ok(()) + } + } } impl Drop for Runtime { diff --git a/temporalio/bridge/src/worker.rs b/temporalio/bridge/src/worker.rs index 92b43f356..549f4268f 100644 --- a/temporalio/bridge/src/worker.rs +++ b/temporalio/bridge/src/worker.rs @@ -474,6 +474,7 @@ pub fn new_worker( config: WorkerConfig, ) -> PyResult { enter_sync!(runtime_ref.runtime); + runtime_ref.runtime.assert_same_process("create worker")?; let event_loop_task_locals = Arc::new(OnceLock::new()); let config = convert_worker_config(config, event_loop_task_locals.clone())?; let worker = temporal_sdk_core::init_worker( @@ -495,6 +496,9 @@ pub fn new_replay_worker<'a>( config: WorkerConfig, ) -> PyResult> { enter_sync!(runtime_ref.runtime); + runtime_ref + .runtime + .assert_same_process("create replay worker")?; let event_loop_task_locals = Arc::new(OnceLock::new()); let config = convert_worker_config(config, event_loop_task_locals.clone())?; let (history_pusher, stream) = HistoryPusher::new(runtime_ref.runtime.clone()); @@ -519,6 +523,7 @@ pub fn new_replay_worker<'a>( #[pymethods] impl WorkerRef { fn validate<'p>(&self, py: Python<'p>) -> PyResult> { + self.runtime.assert_same_process("use worker")?; let worker = self.worker.as_ref().unwrap().clone(); // Set custom slot supplier task locals so they can run futures. // Event loop is assumed to be running at this point. @@ -538,6 +543,7 @@ impl WorkerRef { } fn poll_workflow_activation<'p>(&self, py: Python<'p>) -> PyResult> { + self.runtime.assert_same_process("use worker")?; let worker = self.worker.as_ref().unwrap().clone(); self.runtime.future_into_py(py, async move { let bytes = match worker.poll_workflow_activation().await { @@ -550,6 +556,7 @@ impl WorkerRef { } fn poll_activity_task<'p>(&self, py: Python<'p>) -> PyResult> { + self.runtime.assert_same_process("use worker")?; let worker = self.worker.as_ref().unwrap().clone(); self.runtime.future_into_py(py, async move { let bytes = match worker.poll_activity_task().await { @@ -562,6 +569,7 @@ impl WorkerRef { } fn poll_nexus_task<'p>(&self, py: Python<'p>) -> PyResult> { + self.runtime.assert_same_process("use worker")?; let worker = self.worker.as_ref().unwrap().clone(); self.runtime.future_into_py(py, async move { let bytes = match worker.poll_nexus_task().await { diff --git a/temporalio/runtime.py b/temporalio/runtime.py index 64fa12192..345d7ca77 100644 --- a/temporalio/runtime.py +++ b/temporalio/runtime.py @@ -24,22 +24,61 @@ import temporalio.bridge.runtime import temporalio.common -_default_runtime: Optional[Runtime] = None + +class _RuntimeRef: + def __init__( + self, + ) -> None: + self._default_runtime: Runtime | None = None + self._prevent_default = False + + def default(self) -> Runtime: + if not self._default_runtime: + if self._prevent_default: + raise RuntimeError( + "Cannot create default Runtime after Runtime.prevent_default has been called" + ) + self._default_runtime = Runtime(telemetry=TelemetryConfig()) + self._default_created = True + return self._default_runtime + + def prevent_default(self): + if self._default_runtime: + raise RuntimeError( + "Runtime.prevent_default called after default runtime has been created or set" + ) + self._prevent_default = True + + def set_default( + self, runtime: Runtime, *, error_if_already_set: bool = True + ) -> None: + if self._default_runtime and error_if_already_set: + raise RuntimeError("Runtime default already set") + + self._default_runtime = runtime + + +_runtime_ref: _RuntimeRef = _RuntimeRef() class Runtime: """Runtime for Temporal Python SDK. - Users are encouraged to use :py:meth:`default`. It can be set with + Most users are encouraged to use :py:meth:`default`. It can be set with :py:meth:`set_default`. Every time a new runtime is created, a new internal thread pool is created. - Runtimes do not work across forks. + Runtimes do not work across forks. Advanced users should consider using + :py:meth:`prevent_default` and `:py:meth`set_default` to ensure each + fork creates it's own runtime. + """ @classmethod def default(cls) -> Runtime: - """Get the default runtime, creating if not already created. + """Get the default runtime, creating if not already created. If :py:meth:`prevent_default` + is called before this method it will raise a RuntimeError instead of creating a default + runtime. If the default runtime needs to be different, it should be done with :py:meth:`set_default` before this is called or ever used. @@ -47,10 +86,20 @@ def default(cls) -> Runtime: Returns: The default runtime. """ - global _default_runtime - if not _default_runtime: - _default_runtime = cls(telemetry=TelemetryConfig()) - return _default_runtime + global _runtime_ref + return _runtime_ref.default() + + @classmethod + def prevent_default(cls): + """Prevent :py:meth:`default` from lazily creating a :py:class:`Runtime`. + + Raises a RuntimeError if a default :py:class:`Runtime` has already been created. + + Explicitly setting a default runtime with :py:meth:`set_default` bypasses this setting and + future calls to :py:meth:`default` will return the provided runtime. + """ + global _runtime_ref + _runtime_ref.prevent_default() @staticmethod def set_default(runtime: Runtime, *, error_if_already_set: bool = True) -> None: @@ -65,10 +114,8 @@ def set_default(runtime: Runtime, *, error_if_already_set: bool = True) -> None: error_if_already_set: If True and default is already set, this will raise a RuntimeError. """ - global _default_runtime - if _default_runtime and error_if_already_set: - raise RuntimeError("Runtime default already set") - _default_runtime = runtime + global _runtime_ref + _runtime_ref.set_default(runtime, error_if_already_set=error_if_already_set) def __init__(self, *, telemetry: TelemetryConfig) -> None: """Create a default runtime with the given telemetry config. diff --git a/tests/conftest.py b/tests/conftest.py index 8ffd3a456..c0f8bc5e0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,8 @@ import asyncio +import multiprocessing.context import os import sys -from typing import AsyncGenerator +from typing import AsyncGenerator, Iterator import pytest import pytest_asyncio @@ -133,6 +134,23 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]: await env.shutdown() +@pytest.fixture(scope="session") +def mp_fork_ctx() -> Iterator[multiprocessing.context.BaseContext | None]: + mp_ctx = None + try: + mp_ctx = multiprocessing.get_context("fork") + except ValueError: + pass + + try: + yield mp_ctx + finally: + if mp_ctx: + for p in mp_ctx.active_children(): + p.terminate() + p.join() + + @pytest_asyncio.fixture async def client(env: WorkflowEnvironment) -> Client: return env.client diff --git a/tests/helpers/fork.py b/tests/helpers/fork.py new file mode 100644 index 000000000..e6d84652f --- /dev/null +++ b/tests/helpers/fork.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import asyncio +import multiprocessing +import multiprocessing.context +import sys +from dataclasses import dataclass +from typing import Any + +import pytest + + +@dataclass +class _ForkTestResult: + status: str + err_name: str | None + err_msg: str | None + + def __eq__(self, value: object) -> bool: + if not isinstance(value, _ForkTestResult): + return False + + valid_err_msg = False + + if self.err_msg and value.err_msg: + valid_err_msg = ( + self.err_msg in value.err_msg or value.err_msg in self.err_msg + ) + + return ( + value.status == self.status + and value.err_name == value.err_name + and valid_err_msg + ) + + @staticmethod + def assertion_error(message: str) -> _ForkTestResult: + return _ForkTestResult( + status="error", err_name="AssertionError", err_msg=message + ) + + +class _TestFork: + _expected: _ForkTestResult + + async def coro(self) -> Any: + raise NotImplementedError() + + def entry(self): + event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(event_loop) + try: + event_loop.run_until_complete(self.coro()) + payload = _ForkTestResult(status="ok", err_name=None, err_msg=None) + except BaseException as err: + payload = _ForkTestResult( + status="error", err_name=err.__class__.__name__, err_msg=str(err) + ) + + self._child_conn.send(payload) + self._child_conn.close() + + def run(self, mp_fork_context: multiprocessing.context.BaseContext | None): + process_factory = getattr(mp_fork_context, "Process", None) + + if not mp_fork_context or not process_factory: + pytest.skip("fork context not available") + + self._parent_conn, self._child_conn = mp_fork_context.Pipe(duplex=False) + # start fork + child_process = process_factory(target=self.entry, args=(), daemon=False) + child_process.start() + # close parent's handle on child_conn + self._child_conn.close() + + # get run info from pipe + payload = self._parent_conn.recv() + self._parent_conn.close() + + assert payload == self._expected diff --git a/tests/test_client.py b/tests/test_client.py index 63dec2810..458492ff6 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,5 +1,7 @@ import dataclasses import json +import multiprocessing +import multiprocessing.context import os import uuid from datetime import datetime, timedelta, timezone @@ -90,6 +92,7 @@ new_worker, worker_versioning_enabled, ) +from tests.helpers.fork import _ForkTestResult, _TestFork from tests.helpers.worker import ( ExternalWorker, KSAction, @@ -1541,3 +1544,39 @@ async def get_schedule_result() -> Tuple[int, Optional[str]]: ) await handle.delete() + + +class TestForkCreateClient(_TestFork): + async def coro(self): + await Client.connect( + self._env.client.config()["service_client"].config.target_host + ) + + def test_fork_create_client( + self, + env: WorkflowEnvironment, + mp_fork_ctx: multiprocessing.context.BaseContext | None, + ): + self._expected = _ForkTestResult.assertion_error( + "Cannot create client across forks" + ) + self._env = env + self.run(mp_fork_ctx) + + +class TestForkUseClient(_TestFork): + async def coro(self): + await self._client.start_workflow( + "some-workflow", + id=f"workflow-{uuid.uuid4()}", + task_queue=f"tq-{uuid.uuid4()}", + ) + + def test_fork_use_client( + self, client: Client, mp_fork_ctx: multiprocessing.context.BaseContext | None + ): + self._expected = _ForkTestResult.assertion_error( + "Cannot use client across forks" + ) + self._client = client + self.run(mp_fork_ctx) diff --git a/tests/test_runtime.py b/tests/test_runtime.py index 4505ebfcf..9b318bbb7 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -7,6 +7,8 @@ from typing import List, cast from urllib.request import urlopen +import pytest + from temporalio import workflow from temporalio.client import Client from temporalio.runtime import ( @@ -16,6 +18,7 @@ Runtime, TelemetryConfig, TelemetryFilter, + _RuntimeRef, ) from temporalio.worker import Worker from tests.helpers import assert_eq_eventually, assert_eventually, find_free_port @@ -254,3 +257,53 @@ async def check_metrics() -> None: # Wait for metrics to appear and match the expected buckets await assert_eventually(check_metrics) + + +def test_runtime_ref_creates_default(): + ref = _RuntimeRef() + assert not ref._default_runtime + ref.default() + assert ref._default_runtime + + +def test_runtime_ref_prevents_default(): + ref = _RuntimeRef() + ref.prevent_default() + with pytest.raises(RuntimeError) as exc_info: + ref.default() + assert exc_info.match( + "Cannot create default Runtime after Runtime.prevent_default has been called" + ) + + # explicitly setting a default runtime will allow future calls to `default()`` + explicit_runtime = Runtime(telemetry=TelemetryConfig()) + ref.set_default(explicit_runtime) + + assert ref.default() is explicit_runtime + + +def test_runtime_ref_prevent_default_errors_after_default(): + ref = _RuntimeRef() + ref.default() + with pytest.raises(RuntimeError) as exc_info: + ref.prevent_default() + + assert exc_info.match( + "Runtime.prevent_default called after default runtime has been created" + ) + + +def test_runtime_ref_set_default(): + ref = _RuntimeRef() + explicit_runtime = Runtime(telemetry=TelemetryConfig()) + ref.set_default(explicit_runtime) + assert ref.default() is explicit_runtime + + new_runtime = Runtime(telemetry=TelemetryConfig()) + + with pytest.raises(RuntimeError) as exc_info: + ref.set_default(new_runtime) + assert exc_info.match("Runtime default already set") + + ref.set_default(new_runtime, error_if_already_set=False) + assert ref.default() is new_runtime diff --git a/tests/worker/test_worker.py b/tests/worker/test_worker.py index 32f27f631..9ad4be3af 100644 --- a/tests/worker/test_worker.py +++ b/tests/worker/test_worker.py @@ -2,6 +2,8 @@ import asyncio import concurrent.futures +import multiprocessing +import multiprocessing.context import uuid from datetime import timedelta from typing import Any, Awaitable, Callable, Optional, Sequence @@ -58,6 +60,7 @@ new_worker, worker_versioning_enabled, ) +from tests.helpers.fork import _ForkTestResult, _TestFork from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name @@ -1271,3 +1274,43 @@ def shutdown(self) -> None: if self.next_exception_task: self.next_exception_task.cancel() setattr(self.worker._bridge_worker, self.attr, self.orig_poll_call) + + +class TestForkCreateWorker(_TestFork): + async def coro(self): + self._worker = Worker( + self._client, + task_queue=f"task-queue-{uuid.uuid4()}", + activities=[never_run_activity], + workflows=[], + nexus_service_handlers=[], + ) + + def test_fork_create_worker( + self, client: Client, mp_fork_ctx: multiprocessing.context.BaseContext | None + ): + self._expected = _ForkTestResult.assertion_error( + "Cannot create worker across forks" + ) + self._client = client + self.run(mp_fork_ctx) + + +class TestForkUseWorker(_TestFork): + async def coro(self): + await self._pre_fork_worker.run() + + def test_fork_use_worker( + self, client: Client, mp_fork_ctx: multiprocessing.context.BaseContext | None + ): + self._expected = _ForkTestResult.assertion_error( + "Cannot use worker across forks" + ) + self._pre_fork_worker = Worker( + client, + task_queue=f"task-queue-{uuid.uuid4()}", + activities=[never_run_activity], + workflows=[], + nexus_service_handlers=[], + ) + self.run(mp_fork_ctx)