Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 41 additions & 7 deletions src/core/app/middleware/content_rewriting_middleware.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import codecs
import json
from collections.abc import AsyncIterator
from typing import Any

from fastapi import Request
Expand Down Expand Up @@ -399,19 +401,51 @@ async def receive():

# Step 3: Potentially rewrite the response
if isinstance(response, StreamingResponse):
original_iterator = response.body_iterator

if original_iterator is None:
return response

# If no reply rules are configured we can pass the stream through.
if not self.rewriter.reply_rules:
return response

max_pattern_length = self.rewriter.max_reply_search_length
overlap = max(0, max_pattern_length - 1)
# Keep the working buffer bounded to avoid unbounded memory growth.
chunk_limit = max(65536, (overlap + 1) * 2)

async def rewriting_iterator() -> AsyncIterator[bytes]:
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
buffer = ""

async for chunk in original_iterator:
if isinstance(chunk, bytes):
text_chunk = decoder.decode(chunk, final=False)
else:
text_chunk = decoder.decode(chunk.encode("utf-8"), final=False)
buffer += text_chunk

while len(buffer) > chunk_limit:
emit_upto = len(buffer) - overlap if overlap else len(buffer)
if emit_upto <= 0:
break
to_emit = buffer[:emit_upto]
buffer = buffer[emit_upto:]
if to_emit:
rewritten = self.rewriter.rewrite_reply(to_emit)
yield rewritten.encode("utf-8")

async def new_iterator():
response_body = b""
async for chunk in response.body_iterator:
response_body += chunk
rewritten_body = self.rewriter.rewrite_reply(response_body.decode())
yield rewritten_body.encode("utf-8")
remaining = buffer + decoder.decode(b"", final=True)
if remaining:
rewritten_remaining = self.rewriter.rewrite_reply(remaining)
yield rewritten_remaining.encode("utf-8")

background = response.background
response.background = None

return StreamingResponse(
new_iterator(),
rewriting_iterator(),
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
Expand Down
30 changes: 30 additions & 0 deletions src/core/services/content_rewriter_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
self.prompt_system_rules: list[ReplacementRule] = []
self.prompt_user_rules: list[ReplacementRule] = []
self.reply_rules: list[ReplacementRule] = []
self._max_reply_search_length: int | None = None
self.load_rules()

def load_rules(self) -> None:
Expand All @@ -50,6 +51,35 @@ def load_rules(self) -> None:
self.reply_rules = self._load_rules_from_dir(
os.path.join(self.config_path, "replies")
)
self.refresh_rule_cache()

def refresh_rule_cache(self) -> None:
"""Recompute cached values derived from replacement rules."""

self._max_reply_search_length = self._calculate_max_search_length(
self.reply_rules
)

@staticmethod
def _calculate_max_search_length(rules: list[ReplacementRule]) -> int:
"""Return the length of the longest search string in ``rules``."""

max_length = 0
for rule in rules:
search = rule.search
if search:
search_length = len(search)
if search_length > max_length:
max_length = search_length
return max_length

@property
def max_reply_search_length(self) -> int:
"""Maximum length of reply rule search patterns (cached)."""

if self._max_reply_search_length is None:
self.refresh_rule_cache()
return self._max_reply_search_length or 0

def _load_rules_from_dir(self, directory: str) -> list[ReplacementRule]:
"""Loads rules from a specific directory."""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from __future__ import annotations

import asyncio

import pytest
from src.core.app.middleware.content_rewriting_middleware import (
ContentRewritingMiddleware,
)
from src.core.domain.replacement_rule import ReplacementMode, ReplacementRule
from src.core.services.content_rewriter_service import ContentRewriterService
from starlette.requests import Request
from starlette.responses import StreamingResponse


def _build_request() -> Request:
scope = {
"type": "http",
"method": "GET",
"path": "/",
"headers": [],
"client": ("127.0.0.1", 12345),
"server": ("127.0.0.1", 8000),
"scheme": "http",
"http_version": "1.1",
}

async def receive() -> dict[str, object]:
return {"type": "http.request", "body": b"", "more_body": False}

return Request(scope, receive)


@pytest.mark.asyncio
async def test_streaming_rewrite_preserves_streaming_and_cross_chunk_matches() -> None:
rewriter = ContentRewriterService(config_path="non-existent")
rewriter.reply_rules = [
ReplacementRule(
mode=ReplacementMode.REPLACE,
search="HELLO",
replace="BYE",
)
]
rewriter.refresh_rule_cache()

middleware = ContentRewritingMiddleware(lambda request: None, rewriter)
request = _build_request()

chunks = [b"start HEL", b"LO mid HEL", b"LO end"]
chunk_iterated: list[bytes] = []

async def chunk_generator():
for chunk in chunks:
chunk_iterated.append(chunk)
yield chunk
await asyncio.sleep(0)

async def call_next(_request: Request) -> StreamingResponse:
return StreamingResponse(chunk_generator(), media_type="text/plain")

response = await middleware.dispatch(request, call_next)

# Nothing has been consumed yet.
assert not chunk_iterated

body_iter = response.body_iterator
assert body_iter is not None

collected = b""
try:
while True:
collected += await body_iter.__anext__()
except StopAsyncIteration:
pass

# All chunks were produced lazily during iteration.
assert chunk_iterated == chunks

assert collected.decode("utf-8") == "start BYE mid BYE end"
Loading