From 6835dd99a3c8c83cfc23bf9c44d7e714e6ec8126 Mon Sep 17 00:00:00 2001 From: matdev83 <211248003+matdev83@users.noreply.github.com> Date: Fri, 7 Nov 2025 15:05:46 +0100 Subject: [PATCH] Fix streaming content rewriting DoS vulnerability --- .../content_rewriting_middleware.py | 48 ++++++++++-- src/core/services/content_rewriter_service.py | 30 +++++++ .../test_content_rewriting_middleware.py | 78 +++++++++++++++++++ 3 files changed, 149 insertions(+), 7 deletions(-) create mode 100644 tests/unit/core/app/middleware/test_content_rewriting_middleware.py diff --git a/src/core/app/middleware/content_rewriting_middleware.py b/src/core/app/middleware/content_rewriting_middleware.py index e7f983e8c..eb3fe9718 100644 --- a/src/core/app/middleware/content_rewriting_middleware.py +++ b/src/core/app/middleware/content_rewriting_middleware.py @@ -1,4 +1,6 @@ +import codecs import json +from collections.abc import AsyncIterator from typing import Any from fastapi import Request @@ -399,19 +401,51 @@ async def receive(): # Step 3: Potentially rewrite the response if isinstance(response, StreamingResponse): + original_iterator = response.body_iterator + + if original_iterator is None: + return response + + # If no reply rules are configured we can pass the stream through. + if not self.rewriter.reply_rules: + return response + + max_pattern_length = self.rewriter.max_reply_search_length + overlap = max(0, max_pattern_length - 1) + # Keep the working buffer bounded to avoid unbounded memory growth. + chunk_limit = max(65536, (overlap + 1) * 2) + + async def rewriting_iterator() -> AsyncIterator[bytes]: + decoder = codecs.getincrementaldecoder("utf-8")(errors="replace") + buffer = "" + + async for chunk in original_iterator: + if isinstance(chunk, bytes): + text_chunk = decoder.decode(chunk, final=False) + else: + text_chunk = decoder.decode(chunk.encode("utf-8"), final=False) + buffer += text_chunk + + while len(buffer) > chunk_limit: + emit_upto = len(buffer) - overlap if overlap else len(buffer) + if emit_upto <= 0: + break + to_emit = buffer[:emit_upto] + buffer = buffer[emit_upto:] + if to_emit: + rewritten = self.rewriter.rewrite_reply(to_emit) + yield rewritten.encode("utf-8") - async def new_iterator(): - response_body = b"" - async for chunk in response.body_iterator: - response_body += chunk - rewritten_body = self.rewriter.rewrite_reply(response_body.decode()) - yield rewritten_body.encode("utf-8") + remaining = buffer + decoder.decode(b"", final=True) + if remaining: + rewritten_remaining = self.rewriter.rewrite_reply(remaining) + yield rewritten_remaining.encode("utf-8") background = response.background response.background = None return StreamingResponse( - new_iterator(), + rewriting_iterator(), status_code=response.status_code, headers=dict(response.headers), media_type=response.media_type, diff --git a/src/core/services/content_rewriter_service.py b/src/core/services/content_rewriter_service.py index a523e450a..0eb7b30fb 100644 --- a/src/core/services/content_rewriter_service.py +++ b/src/core/services/content_rewriter_service.py @@ -37,6 +37,7 @@ def __init__( self.prompt_system_rules: list[ReplacementRule] = [] self.prompt_user_rules: list[ReplacementRule] = [] self.reply_rules: list[ReplacementRule] = [] + self._max_reply_search_length: int | None = None self.load_rules() def load_rules(self) -> None: @@ -50,6 +51,35 @@ def load_rules(self) -> None: self.reply_rules = self._load_rules_from_dir( os.path.join(self.config_path, "replies") ) + self.refresh_rule_cache() + + def refresh_rule_cache(self) -> None: + """Recompute cached values derived from replacement rules.""" + + self._max_reply_search_length = self._calculate_max_search_length( + self.reply_rules + ) + + @staticmethod + def _calculate_max_search_length(rules: list[ReplacementRule]) -> int: + """Return the length of the longest search string in ``rules``.""" + + max_length = 0 + for rule in rules: + search = rule.search + if search: + search_length = len(search) + if search_length > max_length: + max_length = search_length + return max_length + + @property + def max_reply_search_length(self) -> int: + """Maximum length of reply rule search patterns (cached).""" + + if self._max_reply_search_length is None: + self.refresh_rule_cache() + return self._max_reply_search_length or 0 def _load_rules_from_dir(self, directory: str) -> list[ReplacementRule]: """Loads rules from a specific directory.""" diff --git a/tests/unit/core/app/middleware/test_content_rewriting_middleware.py b/tests/unit/core/app/middleware/test_content_rewriting_middleware.py new file mode 100644 index 000000000..16c5e622a --- /dev/null +++ b/tests/unit/core/app/middleware/test_content_rewriting_middleware.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import asyncio + +import pytest +from src.core.app.middleware.content_rewriting_middleware import ( + ContentRewritingMiddleware, +) +from src.core.domain.replacement_rule import ReplacementMode, ReplacementRule +from src.core.services.content_rewriter_service import ContentRewriterService +from starlette.requests import Request +from starlette.responses import StreamingResponse + + +def _build_request() -> Request: + scope = { + "type": "http", + "method": "GET", + "path": "/", + "headers": [], + "client": ("127.0.0.1", 12345), + "server": ("127.0.0.1", 8000), + "scheme": "http", + "http_version": "1.1", + } + + async def receive() -> dict[str, object]: + return {"type": "http.request", "body": b"", "more_body": False} + + return Request(scope, receive) + + +@pytest.mark.asyncio +async def test_streaming_rewrite_preserves_streaming_and_cross_chunk_matches() -> None: + rewriter = ContentRewriterService(config_path="non-existent") + rewriter.reply_rules = [ + ReplacementRule( + mode=ReplacementMode.REPLACE, + search="HELLO", + replace="BYE", + ) + ] + rewriter.refresh_rule_cache() + + middleware = ContentRewritingMiddleware(lambda request: None, rewriter) + request = _build_request() + + chunks = [b"start HEL", b"LO mid HEL", b"LO end"] + chunk_iterated: list[bytes] = [] + + async def chunk_generator(): + for chunk in chunks: + chunk_iterated.append(chunk) + yield chunk + await asyncio.sleep(0) + + async def call_next(_request: Request) -> StreamingResponse: + return StreamingResponse(chunk_generator(), media_type="text/plain") + + response = await middleware.dispatch(request, call_next) + + # Nothing has been consumed yet. + assert not chunk_iterated + + body_iter = response.body_iterator + assert body_iter is not None + + collected = b"" + try: + while True: + collected += await body_iter.__anext__() + except StopAsyncIteration: + pass + + # All chunks were produced lazily during iteration. + assert chunk_iterated == chunks + + assert collected.decode("utf-8") == "start BYE mid BYE end"