Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions src/core/services/tool_text_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable
from contextvars import ContextVar, Token
from typing import Any

from src.core.domain.chat import ToolCall
Expand Down Expand Up @@ -279,31 +280,35 @@ def reset_renderer_registry() -> None:


# Context manager to temporarily override the renderer for a block of code
_override: str | None = None
_override_var: ContextVar[str | None] = ContextVar(
"tool_text_renderer_override", default=None
)


class OverrideRenderer:
def __init__(self, renderer_name: str):
self.renderer_name = renderer_name
self.original_override = _override
self._token: Token[str | None] | None = None

def __enter__(self) -> None:
global _override
_override = self.renderer_name
self._token = _override_var.set(self.renderer_name)

def __exit__(self, exc_type: Any, _: Any, traceback: Any) -> None:
global _override
_override = self.original_override
if self._token is not None:
_override_var.reset(self._token)
else:
_override_var.set(None)


def render_tool_call(tool_call: ToolCall) -> str | None:
"""Render a tool call using the currently active renderer."""
renderer_name = _override or _renderer_registry.default_renderer
current_override = _override_var.get()
renderer_name = current_override or _renderer_registry.default_renderer
renderer = get_renderer(renderer_name)
text = renderer.render(tool_call)
if text:
return text
if (_override or "").strip().lower() in {"", "none"}:
if (current_override or "").strip().lower() in {"", "none"}:
return None
fallback_name = _renderer_registry.fallback_renderer
if fallback_name and fallback_name != renderer_name:
Expand Down
47 changes: 47 additions & 0 deletions tests/unit/core/services/test_tool_text_renderer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import asyncio
import json

import pytest
from src.core.domain.chat import FunctionCall, ToolCall
from src.core.services.tool_text_renderer import (
OverrideRenderer,
render_tool_call,
reset_renderer_registry,
)


@pytest.mark.asyncio
async def test_override_is_session_isolated() -> None:
"""Ensure renderer overrides do not leak across concurrent sessions."""
reset_renderer_registry()
tool_call = ToolCall(
id="call-1",
function=FunctionCall(
name="shell",
arguments=json.dumps({"command": ["echo", "hello"]}),
),
)

start_override = asyncio.Event()
release_override = asyncio.Event()

async def session_with_override() -> str | None:
with OverrideRenderer("markdown"):
start_override.set()
await release_override.wait()
return render_tool_call(tool_call)

async def concurrent_session() -> str | None:
await start_override.wait()
result = render_tool_call(tool_call)
release_override.set()
return result

override_result, default_result = await asyncio.gather(
session_with_override(),
concurrent_session(),
)

assert override_result is not None and "```bash" in override_result
assert default_result is None
assert render_tool_call(tool_call) is None
Loading