diff --git a/mypy/typeshed/stubs/librt/librt/internal.pyi b/mypy/typeshed/stubs/librt/librt/internal.pyi index 8654e31c100e..78e7f9caa117 100644 --- a/mypy/typeshed/stubs/librt/librt/internal.pyi +++ b/mypy/typeshed/stubs/librt/librt/internal.pyi @@ -1,19 +1,27 @@ from mypy_extensions import u8 +# TODO: Remove Buffer -- right now we have hacky support for BOTH the old and new APIs + class Buffer: def __init__(self, source: bytes = ...) -> None: ... def getvalue(self) -> bytes: ... -def write_bool(data: Buffer, value: bool) -> None: ... -def read_bool(data: Buffer) -> bool: ... -def write_str(data: Buffer, value: str) -> None: ... -def read_str(data: Buffer) -> str: ... -def write_bytes(data: Buffer, value: bytes) -> None: ... -def read_bytes(data: Buffer) -> bytes: ... -def write_float(data: Buffer, value: float) -> None: ... -def read_float(data: Buffer) -> float: ... -def write_int(data: Buffer, value: int) -> None: ... -def read_int(data: Buffer) -> int: ... -def write_tag(data: Buffer, value: u8) -> None: ... -def read_tag(data: Buffer) -> u8: ... +class ReadBuffer: + def __init__(self, source: bytes) -> None: ... + +class WriteBuffer: + def getvalue(self) -> bytes: ... + +def write_bool(data: WriteBuffer | Buffer, value: bool) -> None: ... +def read_bool(data: ReadBuffer | Buffer) -> bool: ... +def write_str(data: WriteBuffer | Buffer, value: str) -> None: ... +def read_str(data: ReadBuffer | Buffer) -> str: ... +def write_bytes(data: WriteBuffer | Buffer, value: bytes) -> None: ... +def read_bytes(data: ReadBuffer | Buffer) -> bytes: ... +def write_float(data: WriteBuffer | Buffer, value: float) -> None: ... +def read_float(data: ReadBuffer | Buffer) -> float: ... +def write_int(data: WriteBuffer | Buffer, value: int) -> None: ... +def read_int(data: ReadBuffer | Buffer) -> int: ... +def write_tag(data: WriteBuffer | Buffer, value: u8) -> None: ... +def read_tag(data: ReadBuffer | Buffer) -> u8: ... def cache_version() -> u8: ... diff --git a/mypyc/codegen/emit.py b/mypyc/codegen/emit.py index 4ef53296ef0d..f2a2271e020e 100644 --- a/mypyc/codegen/emit.py +++ b/mypyc/codegen/emit.py @@ -705,13 +705,25 @@ def emit_cast( self.emit_lines(f" {dest} = {src};", "else {") self.emit_cast_error_handler(error, src, dest, typ, raise_exception) self.emit_line("}") - elif is_object_rprimitive(typ) or is_native_rprimitive(typ): + elif is_object_rprimitive(typ): if declare_dest: self.emit_line(f"PyObject *{dest};") self.emit_arg_check(src, dest, typ, "", optional) self.emit_line(f"{dest} = {src};") if optional: self.emit_line("}") + elif is_native_rprimitive(typ): + # Native primitive types have type check functions of form "CPy_Check(...)". + if declare_dest: + self.emit_line(f"PyObject *{dest};") + short_name = typ.name.rsplit(".", 1)[-1] + check = f"(CPy{short_name}_Check({src}))" + if likely: + check = f"(likely{check})" + self.emit_arg_check(src, dest, typ, check, optional) + self.emit_lines(f" {dest} = {src};", "else {") + self.emit_cast_error_handler(error, src, dest, typ, raise_exception) + self.emit_line("}") elif isinstance(typ, RUnion): self.emit_union_cast( src, dest, typ, declare_dest, error, optional, src_type, raise_exception diff --git a/mypyc/ir/rtypes.py b/mypyc/ir/rtypes.py index 941670ab230d..66b98e5d6398 100644 --- a/mypyc/ir/rtypes.py +++ b/mypyc/ir/rtypes.py @@ -514,7 +514,7 @@ def __hash__(self) -> int: KNOWN_NATIVE_TYPES: Final = { name: RPrimitive(name, is_unboxed=False, is_refcounted=True) - for name in ["librt.internal.Buffer"] + for name in ["librt.internal.WriteBuffer", "librt.internal.ReadBuffer"] } diff --git a/mypyc/lib-rt/librt_internal.c b/mypyc/lib-rt/librt_internal.c index eaf451eff22b..ada2dfeb39a5 100644 --- a/mypyc/lib-rt/librt_internal.c +++ b/mypyc/lib-rt/librt_internal.c @@ -30,18 +30,30 @@ #define CPY_NONE_ERROR 2 #define CPY_NONE 1 -#define _CHECK_BUFFER(data, err) if (unlikely(_check_buffer(data) == CPY_NONE_ERROR)) \ - return err; -#define _CHECK_SIZE(data, need) if (unlikely(_check_size((BufferObject *)data, need) == CPY_NONE_ERROR)) \ - return CPY_NONE_ERROR; -#define _CHECK_READ(data, size, err) if (unlikely(_check_read((BufferObject *)data, size) == CPY_NONE_ERROR)) \ - return err; - -#define _READ(data, type) *(type *)(((BufferObject *)data)->buf + ((BufferObject *)data)->pos); \ - ((BufferObject *)data)->pos += sizeof(type); - -#define _WRITE(data, type, v) *(type *)(((BufferObject *)data)->buf + ((BufferObject *)data)->pos) = v; \ - ((BufferObject *)data)->pos += sizeof(type); +#define _CHECK_READ_BUFFER(data, err) if (unlikely(_check_read_buffer(data) == CPY_NONE_ERROR)) \ + return err; +#define _CHECK_WRITE_BUFFER(data, err) if (unlikely(_check_write_buffer(data) == CPY_NONE_ERROR)) \ + return err; +#define _CHECK_WRITE(data, need) if (unlikely(_check_size((WriteBufferObject *)data, need) == CPY_NONE_ERROR)) \ + return CPY_NONE_ERROR; +#define _CHECK_READ(data, size, err) if (unlikely(_check_read((ReadBufferObject *)data, size) == CPY_NONE_ERROR)) \ + return err; + +#define _READ(result, data, type) \ + do { \ + *(result) = *(type *)(((ReadBufferObject *)data)->ptr); \ + ((ReadBufferObject *)data)->ptr += sizeof(type); \ + } while (0) + +#define _WRITE(data, type, v) \ + do { \ + *(type *)(((WriteBufferObject *)data)->ptr) = v; \ + ((WriteBufferObject *)data)->ptr += sizeof(type); \ + } while (0) + +// +// ReadBuffer +// #if PY_BIG_ENDIAN uint16_t reverse_16(uint16_t number) { @@ -55,78 +67,59 @@ uint32_t reverse_32(uint32_t number) { typedef struct { PyObject_HEAD - Py_ssize_t pos; - Py_ssize_t end; - Py_ssize_t size; - char *buf; -} BufferObject; + char *ptr; // Current read location in the buffer + char *end; // End of the buffer + PyObject *source; // The object that contains the buffer +} ReadBufferObject; -static PyTypeObject BufferType; +static PyTypeObject ReadBufferType; static PyObject* -Buffer_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +ReadBuffer_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { - if (type != &BufferType) { - PyErr_SetString(PyExc_TypeError, "Buffer should not be subclassed"); + if (type != &ReadBufferType) { + PyErr_SetString(PyExc_TypeError, "ReadBuffer should not be subclassed"); return NULL; } - BufferObject *self = (BufferObject *)type->tp_alloc(type, 0); + ReadBufferObject *self = (ReadBufferObject *)type->tp_alloc(type, 0); if (self != NULL) { - self->pos = 0; - self->end = 0; - self->size = 0; - self->buf = NULL; + self->source = NULL; + self->ptr = NULL; + self->end = NULL; } return (PyObject *) self; } - static int -Buffer_init_internal(BufferObject *self, PyObject *source) { - if (source) { - if (!PyBytes_Check(source)) { - PyErr_SetString(PyExc_TypeError, "source must be a bytes object"); - return -1; - } - self->end = PyBytes_GET_SIZE(source); - // Allocate at least one byte to simplify resizing logic. - // The original bytes buffer has last null byte, so this is safe. - self->size = self->end + 1; - // This returns a pointer to internal bytes data, so make our own copy. - char *buf = PyBytes_AsString(source); - self->buf = PyMem_Malloc(self->size); - memcpy(self->buf, buf, self->end); - } else { - self->buf = PyMem_Malloc(START_SIZE); - self->size = START_SIZE; +ReadBuffer_init_internal(ReadBufferObject *self, PyObject *source) { + if (!PyBytes_CheckExact(source)) { + PyErr_SetString(PyExc_TypeError, "source must be a bytes object"); + return -1; } + self->source = Py_NewRef(source); + self->ptr = PyBytes_AS_STRING(source); + self->end = self->ptr + PyBytes_GET_SIZE(source); return 0; } static PyObject* -Buffer_internal(PyObject *source) { - BufferObject *self = (BufferObject *)BufferType.tp_alloc(&BufferType, 0); +ReadBuffer_internal(PyObject *source) { + ReadBufferObject *self = (ReadBufferObject *)ReadBufferType.tp_alloc(&ReadBufferType, 0); if (self == NULL) return NULL; - self->pos = 0; - self->end = 0; - self->size = 0; - self->buf = NULL; - if (Buffer_init_internal(self, source) == -1) { + self->ptr = NULL; + self->end = NULL; + self->source = NULL; + if (ReadBuffer_init_internal(self, source) == -1) { Py_DECREF(self); return NULL; } return (PyObject *)self; } -static PyObject* -Buffer_internal_empty(void) { - return Buffer_internal(NULL); -} - static int -Buffer_init(BufferObject *self, PyObject *args, PyObject *kwds) +ReadBuffer_init(ReadBufferObject *self, PyObject *args, PyObject *kwds) { static char *kwlist[] = {"source", NULL}; PyObject *source = NULL; @@ -134,53 +127,166 @@ Buffer_init(BufferObject *self, PyObject *args, PyObject *kwds) if (!PyArg_ParseTupleAndKeywords(args, kwds, "|O", kwlist, &source)) return -1; - return Buffer_init_internal(self, source); + return ReadBuffer_init_internal(self, source); +} + +static void +ReadBuffer_dealloc(ReadBufferObject *self) +{ + Py_CLEAR(self->source); + Py_TYPE(self)->tp_free((PyObject *)self); +} + +static PyMethodDef ReadBuffer_methods[] = { + {NULL} /* Sentinel */ +}; + +static PyTypeObject ReadBufferType = { + .ob_base = PyVarObject_HEAD_INIT(NULL, 0) + .tp_name = "ReadBuffer", + .tp_doc = PyDoc_STR("Mypy cache buffer objects"), + .tp_basicsize = sizeof(ReadBufferObject), + .tp_itemsize = 0, + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_new = ReadBuffer_new, + .tp_init = (initproc) ReadBuffer_init, + .tp_dealloc = (destructor) ReadBuffer_dealloc, + .tp_methods = ReadBuffer_methods, +}; + +// +// WriteBuffer +// + +typedef struct { + PyObject_HEAD + char *buf; // Beginning of the buffer + char *ptr; // Current write location in the buffer + char *end; // End of the buffer +} WriteBufferObject; + +static PyTypeObject WriteBufferType; + +static PyObject* +WriteBuffer_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + if (type != &WriteBufferType) { + PyErr_SetString(PyExc_TypeError, "WriteBuffer cannot be subclassed"); + return NULL; + } + + WriteBufferObject *self = (WriteBufferObject *)type->tp_alloc(type, 0); + if (self != NULL) { + self->buf = NULL; + self->ptr = NULL; + self->end = NULL; + } + return (PyObject *)self; +} + +static int +WriteBuffer_init_internal(WriteBufferObject *self) { + Py_ssize_t size = START_SIZE; + self->buf = PyMem_Malloc(size + 1); + if (self->buf == NULL) { + PyErr_NoMemory(); + return -1; + } + self->ptr = self->buf; + self->end = self->buf + size; + return 0; +} + +static PyObject* +WriteBuffer_internal(void) { + WriteBufferObject *self = (WriteBufferObject *)WriteBufferType.tp_alloc(&WriteBufferType, 0); + if (self == NULL) + return NULL; + self->buf = NULL; + self->ptr = NULL; + self->end = NULL; + if (WriteBuffer_init_internal(self) == -1) { + Py_DECREF(self); + return NULL; + } + return (PyObject *)self; +} + +static int +WriteBuffer_init(WriteBufferObject *self, PyObject *args, PyObject *kwds) +{ + if (!PyArg_ParseTuple(args, "")) { + return -1; + } + + if (kwds != NULL && PyDict_Size(kwds) > 0) { + PyErr_SetString(PyExc_TypeError, + "WriteBuffer() takes no keyword arguments"); + return -1; + } + + return WriteBuffer_init_internal(self); } static void -Buffer_dealloc(BufferObject *self) +WriteBuffer_dealloc(WriteBufferObject *self) { PyMem_Free(self->buf); + self->buf = NULL; Py_TYPE(self)->tp_free((PyObject *)self); } static PyObject* -Buffer_getvalue_internal(PyObject *self) +WriteBuffer_getvalue_internal(PyObject *self) { - return PyBytes_FromStringAndSize(((BufferObject *)self)->buf, ((BufferObject *)self)->end); + WriteBufferObject *obj = (WriteBufferObject *)self; + return PyBytes_FromStringAndSize(obj->buf, obj->ptr - obj->buf); } static PyObject* -Buffer_getvalue(BufferObject *self, PyObject *Py_UNUSED(ignored)) +WriteBuffer_getvalue(WriteBufferObject *self, PyObject *Py_UNUSED(ignored)) { - return PyBytes_FromStringAndSize(self->buf, self->end); + return PyBytes_FromStringAndSize(self->buf, self->ptr - self->buf); } -static PyMethodDef Buffer_methods[] = { - {"getvalue", (PyCFunction) Buffer_getvalue, METH_NOARGS, +static PyMethodDef WriteBuffer_methods[] = { + {"getvalue", (PyCFunction) WriteBuffer_getvalue, METH_NOARGS, "Return the buffer content as bytes object" }, {NULL} /* Sentinel */ }; -static PyTypeObject BufferType = { +static PyTypeObject WriteBufferType = { .ob_base = PyVarObject_HEAD_INIT(NULL, 0) - .tp_name = "Buffer", + .tp_name = "WriteBuffer", .tp_doc = PyDoc_STR("Mypy cache buffer objects"), - .tp_basicsize = sizeof(BufferObject), + .tp_basicsize = sizeof(WriteBufferObject), .tp_itemsize = 0, .tp_flags = Py_TPFLAGS_DEFAULT, - .tp_new = Buffer_new, - .tp_init = (initproc) Buffer_init, - .tp_dealloc = (destructor) Buffer_dealloc, - .tp_methods = Buffer_methods, + .tp_new = WriteBuffer_new, + .tp_init = (initproc) WriteBuffer_init, + .tp_dealloc = (destructor) WriteBuffer_dealloc, + .tp_methods = WriteBuffer_methods, }; +// ---------- + +static inline char +_check_read_buffer(PyObject *data) { + if (unlikely(Py_TYPE(data) != &ReadBufferType)) { + PyErr_Format( + PyExc_TypeError, "data must be a ReadBuffer object, got %s", Py_TYPE(data)->tp_name + ); + return CPY_NONE_ERROR; + } + return CPY_NONE; +} + static inline char -_check_buffer(PyObject *data) { - if (unlikely(Py_TYPE(data) != &BufferType)) { +_check_write_buffer(PyObject *data) { + if (unlikely(Py_TYPE(data) != &WriteBufferType)) { PyErr_Format( - PyExc_TypeError, "data must be a Buffer object, got %s", Py_TYPE(data)->tp_name + PyExc_TypeError, "data must be a WriteBuffer object, got %s", Py_TYPE(data)->tp_name ); return CPY_NONE_ERROR; } @@ -188,24 +294,28 @@ _check_buffer(PyObject *data) { } static inline char -_check_size(BufferObject *data, Py_ssize_t need) { - Py_ssize_t target = data->pos + need; - if (target <= data->size) +_check_size(WriteBufferObject *data, Py_ssize_t need) { + if (data->end - data->ptr >= need) return CPY_NONE; - do - data->size *= 2; - while (target >= data->size); - data->buf = PyMem_Realloc(data->buf, data->size); + Py_ssize_t index = data->ptr - data->buf; + Py_ssize_t target = index + need; + Py_ssize_t size = data->end - data->buf; + do { + size *= 2; + } while (target >= size); + data->buf = PyMem_Realloc(data->buf, size); if (unlikely(data->buf == NULL)) { PyErr_NoMemory(); return CPY_NONE_ERROR; } + data->ptr = data->buf + index; + data->end = data->buf + size; return CPY_NONE; } static inline char -_check_read(BufferObject *data, Py_ssize_t need) { - if (unlikely(data->pos + need > data->end)) { +_check_read(ReadBufferObject *data, Py_ssize_t need) { + if (unlikely((data->end - data->ptr) < need)) { PyErr_SetString(PyExc_ValueError, "reading past the buffer end"); return CPY_NONE_ERROR; } @@ -220,9 +330,9 @@ bool format: single byte static char read_bool_internal(PyObject *data) { - _CHECK_BUFFER(data, CPY_BOOL_ERROR) _CHECK_READ(data, 1, CPY_BOOL_ERROR) - char res = _READ(data, char) + char res; + _READ(&res, data, char); if (unlikely((res != 0) & (res != 1))) { PyErr_SetString(PyExc_ValueError, "invalid bool value"); return CPY_BOOL_ERROR; @@ -238,6 +348,7 @@ read_bool(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnames if (unlikely(!CPyArg_ParseStackAndKeywordsOneArg(args, nargs, kwnames, &parser, &data))) { return NULL; } + _CHECK_READ_BUFFER(data, NULL) char res = read_bool_internal(data); if (unlikely(res == CPY_BOOL_ERROR)) return NULL; @@ -248,10 +359,8 @@ read_bool(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnames static char write_bool_internal(PyObject *data, char value) { - _CHECK_BUFFER(data, CPY_NONE_ERROR) - _CHECK_SIZE(data, 1) - _WRITE(data, char, value) - ((BufferObject *)data)->end += 1; + _CHECK_WRITE(data, 1) + _WRITE(data, char, value); return CPY_NONE; } @@ -264,6 +373,7 @@ write_bool(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwname if (unlikely(!CPyArg_ParseStackAndKeywordsSimple(args, nargs, kwnames, &parser, &data, &value))) { return NULL; } + _CHECK_WRITE_BUFFER(data, NULL) if (unlikely(!PyBool_Check(value))) { PyErr_SetString(PyExc_TypeError, "value must be a bool"); return NULL; @@ -289,14 +399,14 @@ _read_short_int(PyObject *data, uint8_t first) { } if ((first & FOUR_BYTES_INT_BIT) == 0) { _CHECK_READ(data, 1, CPY_INT_TAG) - second = _READ(data, uint8_t) + _READ(&second, data, uint8_t); return ((((Py_ssize_t)second) << 6) + (Py_ssize_t)(first >> 2) + MIN_TWO_BYTES_INT) << 1; } // The caller is responsible to verify this is called only for short ints. _CHECK_READ(data, 3, CPY_INT_TAG) // TODO: check if compilers emit optimal code for these two reads, and tweak if needed. - second = _READ(data, uint8_t) - two_more = _READ(data, uint16_t) + _READ(&second, data, uint8_t); + _READ(&two_more, data, uint16_t); #if PY_BIG_ENDIAN two_more = reverse_16(two_more); #endif @@ -306,11 +416,10 @@ _read_short_int(PyObject *data, uint8_t first) { static PyObject* read_str_internal(PyObject *data) { - _CHECK_BUFFER(data, NULL) - // Read string length. _CHECK_READ(data, 1, NULL) - uint8_t first = _READ(data, uint8_t) + uint8_t first; + _READ(&first, data, uint8_t); if (unlikely(first == LONG_INT_TRAILER)) { // Fail fast for invalid/tampered data. PyErr_SetString(PyExc_ValueError, "invalid str size"); @@ -326,14 +435,12 @@ read_str_internal(PyObject *data) { } Py_ssize_t size = tagged_size >> 1; // Read string content. - char *buf = ((BufferObject *)data)->buf; + char *ptr = ((ReadBufferObject *)data)->ptr; _CHECK_READ(data, size, NULL) - PyObject *res = PyUnicode_FromStringAndSize( - buf + ((BufferObject *)data)->pos, (Py_ssize_t)size - ); + PyObject *res = PyUnicode_FromStringAndSize(ptr, (Py_ssize_t)size); if (unlikely(res == NULL)) return NULL; - ((BufferObject *)data)->pos += size; + ((ReadBufferObject *)data)->ptr += size; return res; } @@ -345,6 +452,7 @@ read_str(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnames) if (unlikely(!CPyArg_ParseStackAndKeywordsOneArg(args, nargs, kwnames, &parser, &data))) { return NULL; } + _CHECK_READ_BUFFER(data, NULL) return read_str_internal(data); } @@ -352,35 +460,30 @@ read_str(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnames) static inline char _write_short_int(PyObject *data, Py_ssize_t real_value) { if (real_value >= MIN_ONE_BYTE_INT && real_value <= MAX_ONE_BYTE_INT) { - _CHECK_SIZE(data, 1) - _WRITE(data, uint8_t, (uint8_t)(real_value - MIN_ONE_BYTE_INT) << 1) - ((BufferObject *)data)->end += 1; + _CHECK_WRITE(data, 1) + _WRITE(data, uint8_t, (uint8_t)(real_value - MIN_ONE_BYTE_INT) << 1); } else if (real_value >= MIN_TWO_BYTES_INT && real_value <= MAX_TWO_BYTES_INT) { - _CHECK_SIZE(data, 2) + _CHECK_WRITE(data, 2) #if PY_BIG_ENDIAN uint16_t to_write = ((uint16_t)(real_value - MIN_TWO_BYTES_INT) << 2) | TWO_BYTES_INT_BIT; _WRITE(data, uint16_t, reverse_16(to_write)) #else - _WRITE(data, uint16_t, ((uint16_t)(real_value - MIN_TWO_BYTES_INT) << 2) | TWO_BYTES_INT_BIT) + _WRITE(data, uint16_t, ((uint16_t)(real_value - MIN_TWO_BYTES_INT) << 2) | TWO_BYTES_INT_BIT); #endif - ((BufferObject *)data)->end += 2; } else { - _CHECK_SIZE(data, 4) + _CHECK_WRITE(data, 4) #if PY_BIG_ENDIAN uint32_t to_write = ((uint32_t)(real_value - MIN_FOUR_BYTES_INT) << 3) | FOUR_BYTES_INT_TRAILER; _WRITE(data, uint32_t, reverse_32(to_write)) #else - _WRITE(data, uint32_t, ((uint32_t)(real_value - MIN_FOUR_BYTES_INT) << 3) | FOUR_BYTES_INT_TRAILER) + _WRITE(data, uint32_t, ((uint32_t)(real_value - MIN_FOUR_BYTES_INT) << 3) | FOUR_BYTES_INT_TRAILER); #endif - ((BufferObject *)data)->end += 4; } return CPY_NONE; } static char write_str_internal(PyObject *data, PyObject *value) { - _CHECK_BUFFER(data, CPY_NONE_ERROR) - Py_ssize_t size; const char *chunk = PyUnicode_AsUTF8AndSize(value, &size); if (unlikely(chunk == NULL)) @@ -395,11 +498,10 @@ write_str_internal(PyObject *data, PyObject *value) { return CPY_NONE_ERROR; } // Write string content. - _CHECK_SIZE(data, size) - char *buf = ((BufferObject *)data)->buf; - memcpy(buf + ((BufferObject *)data)->pos, chunk, size); - ((BufferObject *)data)->pos += size; - ((BufferObject *)data)->end += size; + _CHECK_WRITE(data, size) + char *ptr = ((WriteBufferObject *)data)->ptr; + memcpy(ptr, chunk, size); + ((WriteBufferObject *)data)->ptr += size; return CPY_NONE; } @@ -412,6 +514,7 @@ write_str(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnames if (unlikely(!CPyArg_ParseStackAndKeywordsSimple(args, nargs, kwnames, &parser, &data, &value))) { return NULL; } + _CHECK_WRITE_BUFFER(data, NULL) if (unlikely(!PyUnicode_Check(value))) { PyErr_SetString(PyExc_TypeError, "value must be a str"); return NULL; @@ -429,11 +532,10 @@ bytes format: size as int (see below) followed by bytes static PyObject* read_bytes_internal(PyObject *data) { - _CHECK_BUFFER(data, NULL) - // Read length. _CHECK_READ(data, 1, NULL) - uint8_t first = _READ(data, uint8_t) + uint8_t first; + _READ(&first, data, uint8_t); if (unlikely(first == LONG_INT_TRAILER)) { // Fail fast for invalid/tampered data. PyErr_SetString(PyExc_ValueError, "invalid bytes size"); @@ -449,14 +551,12 @@ read_bytes_internal(PyObject *data) { } Py_ssize_t size = tagged_size >> 1; // Read bytes content. - char *buf = ((BufferObject *)data)->buf; + char *ptr = ((ReadBufferObject *)data)->ptr; _CHECK_READ(data, size, NULL) - PyObject *res = PyBytes_FromStringAndSize( - buf + ((BufferObject *)data)->pos, (Py_ssize_t)size - ); + PyObject *res = PyBytes_FromStringAndSize(ptr, (Py_ssize_t)size); if (unlikely(res == NULL)) return NULL; - ((BufferObject *)data)->pos += size; + ((ReadBufferObject *)data)->ptr += size; return res; } @@ -468,13 +568,12 @@ read_bytes(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwname if (unlikely(!CPyArg_ParseStackAndKeywordsOneArg(args, nargs, kwnames, &parser, &data))) { return NULL; } + _CHECK_READ_BUFFER(data, NULL) return read_bytes_internal(data); } static char write_bytes_internal(PyObject *data, PyObject *value) { - _CHECK_BUFFER(data, CPY_NONE_ERROR) - const char *chunk = PyBytes_AsString(value); if (unlikely(chunk == NULL)) return CPY_NONE_ERROR; @@ -489,11 +588,10 @@ write_bytes_internal(PyObject *data, PyObject *value) { return CPY_NONE_ERROR; } // Write bytes content. - _CHECK_SIZE(data, size) - char *buf = ((BufferObject *)data)->buf; - memcpy(buf + ((BufferObject *)data)->pos, chunk, size); - ((BufferObject *)data)->pos += size; - ((BufferObject *)data)->end += size; + _CHECK_WRITE(data, size) + char *ptr = ((WriteBufferObject *)data)->ptr; + memcpy(ptr, chunk, size); + ((WriteBufferObject *)data)->ptr += size; return CPY_NONE; } @@ -506,6 +604,7 @@ write_bytes(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnam if (unlikely(!CPyArg_ParseStackAndKeywordsSimple(args, nargs, kwnames, &parser, &data, &value))) { return NULL; } + _CHECK_WRITE_BUFFER(data, NULL) if (unlikely(!PyBytes_Check(value))) { PyErr_SetString(PyExc_TypeError, "value must be a bytes object"); return NULL; @@ -524,13 +623,12 @@ float format: static double read_float_internal(PyObject *data) { - _CHECK_BUFFER(data, CPY_FLOAT_ERROR) _CHECK_READ(data, 8, CPY_FLOAT_ERROR) - char *buf = ((BufferObject *)data)->buf; - double res = PyFloat_Unpack8(buf + ((BufferObject *)data)->pos, 1); + char *ptr = ((ReadBufferObject *)data)->ptr; + double res = PyFloat_Unpack8(ptr, 1); if (unlikely((res == -1.0) && PyErr_Occurred())) return CPY_FLOAT_ERROR; - ((BufferObject *)data)->pos += 8; + ((ReadBufferObject *)data)->ptr += 8; return res; } @@ -542,6 +640,7 @@ read_float(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwname if (unlikely(!CPyArg_ParseStackAndKeywordsOneArg(args, nargs, kwnames, &parser, &data))) { return NULL; } + _CHECK_READ_BUFFER(data, NULL) double retval = read_float_internal(data); if (unlikely(retval == CPY_FLOAT_ERROR && PyErr_Occurred())) { return NULL; @@ -551,14 +650,12 @@ read_float(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwname static char write_float_internal(PyObject *data, double value) { - _CHECK_BUFFER(data, CPY_NONE_ERROR) - _CHECK_SIZE(data, 8) - char *buf = ((BufferObject *)data)->buf; - int res = PyFloat_Pack8(value, buf + ((BufferObject *)data)->pos, 1); + _CHECK_WRITE(data, 8) + char *ptr = ((WriteBufferObject *)data)->ptr; + int res = PyFloat_Pack8(value, ptr, 1); if (unlikely(res == -1)) return CPY_NONE_ERROR; - ((BufferObject *)data)->pos += 8; - ((BufferObject *)data)->end += 8; + ((WriteBufferObject *)data)->ptr += 8; return CPY_NONE; } @@ -571,6 +668,7 @@ write_float(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnam if (unlikely(!CPyArg_ParseStackAndKeywordsSimple(args, nargs, kwnames, &parser, &data, &value))) { return NULL; } + _CHECK_WRITE_BUFFER(data, NULL) if (unlikely(!PyFloat_Check(value))) { PyErr_SetString(PyExc_TypeError, "value must be a float"); return NULL; @@ -595,10 +693,10 @@ since negative integers are much more rare. static CPyTagged read_int_internal(PyObject *data) { - _CHECK_BUFFER(data, CPY_INT_TAG) _CHECK_READ(data, 1, CPY_INT_TAG) - uint8_t first = _READ(data, uint8_t) + uint8_t first; + _READ(&first, data, uint8_t); if (likely(first != LONG_INT_TRAILER)) { return _read_short_int(data, first); } @@ -607,7 +705,7 @@ read_int_internal(PyObject *data) { // Read byte length and sign. _CHECK_READ(data, 1, CPY_INT_TAG) - first = _READ(data, uint8_t) + _READ(&first, data, uint8_t); Py_ssize_t size_and_sign = _read_short_int(data, first); if (size_and_sign == CPY_INT_TAG) return CPY_INT_TAG; @@ -620,12 +718,11 @@ read_int_internal(PyObject *data) { // Construct an int object from the byte array. _CHECK_READ(data, size, CPY_INT_TAG) - char *buf = ((BufferObject *)data)->buf; - PyObject *num = _PyLong_FromByteArray( - (unsigned char *)(buf + ((BufferObject *)data)->pos), size, 1, 0); + char *ptr = ((ReadBufferObject *)data)->ptr; + PyObject *num = _PyLong_FromByteArray((unsigned char *)ptr, size, 1, 0); if (num == NULL) return CPY_INT_TAG; - ((BufferObject *)data)->pos += size; + ((ReadBufferObject *)data)->ptr += size; if (sign) { PyObject *old = num; num = PyNumber_Negative(old); @@ -645,6 +742,7 @@ read_int(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnames) if (unlikely(!CPyArg_ParseStackAndKeywordsOneArg(args, nargs, kwnames, &parser, &data))) { return NULL; } + _CHECK_READ_BUFFER(data, NULL) CPyTagged retval = read_int_internal(data); if (unlikely(retval == CPY_INT_TAG)) { return NULL; @@ -664,9 +762,8 @@ static inline int hex_to_int(char c) { static inline char _write_long_int(PyObject *data, CPyTagged value) { - _CHECK_SIZE(data, 1) - _WRITE(data, uint8_t, LONG_INT_TRAILER) - ((BufferObject *)data)->end += 1; + _CHECK_WRITE(data, 1) + _WRITE(data, uint8_t, LONG_INT_TRAILER); PyObject *hex_str = NULL; PyObject* int_value = CPyTagged_AsObject(value); @@ -731,8 +828,6 @@ _write_long_int(PyObject *data, CPyTagged value) { static char write_int_internal(PyObject *data, CPyTagged value) { - _CHECK_BUFFER(data, CPY_NONE_ERROR) - if (likely((value & CPY_INT_TAG) == 0)) { Py_ssize_t real_value = CPyTagged_ShortAsSsize_t(value); if (likely(real_value >= MIN_FOUR_BYTES_INT && real_value <= MAX_FOUR_BYTES_INT)) { @@ -754,6 +849,7 @@ write_int(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnames if (unlikely(!CPyArg_ParseStackAndKeywordsSimple(args, nargs, kwnames, &parser, &data, &value))) { return NULL; } + _CHECK_WRITE_BUFFER(data, NULL) if (unlikely(!PyLong_Check(value))) { PyErr_SetString(PyExc_TypeError, "value must be an int"); return NULL; @@ -773,9 +869,9 @@ integer tag format (0 <= t <= 255): static uint8_t read_tag_internal(PyObject *data) { - _CHECK_BUFFER(data, CPY_LL_UINT_ERROR) _CHECK_READ(data, 1, CPY_LL_UINT_ERROR) - uint8_t ret = _READ(data, uint8_t) + uint8_t ret; + _READ(&ret, data, uint8_t); return ret; } @@ -787,6 +883,7 @@ read_tag(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnames) if (unlikely(!CPyArg_ParseStackAndKeywordsOneArg(args, nargs, kwnames, &parser, &data))) { return NULL; } + _CHECK_READ_BUFFER(data, NULL) uint8_t retval = read_tag_internal(data); if (unlikely(retval == CPY_LL_UINT_ERROR && PyErr_Occurred())) { return NULL; @@ -796,10 +893,8 @@ read_tag(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnames) static char write_tag_internal(PyObject *data, uint8_t value) { - _CHECK_BUFFER(data, CPY_NONE_ERROR) - _CHECK_SIZE(data, 1) - _WRITE(data, uint8_t, value) - ((BufferObject *)data)->end += 1; + _CHECK_WRITE(data, 1) + _WRITE(data, uint8_t, value); return CPY_NONE; } @@ -812,6 +907,7 @@ write_tag(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnames if (unlikely(!CPyArg_ParseStackAndKeywordsSimple(args, nargs, kwnames, &parser, &data, &value))) { return NULL; } + _CHECK_WRITE_BUFFER(data, NULL) uint8_t unboxed = CPyLong_AsUInt8(value); if (unlikely(unboxed == CPY_LL_UINT_ERROR && PyErr_Occurred())) { CPy_TypeError("u8", value); @@ -834,6 +930,16 @@ cache_version(PyObject *self, PyObject *Py_UNUSED(ignored)) { return PyLong_FromLong(cache_version_internal()); } +static PyTypeObject * +ReadBuffer_type_internal(void) { + return &ReadBufferType; // Return borrowed reference +} + +static PyTypeObject * +WriteBuffer_type_internal(void) { + return &WriteBufferType; // Return borrowed reference +}; + static PyMethodDef librt_internal_module_methods[] = { {"write_bool", (PyCFunction)write_bool, METH_FASTCALL | METH_KEYWORDS, PyDoc_STR("write a bool")}, {"read_bool", (PyCFunction)read_bool, METH_FASTCALL | METH_KEYWORDS, PyDoc_STR("read a bool")}, @@ -859,18 +965,24 @@ NativeInternal_ABI_Version(void) { static int librt_internal_module_exec(PyObject *m) { - if (PyType_Ready(&BufferType) < 0) { + if (PyType_Ready(&ReadBufferType) < 0) { + return -1; + } + if (PyType_Ready(&WriteBufferType) < 0) { + return -1; + } + if (PyModule_AddObjectRef(m, "ReadBuffer", (PyObject *) &ReadBufferType) < 0) { return -1; } - if (PyModule_AddObjectRef(m, "Buffer", (PyObject *) &BufferType) < 0) { + if (PyModule_AddObjectRef(m, "WriteBuffer", (PyObject *) &WriteBufferType) < 0) { return -1; } // Export mypy internal C API, be careful with the order! - static void *NativeInternal_API[17] = { - (void *)Buffer_internal, - (void *)Buffer_internal_empty, - (void *)Buffer_getvalue_internal, + static void *NativeInternal_API[LIBRT_INTERNAL_API_LEN] = { + (void *)ReadBuffer_internal, + (void *)WriteBuffer_internal, + (void *)WriteBuffer_getvalue_internal, (void *)write_bool_internal, (void *)read_bool_internal, (void *)write_str_internal, @@ -885,6 +997,8 @@ librt_internal_module_exec(PyObject *m) (void *)write_bytes_internal, (void *)read_bytes_internal, (void *)cache_version_internal, + (void *)ReadBuffer_type_internal, + (void *)WriteBuffer_type_internal, }; PyObject *c_api_object = PyCapsule_New((void *)NativeInternal_API, "librt.internal._C_API", NULL); if (PyModule_Add(m, "_C_API", c_api_object) < 0) { diff --git a/mypyc/lib-rt/librt_internal.h b/mypyc/lib-rt/librt_internal.h index 1d16e1cb127f..329a0fd68c11 100644 --- a/mypyc/lib-rt/librt_internal.h +++ b/mypyc/lib-rt/librt_internal.h @@ -1,13 +1,16 @@ #ifndef LIBRT_INTERNAL_H #define LIBRT_INTERNAL_H -#define LIBRT_INTERNAL_ABI_VERSION 0 +#define LIBRT_INTERNAL_ABI_VERSION 1 +#define LIBRT_INTERNAL_API_LEN 19 #ifdef LIBRT_INTERNAL_MODULE -static PyObject *Buffer_internal(PyObject *source); -static PyObject *Buffer_internal_empty(void); -static PyObject *Buffer_getvalue_internal(PyObject *self); +static PyObject *ReadBuffer_internal(PyObject *source); +static PyObject *WriteBuffer_internal(void); +static PyObject *WriteBuffer_getvalue_internal(PyObject *self); +static PyObject *ReadBuffer_internal(PyObject *source); +static PyObject *ReadBuffer_internal_empty(void); static char write_bool_internal(PyObject *data, char value); static char read_bool_internal(PyObject *data); static char write_str_internal(PyObject *data, PyObject *value); @@ -22,14 +25,16 @@ static int NativeInternal_ABI_Version(void); static char write_bytes_internal(PyObject *data, PyObject *value); static PyObject *read_bytes_internal(PyObject *data); static uint8_t cache_version_internal(void); +static PyTypeObject *ReadBuffer_type_internal(void); +static PyTypeObject *WriteBuffer_type_internal(void); #else -static void **NativeInternal_API; +static void *NativeInternal_API[LIBRT_INTERNAL_API_LEN]; -#define Buffer_internal (*(PyObject* (*)(PyObject *source)) NativeInternal_API[0]) -#define Buffer_internal_empty (*(PyObject* (*)(void)) NativeInternal_API[1]) -#define Buffer_getvalue_internal (*(PyObject* (*)(PyObject *source)) NativeInternal_API[2]) +#define ReadBuffer_internal (*(PyObject* (*)(PyObject *source)) NativeInternal_API[0]) +#define WriteBuffer_internal (*(PyObject* (*)(void)) NativeInternal_API[1]) +#define WriteBuffer_getvalue_internal (*(PyObject* (*)(PyObject *source)) NativeInternal_API[2]) #define write_bool_internal (*(char (*)(PyObject *source, char value)) NativeInternal_API[3]) #define read_bool_internal (*(char (*)(PyObject *source)) NativeInternal_API[4]) #define write_str_internal (*(char (*)(PyObject *source, PyObject *value)) NativeInternal_API[5]) @@ -44,6 +49,8 @@ static void **NativeInternal_API; #define write_bytes_internal (*(char (*)(PyObject *source, PyObject *value)) NativeInternal_API[14]) #define read_bytes_internal (*(PyObject* (*)(PyObject *source)) NativeInternal_API[15]) #define cache_version_internal (*(uint8_t (*)(void)) NativeInternal_API[16]) +#define ReadBuffer_type_internal (*(PyTypeObject* (*)(void)) NativeInternal_API[17]) +#define WriteBuffer_type_internal (*(PyTypeObject* (*)(void)) NativeInternal_API[18]) static int import_librt_internal(void) @@ -52,9 +59,10 @@ import_librt_internal(void) if (mod == NULL) return -1; Py_DECREF(mod); // we import just for the side effect of making the below work. - NativeInternal_API = (void **)PyCapsule_Import("librt.internal._C_API", 0); - if (NativeInternal_API == NULL) + void *capsule = PyCapsule_Import("librt.internal._C_API", 0); + if (capsule == NULL) return -1; + memcpy(NativeInternal_API, capsule, sizeof(NativeInternal_API)); if (NativeInternal_ABI_Version() != LIBRT_INTERNAL_ABI_VERSION) { PyErr_SetString(PyExc_ValueError, "ABI version conflict for librt.internal"); return -1; @@ -63,4 +71,13 @@ import_librt_internal(void) } #endif + +static inline bool CPyReadBuffer_Check(PyObject *obj) { + return Py_TYPE(obj) == ReadBuffer_type_internal(); +} + +static inline bool CPyWriteBuffer_Check(PyObject *obj) { + return Py_TYPE(obj) == WriteBuffer_type_internal(); +} + #endif // LIBRT_INTERNAL_H diff --git a/mypyc/primitives/misc_ops.py b/mypyc/primitives/misc_ops.py index 10f4bc001e29..f685b1cfbcf5 100644 --- a/mypyc/primitives/misc_ops.py +++ b/mypyc/primitives/misc_ops.py @@ -333,31 +333,32 @@ error_kind=ERR_NEVER, ) -buffer_rprimitive = KNOWN_NATIVE_TYPES["librt.internal.Buffer"] +write_buffer_rprimitive = KNOWN_NATIVE_TYPES["librt.internal.WriteBuffer"] +read_buffer_rprimitive = KNOWN_NATIVE_TYPES["librt.internal.ReadBuffer"] -# Buffer(source) +# ReadBuffer(source) function_op( - name="librt.internal.Buffer", + name="librt.internal.ReadBuffer", arg_types=[bytes_rprimitive], - return_type=buffer_rprimitive, - c_function_name="Buffer_internal", + return_type=read_buffer_rprimitive, + c_function_name="ReadBuffer_internal", error_kind=ERR_MAGIC, ) -# Buffer() +# WriteBuffer() function_op( - name="librt.internal.Buffer", + name="librt.internal.WriteBuffer", arg_types=[], - return_type=buffer_rprimitive, - c_function_name="Buffer_internal_empty", + return_type=write_buffer_rprimitive, + c_function_name="WriteBuffer_internal", error_kind=ERR_MAGIC, ) method_op( name="getvalue", - arg_types=[buffer_rprimitive], + arg_types=[write_buffer_rprimitive], return_type=bytes_rprimitive, - c_function_name="Buffer_getvalue_internal", + c_function_name="WriteBuffer_getvalue_internal", error_kind=ERR_MAGIC, ) diff --git a/mypyc/test-data/irbuild-classes.test b/mypyc/test-data/irbuild-classes.test index a8ee7213ef96..0f8ec2b094f0 100644 --- a/mypyc/test-data/irbuild-classes.test +++ b/mypyc/test-data/irbuild-classes.test @@ -1453,7 +1453,7 @@ class TestOverload: from typing import Final from mypy_extensions import u8 from librt.internal import ( - Buffer, write_bool, read_bool, write_str, read_str, write_float, read_float, + WriteBuffer, ReadBuffer, write_bool, read_bool, write_str, read_str, write_float, read_float, write_int, read_int, write_tag, read_tag, write_bytes, read_bytes, cache_version, ) @@ -1462,7 +1462,7 @@ Tag = u8 TAG: Final[Tag] = 1 def foo() -> None: - b = Buffer() + b = WriteBuffer() write_str(b, "foo") write_bytes(b, b"bar") write_bool(b, True) @@ -1470,23 +1470,23 @@ def foo() -> None: write_int(b, 1) write_tag(b, TAG) - b = Buffer(b.getvalue()) - x = read_str(b) - xb = read_bytes(b) - y = read_bool(b) - z = read_float(b) - t = read_int(b) - u = read_tag(b) + rb = ReadBuffer(b.getvalue()) + x = read_str(rb) + xb = read_bytes(rb) + y = read_bool(rb) + z = read_float(rb) + t = read_int(rb) + u = read_tag(rb) v = cache_version() [out] def foo(): - r0, b :: librt.internal.Buffer + r0, b :: librt.internal.WriteBuffer r1 :: str r2 :: None r3 :: bytes r4, r5, r6, r7, r8 :: None r9 :: bytes - r10 :: librt.internal.Buffer + r10, rb :: librt.internal.ReadBuffer r11, x :: str r12, xb :: bytes r13, y :: bool @@ -1494,7 +1494,7 @@ def foo(): r15, t :: int r16, u, r17, v :: u8 L0: - r0 = Buffer_internal_empty() + r0 = WriteBuffer_internal() b = r0 r1 = 'foo' r2 = write_str_internal(b, r1) @@ -1504,20 +1504,20 @@ L0: r6 = write_float_internal(b, 0.1) r7 = write_int_internal(b, 2) r8 = write_tag_internal(b, 1) - r9 = Buffer_getvalue_internal(b) - r10 = Buffer_internal(r9) - b = r10 - r11 = read_str_internal(b) + r9 = WriteBuffer_getvalue_internal(b) + r10 = ReadBuffer_internal(r9) + rb = r10 + r11 = read_str_internal(rb) x = r11 - r12 = read_bytes_internal(b) + r12 = read_bytes_internal(rb) xb = r12 - r13 = read_bool_internal(b) + r13 = read_bool_internal(rb) y = r13 - r14 = read_float_internal(b) + r14 = read_float_internal(rb) z = r14 - r15 = read_int_internal(b) + r15 = read_int_internal(rb) t = r15 - r16 = read_tag_internal(b) + r16 = read_tag_internal(rb) u = r16 r17 = cache_version_internal() v = r17 diff --git a/mypyc/test-data/run-classes.test b/mypyc/test-data/run-classes.test index 0805da184e1a..54c3268f038b 100644 --- a/mypyc/test-data/run-classes.test +++ b/mypyc/test-data/run-classes.test @@ -2711,14 +2711,18 @@ from native import Player Player.MIN = [case testBufferRoundTrip_librt_internal] -from typing import Final +from __future__ import annotations + +from typing import Final, Any from mypy_extensions import u8 from librt.internal import ( - Buffer, write_bool, read_bool, write_str, read_str, write_float, read_float, + ReadBuffer, WriteBuffer, write_bool, read_bool, write_str, read_str, write_float, read_float, write_int, read_int, write_tag, read_tag, write_bytes, read_bytes, cache_version, ) +from testutil import assertRaises + Tag = u8 TAG_A: Final[Tag] = 33 TAG_B: Final[Tag] = 255 @@ -2726,11 +2730,89 @@ TAG_SPECIAL: Final[Tag] = 239 def test_buffer_basic() -> None: assert cache_version() == 0 - b = Buffer(b"foo") - assert b.getvalue() == b"foo" + w = WriteBuffer() + write_str(w, "foo") + r = ReadBuffer(w.getvalue()) + assert read_str(r) == "foo" + +def test_buffer_grow() -> None: + w = WriteBuffer() + n = 100 * 1000 + for i in range(n): + write_int(w, i & 63) + r = ReadBuffer(w.getvalue()) + for i in range(n): + assert read_int(r) == (i & 63) + with assertRaises(ValueError): + read_int(r) + +def test_buffer_primitive_types() -> None: + a1: Any = WriteBuffer() + w: WriteBuffer = a1 + write_str(w, "foo") + data = w.getvalue() + assert read_str(ReadBuffer(data)) == "foo" + a2: Any = ReadBuffer(b"foo") + with assertRaises(TypeError): + w2: WriteBuffer = a2 + + a3: Any = ReadBuffer(data) + r: ReadBuffer = a3 + assert read_str(r) == "foo" + a4: Any = WriteBuffer() + with assertRaises(TypeError): + r2: ReadBuffer = a4 + +def test_type_check_args_in_write_functions() -> None: + # Test calling wrapper functions with invalid arg types + from librt import internal + alias: Any = internal + w = WriteBuffer() + with assertRaises(TypeError): + alias.write_str(None, "foo") + with assertRaises(TypeError): + alias.write_str(w, None) + with assertRaises(TypeError): + alias.write_bool(None, True) + with assertRaises(TypeError): + alias.write_bool(w, None) + with assertRaises(TypeError): + alias.write_bytes(None, b"foo") + with assertRaises(TypeError): + alias.write_bytes(w, None) + with assertRaises(TypeError): + alias.write_float(None, 1.5) + with assertRaises(TypeError): + alias.write_float(w, None) + with assertRaises(TypeError): + alias.write_int(None, 15) + with assertRaises(TypeError): + alias.write_int(w, None) + with assertRaises(TypeError): + alias.write_tag(None, 15) + with assertRaises(TypeError): + alias.write_tag(w, None) + +def test_type_check_buffer_in_read_functions() -> None: + # Test calling wrapper functions with invalid arg types + from librt import internal + alias: Any = internal + with assertRaises(TypeError): + alias.read_str(None) + with assertRaises(TypeError): + alias.read_bool(None) + with assertRaises(TypeError): + alias.read_bytes(None) + with assertRaises(TypeError): + alias.read_float(None) + with assertRaises(TypeError): + alias.read_int(None) + with assertRaises(TypeError): + alias.read_tag(None) def test_buffer_roundtrip() -> None: - b = Buffer() + b: WriteBuffer | ReadBuffer + b = WriteBuffer() write_str(b, "foo") write_bool(b, True) write_str(b, "bar" * 1000) @@ -2757,7 +2839,7 @@ def test_buffer_roundtrip() -> None: write_int(b, 536860912) write_int(b, 1234567891) - b = Buffer(b.getvalue()) + b = ReadBuffer(b.getvalue()) assert read_str(b) == "foo" assert read_bool(b) is True assert read_str(b) == "bar" * 1000 @@ -2785,77 +2867,80 @@ def test_buffer_roundtrip() -> None: assert read_int(b) == 1234567891 def test_buffer_int_size() -> None: + b: WriteBuffer | ReadBuffer for i in (-10, -9, 0, 116, 117): - b = Buffer() + b = WriteBuffer() write_int(b, i) assert len(b.getvalue()) == 1 - b = Buffer(b.getvalue()) + b = ReadBuffer(b.getvalue()) assert read_int(b) == i for i in (-100, -11, 118, 12344, 16283): - b = Buffer() + b = WriteBuffer() write_int(b, i) assert len(b.getvalue()) == 2 - b = Buffer(b.getvalue()) + b = ReadBuffer(b.getvalue()) assert read_int(b) == i for i in (-10000, 16284, 123456789): - b = Buffer() + b = WriteBuffer() write_int(b, i) assert len(b.getvalue()) == 4 - b = Buffer(b.getvalue()) + b = ReadBuffer(b.getvalue()) assert read_int(b) == i def test_buffer_int_powers() -> None: # 0, 1, 2 are tested above for p in range(2, 200): - b = Buffer() + b = WriteBuffer() write_int(b, 1 << p) write_int(b, (1 << p) - 1) write_int(b, -1 << p) write_int(b, (-1 << p) + 1) - b = Buffer(b.getvalue()) - assert read_int(b) == 1 << p - assert read_int(b) == (1 << p) - 1 - assert read_int(b) == -1 << p - assert read_int(b) == (-1 << p) + 1 + rb = ReadBuffer(b.getvalue()) + assert read_int(rb) == 1 << p + assert read_int(rb) == (1 << p) - 1 + assert read_int(rb) == -1 << p + assert read_int(rb) == (-1 << p) + 1 def test_positive_long_int_serialized_bytes() -> None: - b = Buffer() + b = WriteBuffer() n = 0x123456789ab write_int(b, n) x = b.getvalue() # Two prefix bytes, followed by little endian encoded integer in variable-length format assert x == b"\x0f\x2c\xab\x89\x67\x45\x23\x01" - b = Buffer(x) - assert read_int(b) == n + rb = ReadBuffer(x) + assert read_int(rb) == n def test_negative_long_int_serialized_bytes() -> None: - b = Buffer() + b = WriteBuffer() n = -0x123456789abcde write_int(b, n) x = b.getvalue() assert x == b"\x0f\x32\xde\xbc\x9a\x78\x56\x34\x12" - b = Buffer(x) - assert read_int(b) == n + rb = ReadBuffer(x) + assert read_int(rb) == n def test_buffer_str_size() -> None: + b: WriteBuffer | ReadBuffer for s in ("", "a", "a" * 117): - b = Buffer() + b = WriteBuffer() write_str(b, s) assert len(b.getvalue()) == len(s) + 1 - b = Buffer(b.getvalue()) + b = ReadBuffer(b.getvalue()) assert read_str(b) == s for s in ("a" * 118, "a" * 16283): - b = Buffer() + b = WriteBuffer() write_str(b, s) assert len(b.getvalue()) == len(s) + 2 - b = Buffer(b.getvalue()) + b = ReadBuffer(b.getvalue()) assert read_str(b) == s [file driver.py] from native import * test_buffer_basic() +test_buffer_grow() test_buffer_roundtrip() test_buffer_int_size() test_buffer_str_size() @@ -2864,11 +2949,13 @@ test_positive_long_int_serialized_bytes() test_negative_long_int_serialized_bytes() def test_buffer_basic_interpreted() -> None: - b = Buffer(b"foo") - assert b.getvalue() == b"foo" + b = WriteBuffer() + write_str(b, "foo") + b = ReadBuffer(b.getvalue()) + assert read_str(b) == "foo" def test_buffer_roundtrip_interpreted() -> None: - b = Buffer() + b = WriteBuffer() write_str(b, "foo") write_bool(b, True) write_str(b, "bar" * 1000) @@ -2893,7 +2980,7 @@ def test_buffer_roundtrip_interpreted() -> None: write_int(b, 536860912) write_int(b, 1234567891) - b = Buffer(b.getvalue()) + b = ReadBuffer(b.getvalue()) assert read_str(b) == "foo" assert read_bool(b) is True assert read_str(b) == "bar" * 1000 @@ -2920,47 +3007,47 @@ def test_buffer_roundtrip_interpreted() -> None: def test_buffer_int_size_interpreted() -> None: for i in (-10, -9, 0, 116, 117): - b = Buffer() + b = WriteBuffer() write_int(b, i) assert len(b.getvalue()) == 1 - b = Buffer(b.getvalue()) + b = ReadBuffer(b.getvalue()) assert read_int(b) == i for i in (-100, -11, 118, 12344, 16283): - b = Buffer() + b = WriteBuffer() write_int(b, i) assert len(b.getvalue()) == 2 - b = Buffer(b.getvalue()) + b = ReadBuffer(b.getvalue()) assert read_int(b) == i for i in (-10000, 16284, 123456789): - b = Buffer() + b = WriteBuffer() write_int(b, i) assert len(b.getvalue()) == 4 - b = Buffer(b.getvalue()) + b = ReadBuffer(b.getvalue()) assert read_int(b) == i def test_buffer_int_powers_interpreted() -> None: # 0, 1, 2 are tested above for p in range(2, 9): - b = Buffer() + b = WriteBuffer() write_int(b, 1 << p) write_int(b, -1 << p) - b = Buffer(b.getvalue()) + b = ReadBuffer(b.getvalue()) assert read_int(b) == 1 << p assert read_int(b) == -1 << p def test_buffer_str_size_interpreted() -> None: for s in ("", "a", "a" * 117): - b = Buffer() + b = WriteBuffer() write_str(b, s) assert len(b.getvalue()) == len(s) + 1 - b = Buffer(b.getvalue()) + b = ReadBuffer(b.getvalue()) assert read_str(b) == s for s in ("a" * 118, "a" * 16283): - b = Buffer() + b = WriteBuffer() write_str(b, s) assert len(b.getvalue()) == len(s) + 2 - b = Buffer(b.getvalue()) + b = ReadBuffer(b.getvalue()) assert read_str(b) == s test_buffer_basic_interpreted() @@ -2970,12 +3057,12 @@ test_buffer_str_size_interpreted() test_buffer_int_powers_interpreted() [case testBufferEmpty_librt_internal] -from librt.internal import Buffer, write_int, read_int +from librt.internal import WriteBuffer, ReadBuffer, write_int, read_int def test_empty() -> None: - b = Buffer(b"") + b = WriteBuffer() write_int(b, 42) - b1 = Buffer(b.getvalue()) + b1 = ReadBuffer(b.getvalue()) assert read_int(b1) == 42 [case testEnumMethodCalls] @@ -5362,37 +5449,37 @@ test_deletable_attr() [case testBufferCorruptedData_librt_internal] from librt.internal import ( - Buffer, read_bool, read_str, read_float, read_int, read_tag, read_bytes + ReadBuffer, read_bool, read_str, read_float, read_int, read_tag, read_bytes ) from random import randbytes def check(data: bytes) -> None: - b = Buffer(data) + b = ReadBuffer(data) try: while True: read_bool(b) except ValueError: pass - b = Buffer(data) + b = ReadBuffer(data) read_tag(b) # Always succeeds try: while True: read_int(b) except ValueError: pass - b = Buffer(data) + b = ReadBuffer(data) try: while True: read_str(b) except ValueError: pass - b = Buffer(data) + b = ReadBuffer(data) try: while True: read_bytes(b) except ValueError: pass - b = Buffer(data) + b = ReadBuffer(data) try: while True: read_float(b)