diff --git a/src/core/services/tool_text_renderer.py b/src/core/services/tool_text_renderer.py index 4245dd0d..7a9257a9 100644 --- a/src/core/services/tool_text_renderer.py +++ b/src/core/services/tool_text_renderer.py @@ -5,6 +5,7 @@ import logging from abc import ABC, abstractmethod from collections.abc import Callable +from contextvars import ContextVar, Token from typing import Any from src.core.domain.chat import ToolCall @@ -279,31 +280,35 @@ def reset_renderer_registry() -> None: # Context manager to temporarily override the renderer for a block of code -_override: str | None = None +_override_var: ContextVar[str | None] = ContextVar( + "tool_text_renderer_override", default=None +) class OverrideRenderer: def __init__(self, renderer_name: str): self.renderer_name = renderer_name - self.original_override = _override + self._token: Token[str | None] | None = None def __enter__(self) -> None: - global _override - _override = self.renderer_name + self._token = _override_var.set(self.renderer_name) def __exit__(self, exc_type: Any, _: Any, traceback: Any) -> None: - global _override - _override = self.original_override + if self._token is not None: + _override_var.reset(self._token) + else: + _override_var.set(None) def render_tool_call(tool_call: ToolCall) -> str | None: """Render a tool call using the currently active renderer.""" - renderer_name = _override or _renderer_registry.default_renderer + current_override = _override_var.get() + renderer_name = current_override or _renderer_registry.default_renderer renderer = get_renderer(renderer_name) text = renderer.render(tool_call) if text: return text - if (_override or "").strip().lower() in {"", "none"}: + if (current_override or "").strip().lower() in {"", "none"}: return None fallback_name = _renderer_registry.fallback_renderer if fallback_name and fallback_name != renderer_name: diff --git a/tests/unit/core/services/test_tool_text_renderer.py b/tests/unit/core/services/test_tool_text_renderer.py new file mode 100644 index 00000000..02b03da5 --- /dev/null +++ b/tests/unit/core/services/test_tool_text_renderer.py @@ -0,0 +1,47 @@ +import asyncio +import json + +import pytest +from src.core.domain.chat import FunctionCall, ToolCall +from src.core.services.tool_text_renderer import ( + OverrideRenderer, + render_tool_call, + reset_renderer_registry, +) + + +@pytest.mark.asyncio +async def test_override_is_session_isolated() -> None: + """Ensure renderer overrides do not leak across concurrent sessions.""" + reset_renderer_registry() + tool_call = ToolCall( + id="call-1", + function=FunctionCall( + name="shell", + arguments=json.dumps({"command": ["echo", "hello"]}), + ), + ) + + start_override = asyncio.Event() + release_override = asyncio.Event() + + async def session_with_override() -> str | None: + with OverrideRenderer("markdown"): + start_override.set() + await release_override.wait() + return render_tool_call(tool_call) + + async def concurrent_session() -> str | None: + await start_override.wait() + result = render_tool_call(tool_call) + release_override.set() + return result + + override_result, default_result = await asyncio.gather( + session_with_override(), + concurrent_session(), + ) + + assert override_result is not None and "```bash" in override_result + assert default_result is None + assert render_tool_call(tool_call) is None