diff --git a/src/core/app/middleware/content_rewriting_middleware.py b/src/core/app/middleware/content_rewriting_middleware.py index e7f983e8c..1229287cb 100644 --- a/src/core/app/middleware/content_rewriting_middleware.py +++ b/src/core/app/middleware/content_rewriting_middleware.py @@ -1,4 +1,5 @@ import json +from collections.abc import AsyncIterator from typing import Any from fastapi import Request @@ -8,6 +9,13 @@ class ContentRewritingMiddleware(BaseHTTPMiddleware): + """Middleware that performs configurable prompt/response rewriting.""" + + #: Maximum number of response bytes buffered when rewriting a streaming + #: response. Exceeding this threshold disables rewriting to prevent + #: unbounded memory usage that could be triggered remotely. + _MAX_STREAM_REWRITE_BYTES = 512_000 + def __init__(self, app, rewriter: ContentRewriterService): super().__init__(app) self.rewriter = rewriter @@ -398,43 +406,72 @@ async def receive(): response = await call_next(request_for_next_call) # Step 3: Potentially rewrite the response - if isinstance(response, StreamingResponse): + if getattr(response, "body_iterator", None) is not None: + + async def stream_with_rewrite() -> AsyncIterator[bytes]: + buffer = bytearray() + rewrite_allowed = True + charset = getattr(response, "charset", None) or "utf-8" - async def new_iterator(): - response_body = b"" async for chunk in response.body_iterator: - response_body += chunk - rewritten_body = self.rewriter.rewrite_reply(response_body.decode()) - yield rewritten_body.encode("utf-8") + if not chunk: + continue + + if isinstance(chunk, (bytes, bytearray, memoryview)): + chunk_bytes = ( + chunk if isinstance(chunk, bytes) else bytes(chunk) + ) + else: + chunk_bytes = str(chunk).encode(charset, errors="ignore") + + if rewrite_allowed: + projected_length = len(buffer) + len(chunk_bytes) + if projected_length <= self._MAX_STREAM_REWRITE_BYTES: + buffer.extend(chunk_bytes) + continue + + rewrite_allowed = False + if buffer: + yield bytes(buffer) + buffer.clear() + + yield bytes(chunk_bytes) + + if rewrite_allowed and buffer: + text = buffer.decode(charset, errors="ignore") + rewritten = self.rewriter.rewrite_reply(text) + yield rewritten.encode(charset) + elif buffer: + yield bytes(buffer) background = response.background response.background = None return StreamingResponse( - new_iterator(), + stream_with_rewrite(), status_code=response.status_code, headers=dict(response.headers), media_type=response.media_type, background=background, ) - else: - response_body = response.body - try: - data = json.loads(response_body) - is_rewritten = False - if self._rewrite_chat_response(data): - is_rewritten = True + response_body = response.body + try: + data = json.loads(response_body) + is_rewritten = False - if self._rewrite_responses_output(data): - is_rewritten = True + if self._rewrite_chat_response(data): + is_rewritten = True - if is_rewritten: - new_body = json.dumps(data).encode("utf-8") - response.body = new_body - response.headers["content-length"] = str(len(new_body)) + if self._rewrite_responses_output(data): + is_rewritten = True - except json.JSONDecodeError: - pass + if is_rewritten: + new_body = json.dumps(data).encode("utf-8") + response.body = new_body + response.headers["content-length"] = str(len(new_body)) + + except json.JSONDecodeError: + pass return response diff --git a/tests/unit/app/middleware/test_content_rewriting_middleware.py b/tests/unit/app/middleware/test_content_rewriting_middleware.py new file mode 100644 index 000000000..000858eb0 --- /dev/null +++ b/tests/unit/app/middleware/test_content_rewriting_middleware.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from fastapi import FastAPI +from starlette.responses import StreamingResponse +from starlette.testclient import TestClient + +from src.core.app.middleware.content_rewriting_middleware import ( + ContentRewritingMiddleware, +) + + +class DummyRewriter: + """Simple stand-in for ``ContentRewriterService`` used in tests.""" + + def __init__(self) -> None: + self.rewrite_calls: list[str] = [] + + def rewrite_prompt( + self, prompt: str, prompt_type: str + ) -> str: # pragma: no cover - unused + return prompt + + def rewrite_reply(self, reply: str) -> str: + self.rewrite_calls.append(reply) + return reply.replace("foo", "bar") + + +def _create_streaming_app(rewriter: DummyRewriter) -> TestClient: + app = FastAPI() + + @app.get("/stream") + async def stream_endpoint() -> StreamingResponse: + async def generator(): + yield b"foo" + yield b"foo" + + return StreamingResponse(generator(), media_type="text/plain") + + app.add_middleware(ContentRewritingMiddleware, rewriter=rewriter) + return TestClient(app) + + +def test_streaming_rewrite_within_limit(monkeypatch): + monkeypatch.setattr(ContentRewritingMiddleware, "_MAX_STREAM_REWRITE_BYTES", 1024) + rewriter = DummyRewriter() + client = _create_streaming_app(rewriter) + + response = client.get("/stream") + + assert response.status_code == 200 + assert response.text == "barbar" + assert rewriter.rewrite_calls == ["foofoo"] + + +def test_streaming_rewrite_skips_when_limit_exceeded(monkeypatch): + monkeypatch.setattr(ContentRewritingMiddleware, "_MAX_STREAM_REWRITE_BYTES", 4) + rewriter = DummyRewriter() + client = _create_streaming_app(rewriter) + + response = client.get("/stream") + + assert response.status_code == 200 + assert response.text == "foofoo" + assert rewriter.rewrite_calls == []