diff --git a/multipart/multipart.py b/multipart/multipart.py index 170151f..2c26b62 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -1937,6 +1937,10 @@ def parse_form( content_length = float("inf") bytes_read = 0 + # If the input stream is a text stream, use its binary buffer + if isinstance (input_stream, io .TextIOBase): + input_stream = input_stream .buffer + while True: # Read only up to the Content-Length given. max_readable = min(content_length - bytes_read, 1048576) diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 3a814fb..6df98a2 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -5,7 +5,7 @@ import sys import tempfile import unittest -from io import BytesIO +from io import BytesIO, TextIOBase from typing import TYPE_CHECKING from unittest.mock import Mock @@ -1268,7 +1268,7 @@ def test_create_form_parser_error(self): with self.assertRaises(ValueError): create_form_parser(headers, None, None) - def test_parse_form(self): + def test_parse_form_bytes(self): on_field = Mock() on_file = Mock() @@ -1280,6 +1280,37 @@ def test_parse_form(self): # 15 - i.e. all data is written. self.assertEqual(on_file.call_args[0][0].size, 15) + def test_parse_form_file(self): + on_field = Mock() + on_file = Mock() + + with open ('12345678.txt', 'wt+') as f: + f .write ('123456789012345') + f .seek (0o0) + parse_form({"Content-Type": "application/octet-stream"}, f .buffer, on_field, on_file) + + assert on_file.call_count == 1 + + # Assert that the first argument of the call (a File object) has size + # 15 - i.e. all data is written. + self.assertEqual(on_file.call_args[0][0].size, 15) + + def test_parse_form_text(self): + on_field = Mock() + on_file = Mock() + + with open ('12345678.txt', 'wt+') as f: + f .write ('123456789012345') + f .seek (0o0) + self .assertTrue (isinstance (f, TextIOBase)) + parse_form({"Content-Type": "application/octet-stream"}, f, on_field, on_file) + + assert on_file.call_count == 1 + + # Assert that the first argument of the call (a File object) has size + # 15 - i.e. all data is written. + self.assertEqual(on_file.call_args[0][0].size, 15) + def test_parse_form_content_length(self): files = []