Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
345 changes: 174 additions & 171 deletions src/core/services/streaming/tool_call_repair_processor.py
Original file line number Diff line number Diff line change
@@ -1,105 +1,105 @@
from __future__ import annotations
import logging
from typing import Any
from src.core.domain.streaming_response_processor import (
IStreamProcessor,
StreamingContent,
)
from src.core.interfaces.tool_call_repair_service_interface import (
IToolCallRepairService,
)
from src.core.services.streaming.stream_utils import get_stream_id
logger = logging.getLogger(__name__)
class ToolCallRepairProcessor(IStreamProcessor):
"""
Stream processor that uses ToolCallRepairService to detect and repair
tool calls within streaming content.
"""
def __init__(
self,
tool_call_repair_service: IToolCallRepairService,
*,
max_buffer_bytes: int | None = None,
) -> None:
self.tool_call_repair_service = tool_call_repair_service
service_cap = getattr(tool_call_repair_service, "max_buffer_bytes", None)
if max_buffer_bytes is not None:
self._max_buffer_bytes = max_buffer_bytes
elif isinstance(service_cap, int):
self._max_buffer_bytes = service_cap
else:
self._max_buffer_bytes = 64 * 1024
self._buffers: dict[str, dict[str, str]] = {}
async def process(self, content: StreamingContent) -> StreamingContent:
"""
Processes a streaming content chunk, attempting to repair tool calls.
"""
if content.is_empty and not content.is_done:
return content # Nothing to process
stream_id = get_stream_id(content)
state = self._buffers.get(stream_id, {"pending_text": ""})
metadata = dict(content.metadata or {})
detected_tool_calls: list[dict[str, Any]] = []
chunk_text = content.content or ""
reasoning_segments: list[str] = []
for key in ("reasoning_content", "reasoning"):
value = metadata.pop(key, None)
if isinstance(value, str) and value:
reasoning_segments.append(value)
if reasoning_segments:
state["pending_text"] += "".join(reasoning_segments)
if chunk_text:
state["pending_text"] += chunk_text
repaired_content_parts: list[str] = []
buffer_text = state["pending_text"]
for marker in ("<use_mcp_tool", "<patch_file"):
marker_index = buffer_text.find(marker)
if marker_index > 0:
prefix = buffer_text[:marker_index]
if prefix.strip():
repaired_content_parts.append(prefix)
buffer_text = buffer_text[marker_index:]
state["pending_text"] = buffer_text
break
if buffer_text:
repaired_json = self.tool_call_repair_service.repair_tool_calls(buffer_text)
if repaired_json:
detected_tool_calls.append(repaired_json)
snippet = getattr(
self.tool_call_repair_service, "last_tool_snippet", None
)
if snippet:
idx = buffer_text.find(snippet)
if idx != -1:
prefix = buffer_text[:idx]
suffix = buffer_text[idx + len(snippet) :]
if prefix.strip():
repaired_content_parts.append(prefix)
buffer_text = suffix
state["pending_text"] = buffer_text
if not detected_tool_calls:
trimmed = self._trim_buffer(state["pending_text"])
if trimmed:
repaired_content_parts.append(trimmed)
state["pending_text"] = state["pending_text"][len(trimmed) :]
from __future__ import annotations

import logging
from typing import Any

from src.core.domain.streaming_response_processor import (
IStreamProcessor,
StreamingContent,
)
from src.core.interfaces.tool_call_repair_service_interface import (
IToolCallRepairService,
)
from src.core.services.streaming.stream_utils import get_stream_id

logger = logging.getLogger(__name__)


class ToolCallRepairProcessor(IStreamProcessor):
"""
Stream processor that uses ToolCallRepairService to detect and repair
tool calls within streaming content.
"""

def __init__(
self,
tool_call_repair_service: IToolCallRepairService,
*,
max_buffer_bytes: int | None = None,
) -> None:
self.tool_call_repair_service = tool_call_repair_service
service_cap = getattr(tool_call_repair_service, "max_buffer_bytes", None)
if max_buffer_bytes is not None:
self._max_buffer_bytes = max_buffer_bytes
elif isinstance(service_cap, int):
self._max_buffer_bytes = service_cap
else:
self._max_buffer_bytes = 64 * 1024

self._buffers: dict[str, dict[str, str]] = {}

async def process(self, content: StreamingContent) -> StreamingContent:
"""
Processes a streaming content chunk, attempting to repair tool calls.
"""
if content.is_empty and not content.is_done:
return content # Nothing to process

stream_id = get_stream_id(content)
state = self._buffers.get(stream_id, {"pending_text": ""})
metadata = dict(content.metadata or {})
detected_tool_calls: list[dict[str, Any]] = []

chunk_text = content.content or ""
reasoning_segments: list[str] = []
for key in ("reasoning_content", "reasoning"):
value = metadata.pop(key, None)
if isinstance(value, str) and value:
reasoning_segments.append(value)

if reasoning_segments:
state["pending_text"] += "".join(reasoning_segments)

if chunk_text:
state["pending_text"] += chunk_text

repaired_content_parts: list[str] = []

buffer_text = state["pending_text"]

for marker in ("<use_mcp_tool", "<patch_file"):
marker_index = buffer_text.find(marker)
if marker_index > 0:
prefix = buffer_text[:marker_index]
if prefix.strip():
repaired_content_parts.append(prefix)
buffer_text = buffer_text[marker_index:]
state["pending_text"] = buffer_text
break

if buffer_text:
repaired_json = self.tool_call_repair_service.repair_tool_calls(buffer_text)
if repaired_json:
detected_tool_calls.append(repaired_json)
snippet = getattr(
self.tool_call_repair_service, "last_tool_snippet", None
)
if snippet:
idx = buffer_text.find(snippet)
if idx != -1:
prefix = buffer_text[:idx]
suffix = buffer_text[idx + len(snippet) :]
if prefix.strip():
repaired_content_parts.append(prefix)
buffer_text = suffix
state["pending_text"] = buffer_text

if not detected_tool_calls:
trimmed = self._trim_buffer(state["pending_text"])
if trimmed:
repaired_content_parts.append(trimmed)
state["pending_text"] = state["pending_text"][len(trimmed) :]

if content.is_done:
pending_text = state["pending_text"]
Comment on lines +97 to 104
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Restore pass-through for plain text streaming.

Because _trim_buffer only returns data once pending_text exceeds self._max_buffer_bytes, this branch now emits nothing for ordinary chunks that do not contain tool calls. With the default cap (~64 KiB), plain-text streams return empty chunks until the terminal is_done flush, effectively disabling streaming for normal responses. Please reinstate immediate pass-through for non-tool-call text (keeping only a minimal tail for future detection) so we do not regress streaming behaviour.

🤖 Prompt for AI Agents
In src/core/services/streaming/tool_call_repair_processor.py around lines 97 to
104, the branch that handles "no detected_tool_calls" currently relies on
_trim_buffer and therefore emits nothing until pending_text exceeds
_max_buffer_bytes; restore immediate pass-through for ordinary text by emitting
all but a small tail used for future detection: if pending_text length <=
self._max_buffer_bytes, append the whole pending_text to repaired_content_parts
and clear state["pending_text"]; otherwise append
pending_text[:-self._max_buffer_bytes] to repaired_content_parts and set
state["pending_text"] to the last self._max_buffer_bytes characters; keep the
existing behavior for content.is_done (flush remaining pending_text).

if pending_text:
Expand Down Expand Up @@ -133,72 +133,75 @@ async def process(self, content: StreamingContent) -> StreamingContent:
if not handled and pending_text:
repaired_content_parts.append(pending_text)
state["pending_text"] = ""

if content.is_done or content.is_cancellation:
self._buffers.pop(stream_id, None)
else:
if state["pending_text"]:
self._buffers[stream_id] = state
else:
self._buffers.pop(stream_id, None)

new_content_str = "".join(repaired_content_parts)
if detected_tool_calls:
logger.debug(
"ToolCallRepairProcessor captured tool call(s): %s", detected_tool_calls
)
existing_calls = metadata.get("tool_calls")
if isinstance(existing_calls, list):
metadata["tool_calls"] = existing_calls + detected_tool_calls
else:
metadata["tool_calls"] = detected_tool_calls
metadata.setdefault("finish_reason", "tool_calls")

if new_content_str or detected_tool_calls or content.is_done:
return StreamingContent(
content=new_content_str,
is_done=content.is_done,
is_cancellation=content.is_cancellation,
metadata=metadata,
usage=content.usage,
raw_data=content.raw_data,
)

return StreamingContent(
content="",
is_cancellation=content.is_cancellation,
metadata=metadata,
usage=content.usage,
raw_data=content.raw_data,
) # Return empty if nothing to yield

def _trim_buffer(self, buffer: str) -> str:
"""Flush enough leading content to honor the buffer cap."""

if not buffer:
return ""

encoded_length = len(buffer.encode("utf-8"))
if encoded_length <= self._max_buffer_bytes:
return ""

overflow = encoded_length - self._max_buffer_bytes
flushed_chars = []
consumed = 0

for ch in buffer:
char_bytes = len(ch.encode("utf-8"))
flushed_chars.append(ch)
consumed += char_bytes
if consumed >= overflow:
break

flush_text = "".join(flushed_chars)

logger.warning(
"ToolCallRepairProcessor buffer exceeded %d bytes; flushed %d characters",
self._max_buffer_bytes,
len(flush_text),
)

return flush_text

if content.is_done or content.is_cancellation:
self._buffers.pop(stream_id, None)
else:
if state["pending_text"]:
self._buffers[stream_id] = state
else:
self._buffers.pop(stream_id, None)

new_content_str = "".join(repaired_content_parts)
if detected_tool_calls:
logger.debug(
"ToolCallRepairProcessor captured tool call(s): %s", detected_tool_calls
)
existing_calls = metadata.get("tool_calls")
if isinstance(existing_calls, list):
metadata["tool_calls"] = existing_calls + detected_tool_calls
else:
metadata["tool_calls"] = detected_tool_calls
metadata.setdefault("finish_reason", "tool_calls")

if new_content_str or detected_tool_calls or content.is_done:
return StreamingContent(
content=new_content_str,
is_done=content.is_done,
is_cancellation=content.is_cancellation,
metadata=metadata,
usage=content.usage,
raw_data=content.raw_data,
)

return StreamingContent(
content="",
is_cancellation=content.is_cancellation,
metadata=metadata,
usage=content.usage,
raw_data=content.raw_data,
) # Return empty if nothing to yield

def _trim_buffer(self, buffer: str) -> str:
"""Flush enough leading content to honor the buffer cap."""

if not buffer:
return ""

encoded_buffer = buffer.encode("utf-8")
encoded_length = len(encoded_buffer)
if encoded_length <= self._max_buffer_bytes:
return ""

overflow = encoded_length - self._max_buffer_bytes
cutoff = overflow

# Adjust cutoff to avoid splitting multibyte UTF-8 characters.
# decode() with strict errors will raise if the slice ends mid-character.
while cutoff <= encoded_length:
try:
flush_text = encoded_buffer[:cutoff].decode("utf-8")
break
except UnicodeDecodeError:
cutoff += 1
else:
# If decoding still fails (extremely unlikely), fall back to flushing all
flush_text = buffer

logger.warning(
"ToolCallRepairProcessor buffer exceeded %d bytes; flushed %d characters",
self._max_buffer_bytes,
len(flush_text),
)

return flush_text
Loading
Loading