Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 59 additions & 22 deletions src/core/app/middleware/content_rewriting_middleware.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from collections.abc import AsyncIterator
from typing import Any

from fastapi import Request
Expand All @@ -8,6 +9,13 @@


class ContentRewritingMiddleware(BaseHTTPMiddleware):
"""Middleware that performs configurable prompt/response rewriting."""

#: Maximum number of response bytes buffered when rewriting a streaming
#: response. Exceeding this threshold disables rewriting to prevent
#: unbounded memory usage that could be triggered remotely.
_MAX_STREAM_REWRITE_BYTES = 512_000

def __init__(self, app, rewriter: ContentRewriterService):
super().__init__(app)
self.rewriter = rewriter
Expand Down Expand Up @@ -398,43 +406,72 @@ async def receive():
response = await call_next(request_for_next_call)

# Step 3: Potentially rewrite the response
if isinstance(response, StreamingResponse):
if getattr(response, "body_iterator", None) is not None:

async def stream_with_rewrite() -> AsyncIterator[bytes]:
buffer = bytearray()
rewrite_allowed = True
charset = getattr(response, "charset", None) or "utf-8"

async def new_iterator():
response_body = b""
async for chunk in response.body_iterator:
response_body += chunk
rewritten_body = self.rewriter.rewrite_reply(response_body.decode())
yield rewritten_body.encode("utf-8")
if not chunk:
continue

if isinstance(chunk, (bytes, bytearray, memoryview)):
chunk_bytes = (
chunk if isinstance(chunk, bytes) else bytes(chunk)
)
else:
chunk_bytes = str(chunk).encode(charset, errors="ignore")

if rewrite_allowed:
projected_length = len(buffer) + len(chunk_bytes)
if projected_length <= self._MAX_STREAM_REWRITE_BYTES:
buffer.extend(chunk_bytes)
continue

rewrite_allowed = False
if buffer:
yield bytes(buffer)
buffer.clear()

yield bytes(chunk_bytes)

if rewrite_allowed and buffer:
text = buffer.decode(charset, errors="ignore")
rewritten = self.rewriter.rewrite_reply(text)
yield rewritten.encode(charset)
elif buffer:
yield bytes(buffer)

background = response.background
response.background = None

return StreamingResponse(
new_iterator(),
stream_with_rewrite(),
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
background=background,
)
else:
response_body = response.body
try:
data = json.loads(response_body)
is_rewritten = False

if self._rewrite_chat_response(data):
is_rewritten = True
response_body = response.body
try:
data = json.loads(response_body)
is_rewritten = False

if self._rewrite_responses_output(data):
is_rewritten = True
if self._rewrite_chat_response(data):
is_rewritten = True

if is_rewritten:
new_body = json.dumps(data).encode("utf-8")
response.body = new_body
response.headers["content-length"] = str(len(new_body))
if self._rewrite_responses_output(data):
is_rewritten = True

except json.JSONDecodeError:
pass
if is_rewritten:
new_body = json.dumps(data).encode("utf-8")
response.body = new_body
response.headers["content-length"] = str(len(new_body))

except json.JSONDecodeError:
pass

return response
64 changes: 64 additions & 0 deletions tests/unit/app/middleware/test_content_rewriting_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations

from fastapi import FastAPI
from starlette.responses import StreamingResponse
from starlette.testclient import TestClient

from src.core.app.middleware.content_rewriting_middleware import (
ContentRewritingMiddleware,
)


class DummyRewriter:
"""Simple stand-in for ``ContentRewriterService`` used in tests."""

def __init__(self) -> None:
self.rewrite_calls: list[str] = []

def rewrite_prompt(
self, prompt: str, prompt_type: str
) -> str: # pragma: no cover - unused
return prompt

def rewrite_reply(self, reply: str) -> str:
self.rewrite_calls.append(reply)
return reply.replace("foo", "bar")


def _create_streaming_app(rewriter: DummyRewriter) -> TestClient:
app = FastAPI()

@app.get("/stream")
async def stream_endpoint() -> StreamingResponse:
async def generator():
yield b"foo"
yield b"foo"

return StreamingResponse(generator(), media_type="text/plain")

app.add_middleware(ContentRewritingMiddleware, rewriter=rewriter)
return TestClient(app)


def test_streaming_rewrite_within_limit(monkeypatch):
monkeypatch.setattr(ContentRewritingMiddleware, "_MAX_STREAM_REWRITE_BYTES", 1024)
rewriter = DummyRewriter()
client = _create_streaming_app(rewriter)

response = client.get("/stream")

assert response.status_code == 200
assert response.text == "barbar"
assert rewriter.rewrite_calls == ["foofoo"]


def test_streaming_rewrite_skips_when_limit_exceeded(monkeypatch):
monkeypatch.setattr(ContentRewritingMiddleware, "_MAX_STREAM_REWRITE_BYTES", 4)
rewriter = DummyRewriter()
client = _create_streaming_app(rewriter)

response = client.get("/stream")

assert response.status_code == 200
assert response.text == "foofoo"
assert rewriter.rewrite_calls == []
Loading