Skip to content

Commit f1e28d6

Browse files
committed
Use content_type in Fields not header
1 parent 12b3f26 commit f1e28d6

File tree

2 files changed

+37
-49
lines changed

2 files changed

+37
-49
lines changed

python_multipart/multipart.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,13 @@ def finalize(self) -> None: ...
6868
def close(self) -> None: ...
6969

7070
class FieldProtocol(_FormProtocol, Protocol):
71-
def __init__(self, name: bytes, headers: dict[str, bytes]) -> None: ...
71+
def __init__(self, name: bytes, content_type: str | None = None) -> None: ...
7272

7373
def set_none(self) -> None: ...
7474

7575
class FileProtocol(_FormProtocol, Protocol):
7676
def __init__(
77-
self, file_name: bytes | None, field_name: bytes | None, config: FileConfig, headers: dict[str, bytes]
77+
self, file_name: bytes | None, field_name: bytes | None, config: FileConfig, content_type: str | None = None
7878
) -> None: ...
7979

8080
OnFieldCallback = Callable[[FieldProtocol], None]
@@ -223,12 +223,13 @@ class Field:
223223
224224
Args:
225225
name: The name of the form field.
226+
content_type: The value of the Content-Type header for this field.
226227
"""
227228

228-
def __init__(self, name: bytes, headers: dict[str, bytes] = {}) -> None:
229+
def __init__(self, name: bytes, content_type: str | None = None) -> None:
229230
self._name = name
230231
self._value: list[bytes] = []
231-
self._headers: dict[str, bytes] = headers
232+
self._content_type = content_type
232233

233234
# We cache the joined version of _value for speed.
234235
self._cache = _missing
@@ -321,9 +322,9 @@ def value(self) -> bytes | None:
321322
return self._cache
322323

323324
@property
324-
def headers(self) -> dict[str, bytes]:
325-
"""This property returns the headers of the field."""
326-
return self._headers
325+
def content_type(self) -> str | None:
326+
"""This property returns the content_type value of the field."""
327+
return self._content_type
327328

328329
def __eq__(self, other: object) -> bool:
329330
if isinstance(other, Field):
@@ -362,14 +363,15 @@ class File:
362363
file_name: The name of the file that this [`File`][python_multipart.File] represents.
363364
field_name: The name of the form field that this file was uploaded with. This can be None, if, for example,
364365
the file was uploaded with Content-Type application/octet-stream.
366+
content_type: The value of the Content-Type header.
365367
config: The configuration for this File. See above for valid configuration keys and their corresponding values.
366368
""" # noqa: E501
367369

368370
def __init__(
369371
self,
370372
file_name: bytes | None,
371373
field_name: bytes | None = None,
372-
headers: dict[str, bytes] = {},
374+
content_type: str | None = None,
373375
config: FileConfig = {},
374376
) -> None:
375377
# Save configuration, set other variables default.
@@ -382,7 +384,7 @@ def __init__(
382384
# Save the provided field/file name and content type.
383385
self._field_name = field_name
384386
self._file_name = file_name
385-
self._headers = headers
387+
self._content_type = content_type
386388

387389
# Our actual file name is None by default, since, depending on our
388390
# config, we may not actually use the provided name.
@@ -436,14 +438,9 @@ def in_memory(self) -> bool:
436438
return self._in_memory
437439

438440
@property
439-
def headers(self) -> dict[str, bytes]:
440-
"""The headers for this part."""
441-
return self._headers
442-
443-
@property
444-
def content_type(self) -> bytes | None:
445-
"""The Content-Type value for this part."""
446-
return self._headers.get("content-type")
441+
def content_type(self) -> str | None:
442+
"""The Content-Type value for this part, if it was set."""
443+
return self._content_type
447444

448445
def flush_to_disk(self) -> None:
449446
"""If the file is already on-disk, do nothing. Otherwise, copy from
@@ -1565,7 +1562,7 @@ def __init__(
15651562

15661563
def on_start() -> None:
15671564
nonlocal file
1568-
file = FileClass(file_name, None, headers={}, config=cast("FileConfig", self.config))
1565+
file = FileClass(file_name, None, content_type=None, config=cast("FileConfig", self.config))
15691566

15701567
def on_data(data: bytes, start: int, end: int) -> None:
15711568
nonlocal file
@@ -1604,7 +1601,7 @@ def on_field_name(data: bytes, start: int, end: int) -> None:
16041601
def on_field_data(data: bytes, start: int, end: int) -> None:
16051602
nonlocal f
16061603
if f is None:
1607-
f = FieldClass(b"".join(name_buffer), headers={})
1604+
f = FieldClass(b"".join(name_buffer), content_type=None)
16081605
del name_buffer[:]
16091606
f.write(data[start:end])
16101607

@@ -1614,7 +1611,7 @@ def on_field_end() -> None:
16141611
if f is None:
16151612
# If we get here, it's because there was no field data.
16161613
# We create a field, set it to None, and then continue.
1617-
f = FieldClass(b"".join(name_buffer), headers={})
1614+
f = FieldClass(b"".join(name_buffer), content_type=None)
16181615
del name_buffer[:]
16191616
f.set_none()
16201617

@@ -1700,10 +1697,14 @@ def on_headers_finished() -> None:
17001697
# TODO: check for errors
17011698

17021699
# Create the proper class.
1700+
content_type_b = headers.get("content-type")
1701+
content_type = content_type_b.decode("latin-1") if content_type_b is not None else None
17031702
if file_name is None:
1704-
f_multi = FieldClass(field_name, headers=headers)
1703+
f_multi = FieldClass(field_name, content_type=content_type)
17051704
else:
1706-
f_multi = FileClass(file_name, field_name, config=cast("FileConfig", self.config), headers=headers)
1705+
f_multi = FileClass(
1706+
file_name, field_name, config=cast("FileConfig", self.config), content_type=content_type
1707+
)
17071708
is_file = True
17081709

17091710
# Parse the given Content-Transfer-Encoding to determine what

tests/test_multipart.py

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,7 @@ def assert_file_data(self, f: File, data: bytes) -> None:
758758
file_data = o.read()
759759
self.assertEqual(file_data, data)
760760

761-
def assert_file(self, field_name: bytes, file_name: bytes, content_type: str, data: bytes) -> None:
761+
def assert_file(self, field_name: bytes, file_name: bytes, content_type: str | None, data: bytes) -> None:
762762
# Find this file.
763763
found = None
764764
for f in self.files:
@@ -770,7 +770,7 @@ def assert_file(self, field_name: bytes, file_name: bytes, content_type: str, da
770770
self.assertIsNotNone(found)
771771
assert found is not None
772772

773-
self.assertEqual(found.content_type, content_type.encode())
773+
self.assertEqual(found.content_type, content_type)
774774

775775
try:
776776
# Assert about this file.
@@ -911,7 +911,8 @@ def test_feed_single_bytes(self, param: TestParams) -> None:
911911
self.assert_field(name, e["data"])
912912

913913
elif type == "file":
914-
self.assert_file(name, e["file_name"].encode("latin-1"), "text/plain", e["data"])
914+
content_type = "text/plain"
915+
self.assert_file(name, e["file_name"].encode("latin-1"), content_type, e["data"])
915916

916917
else:
917918
assert False
@@ -949,24 +950,16 @@ def test_feed_blocks(self) -> None:
949950
# Assert that our field is here.
950951
self.assert_field(b"field", b"0123456789ABCDEFGHIJ0123456789ABCDEFGHIJ")
951952

952-
def test_file_headers(self) -> None:
953+
def test_file_content_type_header(self) -> None:
953954
"""
954-
This test checks headers for a file part are read.
955+
This test checks the content-type for a file part is passed on.
955956
"""
956957
# Load test data.
957958
test_file = "header_with_number.http"
958959
with open(os.path.join(http_tests_dir, test_file), "rb") as f:
959960
test_data = f.read()
960961

961-
expected_headers = {
962-
"content-disposition": b'form-data; filename="secret.txt"; name="files"',
963-
"content-type": b"text/plain; charset=utf-8",
964-
"x-funky-header-1": b"bar",
965-
"abcdefghijklmnopqrstuvwxyz01234": b"foo",
966-
"abcdefghijklmnopqrstuvwxyz56789": b"bar",
967-
"other!#$%&'*+-.^_`|~": b"baz",
968-
"content-length": b"6",
969-
}
962+
expected_content_type = "text/plain; charset=utf-8"
970963

971964
# Create form parser.
972965
self.make(boundary="b8825ae386be4fdc9644d87e392caad3")
@@ -975,22 +968,19 @@ def test_file_headers(self) -> None:
975968

976969
# Assert that our field is here.
977970
self.assertEqual(1, len(self.files))
978-
actual_headers = self.files[0].headers
979-
self.assertEqual(len(actual_headers), len(expected_headers))
971+
actual_content_type = self.files[0].content_type
972+
self.assertEqual(actual_content_type, expected_content_type)
980973

981-
for k, v in expected_headers.items():
982-
self.assertEqual(v, actual_headers[k])
983-
984-
def test_field_headers(self) -> None:
974+
def test_field_content_type_header(self) -> None:
985975
"""
986-
This test checks headers for a field part are read.
976+
This test checks content-tpye for a field part are read and passed.
987977
"""
988978
# Load test data.
989979
test_file = "single_field.http"
990980
with open(os.path.join(http_tests_dir, test_file), "rb") as f:
991981
test_data = f.read()
992982

993-
expected_headers = {"content-disposition": b'form-data; name="field"'}
983+
expected_content_type = None
994984

995985
# Create form parser.
996986
self.make(boundary="----WebKitFormBoundaryTkr3kCBQlBe1nrhc")
@@ -999,11 +989,8 @@ def test_field_headers(self) -> None:
999989

1000990
# Assert that our field is here.
1001991
self.assertEqual(1, len(self.fields))
1002-
actual_headers = self.fields[0].headers
1003-
self.assertEqual(len(actual_headers), len(expected_headers))
1004-
1005-
for k, v in expected_headers.items():
1006-
self.assertEqual(v, actual_headers[k])
992+
actual_content_type = self.fields[0].content_type
993+
self.assertEqual(actual_content_type, expected_content_type)
1007994

1008995
def test_request_body_fuzz(self) -> None:
1009996
"""

0 commit comments

Comments
 (0)