From e191cbeeead6af5add4496909c10bbf130e3126b Mon Sep 17 00:00:00 2001 From: John Stark Date: Tue, 26 Nov 2024 20:21:07 +0100 Subject: [PATCH 1/2] Treat headers case insenitively, internally --- python_multipart/multipart.py | 25 +++++++++++++------- tests/test_data/http/mixed_case_headers.http | 19 +++++++++++++++ tests/test_data/http/mixed_case_headers.yaml | 14 +++++++++++ 3 files changed, 49 insertions(+), 9 deletions(-) create mode 100644 tests/test_data/http/mixed_case_headers.http create mode 100644 tests/test_data/http/mixed_case_headers.yaml diff --git a/python_multipart/multipart.py b/python_multipart/multipart.py index ace4a8f..64b8033 100644 --- a/python_multipart/multipart.py +++ b/python_multipart/multipart.py @@ -15,13 +15,17 @@ from .exceptions import FileError, FormParserError, MultipartParseError, QuerystringParseError if TYPE_CHECKING: # pragma: no cover - from typing import Any, Callable, Literal, Protocol, TypedDict + from typing import Any, Callable, Literal, Optional, Protocol, TypedDict from typing_extensions import TypeAlias class SupportsRead(Protocol): def read(self, __n: int) -> bytes: ... + # Protocol for dict-like (dict, or CaseInsensitiveDict) + class SupportsGetStrBytes(Protocol): + def get(self, _key: str) -> Optional[bytes]: ... + class QuerystringCallbacks(TypedDict, total=False): on_field_start: Callable[[], None] on_field_name: Callable[[bytes, int, int], None] @@ -1617,7 +1621,8 @@ def _on_end() -> None: header_name: list[bytes] = [] header_value: list[bytes] = [] - headers: dict[bytes, bytes] = {} + # Header keys are always inserted in Title-Case + headers: dict[str, bytes] = {} f_multi: FileProtocol | FieldProtocol | None = None writer = None @@ -1652,7 +1657,9 @@ def on_header_value(data: bytes, start: int, end: int) -> None: header_value.append(data[start:end]) def on_header_end() -> None: - headers[b"".join(header_name)] = b"".join(header_value) + # Convert header name to title case. + header_name_tc = b"".join(header_name).decode().title() + headers[header_name_tc] = b"".join(header_value) del header_name[:] del header_value[:] @@ -1662,8 +1669,7 @@ def on_headers_finished() -> None: is_file = False # Parse the content-disposition header. - # TODO: handle mixed case - content_disp = headers.get(b"Content-Disposition") + content_disp = headers.get("Content-Disposition") disp, options = parse_options_header(content_disp) # Get the field and filename. @@ -1681,7 +1687,7 @@ def on_headers_finished() -> None: # Parse the given Content-Transfer-Encoding to determine what # we need to do with the incoming data. # TODO: check that we properly handle 8bit / 7bit encoding. - transfer_encoding = headers.get(b"Content-Transfer-Encoding", b"7bit") + transfer_encoding = headers.get("Content-Transfer-Encoding", b"7bit") if transfer_encoding in (b"binary", b"8bit", b"7bit"): writer = f_multi @@ -1760,7 +1766,7 @@ def __repr__(self) -> str: def create_form_parser( - headers: dict[str, bytes], + headers: SupportsGetStrBytes, on_field: OnFieldCallback | None, on_file: OnFileCallback | None, trust_x_headers: bool = False, @@ -1804,7 +1810,7 @@ def create_form_parser( def parse_form( - headers: dict[str, bytes], + headers: SupportsGetStrBytes, input_stream: SupportsRead, on_field: OnFieldCallback | None, on_file: OnFileCallback | None, @@ -1816,7 +1822,8 @@ def parse_form( callbacks that will get called whenever a field or file is parsed. Args: - headers: A dictionary-like object of HTTP headers. The only required header is Content-Type. + headers: A dictionary-like object of HTTP headers. The only required header is Content-Type, + in exactly this form if the input dict is case sensitive. input_stream: A file-like object that represents the request body. The read() method must return bytestrings. on_field: Callback to call with each parsed field. on_file: Callback to call with each parsed file. diff --git a/tests/test_data/http/mixed_case_headers.http b/tests/test_data/http/mixed_case_headers.http new file mode 100644 index 0000000..6a4e247 --- /dev/null +++ b/tests/test_data/http/mixed_case_headers.http @@ -0,0 +1,19 @@ +----boundary +ConTenT-TypE: text/plain; charset="UTF-8" +ConTenT-DisPoSitioN: form-data; name=field1 +ConTenT-TransfeR-EncoDinG: base64 + +VGVzdCAxMjM= +----boundary +content-type: text/plain; charset="UTF-8" +content-disposition: form-data; name=field2 +content-transfer-encoding: base64 + +VGVzdCAxMjM= +----boundary +CONTENT-TYPE: text/plain; charset="UTF-8" +CONTENT-DISPOSITION: form-data; name=Field3 +CONTENT-TRANSFER-ENCODING: base64 + +VGVzdCAxMjM= +----boundary-- \ No newline at end of file diff --git a/tests/test_data/http/mixed_case_headers.yaml b/tests/test_data/http/mixed_case_headers.yaml new file mode 100644 index 0000000..d333067 --- /dev/null +++ b/tests/test_data/http/mixed_case_headers.yaml @@ -0,0 +1,14 @@ +boundary: --boundary +expected: + - name: field1 + type: field + data: !!binary | + VGVzdCAxMjM= + - name: field2 + type: field + data: !!binary | + VGVzdCAxMjM= + - name: Field3 + type: field + data: !!binary | + VGVzdCAxMjM= From e7a340c68cc1f8741c3ce21c754c1672b9a44bb9 Mon Sep 17 00:00:00 2001 From: John Stark Date: Tue, 26 Nov 2024 20:43:21 +0100 Subject: [PATCH 2/2] Use Mapping instead of Protocol --- python_multipart/multipart.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/python_multipart/multipart.py b/python_multipart/multipart.py index 64b8033..8401f67 100644 --- a/python_multipart/multipart.py +++ b/python_multipart/multipart.py @@ -15,17 +15,13 @@ from .exceptions import FileError, FormParserError, MultipartParseError, QuerystringParseError if TYPE_CHECKING: # pragma: no cover - from typing import Any, Callable, Literal, Optional, Protocol, TypedDict + from typing import Any, Callable, Literal, Mapping, Protocol, TypedDict from typing_extensions import TypeAlias class SupportsRead(Protocol): def read(self, __n: int) -> bytes: ... - # Protocol for dict-like (dict, or CaseInsensitiveDict) - class SupportsGetStrBytes(Protocol): - def get(self, _key: str) -> Optional[bytes]: ... - class QuerystringCallbacks(TypedDict, total=False): on_field_start: Callable[[], None] on_field_name: Callable[[bytes, int, int], None] @@ -1766,7 +1762,7 @@ def __repr__(self) -> str: def create_form_parser( - headers: SupportsGetStrBytes, + headers: Mapping[str, bytes], on_field: OnFieldCallback | None, on_file: OnFileCallback | None, trust_x_headers: bool = False, @@ -1810,7 +1806,7 @@ def create_form_parser( def parse_form( - headers: SupportsGetStrBytes, + headers: Mapping[str, bytes], input_stream: SupportsRead, on_field: OnFieldCallback | None, on_file: OnFileCallback | None,