-
Notifications
You must be signed in to change notification settings - Fork 1.1k
PYTHON-5355 Addition of API to move to and from NumPy ndarrays and BSON BinaryVectors #2590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 14 commits
797e665
47fc92c
7882f75
b7556fb
5753e3b
d3407d7
9ae90e8
aae159f
9fadf97
40120e7
be06ce7
f03b943
3cc5041
e3b894b
0b0a50b
10da245
73910ce
8dec0d3
9420ec1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -65,6 +65,9 @@ | |||||
| from array import array as _array | ||||||
| from mmap import mmap as _mmap | ||||||
|
|
||||||
| import numpy as np | ||||||
| import numpy.typing as npt | ||||||
|
|
||||||
|
|
||||||
| class UuidRepresentation: | ||||||
| UNSPECIFIED = 0 | ||||||
|
|
@@ -234,13 +237,20 @@ class BinaryVector: | |||||
|
|
||||||
| __slots__ = ("data", "dtype", "padding") | ||||||
|
|
||||||
| def __init__(self, data: Sequence[float | int], dtype: BinaryVectorDtype, padding: int = 0): | ||||||
| def __init__( | ||||||
| self, | ||||||
| data: Union[Sequence[float | int], npt.NDArray[np.number]], | ||||||
| dtype: BinaryVectorDtype, | ||||||
| padding: int = 0, | ||||||
| ): | ||||||
| """ | ||||||
| :param data: Sequence of numbers representing the mathematical vector. | ||||||
| :param dtype: The data type stored in binary | ||||||
| :param padding: The number of bits in the final byte that are to be ignored | ||||||
| when a vector element's size is less than a byte | ||||||
| and the length of the vector is not a multiple of 8. | ||||||
| (Padding is equivalent to a negative value of `count` in | ||||||
| `numpy.unpackbits <https://numpy.org/doc/stable/reference/generated/numpy.unpackbits.html>`_) | ||||||
| """ | ||||||
| self.data = data | ||||||
| self.dtype = dtype | ||||||
|
|
@@ -424,10 +434,20 @@ def from_vector( | |||||
| ) -> Binary: | ||||||
| ... | ||||||
|
|
||||||
| @classmethod | ||||||
| @overload | ||||||
| def from_vector( | ||||||
| cls: Type[Binary], | ||||||
| vector: npt.NDArray[np.number], | ||||||
| dtype: BinaryVectorDtype, | ||||||
| padding: int = 0, | ||||||
| ) -> Binary: | ||||||
| ... | ||||||
|
|
||||||
| @classmethod | ||||||
| def from_vector( | ||||||
| cls: Type[Binary], | ||||||
| vector: Union[BinaryVector, list[int], list[float]], | ||||||
| vector: Union[BinaryVector, list[int], list[float], npt.NDArray[np.number]], | ||||||
| dtype: Optional[BinaryVectorDtype] = None, | ||||||
| padding: Optional[int] = None, | ||||||
| ) -> Binary: | ||||||
|
|
@@ -459,25 +479,60 @@ def from_vector( | |||||
| vector = vector.data # type: ignore | ||||||
|
|
||||||
| padding = 0 if padding is None else padding | ||||||
| if dtype == BinaryVectorDtype.INT8: # pack ints in [-128, 127] as signed int8 | ||||||
| format_str = "b" | ||||||
| if padding: | ||||||
| raise ValueError(f"padding does not apply to {dtype=}") | ||||||
| elif dtype == BinaryVectorDtype.PACKED_BIT: # pack ints in [0, 255] as unsigned uint8 | ||||||
| format_str = "B" | ||||||
| if 0 <= padding > 7: | ||||||
| raise ValueError(f"{padding=}. It must be in [0,1, ..7].") | ||||||
| if padding and not vector: | ||||||
| raise ValueError("Empty vector with non-zero padding.") | ||||||
| elif dtype == BinaryVectorDtype.FLOAT32: # pack floats as float32 | ||||||
| format_str = "f" | ||||||
| if padding: | ||||||
| raise ValueError(f"padding does not apply to {dtype=}") | ||||||
| else: | ||||||
| raise NotImplementedError("%s not yet supported" % dtype) | ||||||
|
|
||||||
| if not isinstance(dtype, BinaryVectorDtype): | ||||||
| raise TypeError( | ||||||
| "dtype must be a bson.BinaryVectorDtype, such as BinaryVectorDtype.FLOAT32" | ||||||
caseyclements marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| ) | ||||||
| metadata = struct.pack("<sB", dtype.value, padding) | ||||||
| data = struct.pack(f"<{len(vector)}{format_str}", *vector) # type: ignore | ||||||
|
|
||||||
| if isinstance(vector, list): | ||||||
| if dtype == BinaryVectorDtype.INT8: # pack ints in [-128, 127] as signed int8 | ||||||
| format_str = "b" | ||||||
| if padding: | ||||||
| raise ValueError(f"padding does not apply to {dtype=}") | ||||||
| elif dtype == BinaryVectorDtype.PACKED_BIT: # pack ints in [0, 255] as unsigned uint8 | ||||||
| format_str = "B" | ||||||
| if 0 <= padding > 7: | ||||||
| raise ValueError(f"{padding=}. It must be in [0,1, ..7].") | ||||||
| if padding and not vector: | ||||||
| raise ValueError("Empty vector with non-zero padding.") | ||||||
| elif dtype == BinaryVectorDtype.FLOAT32: # pack floats as float32 | ||||||
| format_str = "f" | ||||||
| if padding: | ||||||
| raise ValueError(f"padding does not apply to {dtype=}") | ||||||
| else: | ||||||
| raise NotImplementedError("%s not yet supported" % dtype) | ||||||
| data = struct.pack(f"<{len(vector)}{format_str}", *vector) | ||||||
| else: # vector is numpy array or incorrect type. | ||||||
| try: | ||||||
| import numpy as np | ||||||
| except ImportError as exc: | ||||||
| raise ImportError( | ||||||
| "Failed to create binary from vector. Check type. If numpy array, numpy must be installed." | ||||||
| ) from exc | ||||||
| if not isinstance(vector, np.ndarray): | ||||||
| raise TypeError("Vector must be a numpy array.") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| if vector.ndim != 1: | ||||||
| raise ValueError( | ||||||
| "from_numpy_vector only supports 1D arrays as it creates a single vector." | ||||||
| ) | ||||||
|
|
||||||
| if dtype == BinaryVectorDtype.FLOAT32: | ||||||
| vector = vector.astype(np.dtype("float32"), copy=False) | ||||||
| elif dtype == BinaryVectorDtype.INT8: | ||||||
| if vector.min() >= -128 and vector.max() <= 127: | ||||||
| vector = vector.astype(np.dtype("int8"), copy=False) | ||||||
| else: | ||||||
| raise ValueError("Values found outside INT8 range.") | ||||||
| elif dtype == BinaryVectorDtype.PACKED_BIT: | ||||||
| if vector.min() >= 0 and vector.max() <= 127: | ||||||
| vector = vector.astype(np.dtype("uint8"), copy=False) | ||||||
| else: | ||||||
| raise ValueError("Values found outside UINT8 range.") | ||||||
| else: | ||||||
| raise NotImplementedError("%s not yet supported" % dtype) | ||||||
| data = vector.tobytes() | ||||||
|
|
||||||
| if padding and len(vector) and not (data[-1] & ((1 << padding) - 1)) == 0: | ||||||
| raise ValueError( | ||||||
| "Vector has a padding P, but bits in the final byte lower than P are non-zero. They must be zero." | ||||||
|
|
@@ -549,6 +604,54 @@ def subtype(self) -> int: | |||||
| """Subtype of this binary data.""" | ||||||
| return self.__subtype | ||||||
|
|
||||||
| def as_numpy_vector(self) -> BinaryVector: | ||||||
|
||||||
| """From the Binary, create a BinaryVector where data is a 1-dim numpy array. | ||||||
| dtype still follows our typing (BinaryVectorDtype), | ||||||
| and padding is as we define it, notably equivalent to a negative value of count | ||||||
| in `numpy.unpackbits <https://numpy.org/doc/stable/reference/generated/numpy.unpackbits.html>`_. | ||||||
|
|
||||||
| :return: BinaryVector | ||||||
|
|
||||||
| .. versionadded:: 4.16 | ||||||
| """ | ||||||
| if self.subtype != VECTOR_SUBTYPE: | ||||||
| raise ValueError(f"Cannot decode subtype {self.subtype} as a vector") | ||||||
| try: | ||||||
| import numpy as np | ||||||
| except ImportError as exc: | ||||||
| raise ImportError( | ||||||
| "Converting binary to numpy.ndarray requires numpy to be installed." | ||||||
| ) from exc | ||||||
|
|
||||||
| dtype, padding = struct.unpack_from("<sB", self, 0) | ||||||
| dtype = BinaryVectorDtype(dtype) | ||||||
| n_bytes = len(self) - 2 | ||||||
|
|
||||||
| if dtype == BinaryVectorDtype.INT8: | ||||||
| data = np.frombuffer(self[2:], dtype="int8") | ||||||
| elif dtype == BinaryVectorDtype.FLOAT32: | ||||||
| if n_bytes % 4: | ||||||
| raise ValueError( | ||||||
| "Corrupt data. N bytes for a float32 vector must be a multiple of 4." | ||||||
| ) | ||||||
| data = np.frombuffer(self[2:], dtype="float32") | ||||||
| elif dtype == BinaryVectorDtype.PACKED_BIT: | ||||||
| # data packed as uint8 | ||||||
| if padding and not n_bytes: | ||||||
| raise ValueError("Corrupt data. Vector has a padding P, but no data.") | ||||||
| if padding > 7 or padding < 0: | ||||||
| raise ValueError(f"Corrupt data. Padding ({padding}) must be between 0 and 7.") | ||||||
| data = np.frombuffer(self[2:], dtype="uint8") | ||||||
|
||||||
| if padding and np.unpackbits(data[-1])[-padding:].sum() > 0: | ||||||
| warnings.warn( | ||||||
| "Vector has a padding P, but bits in the final byte lower than P are non-zero. For pymongo>=5.0, they must be zero.", | ||||||
| DeprecationWarning, | ||||||
| stacklevel=2, | ||||||
| ) | ||||||
| else: | ||||||
| raise ValueError(f"Unsupported dtype code: {dtype!r}") | ||||||
| return BinaryVector(data, dtype, padding) | ||||||
|
|
||||||
| def __getnewargs__(self) -> Tuple[bytes, int]: # type: ignore[override] | ||||||
| # Work around http://bugs.python.org/issue7382 | ||||||
| data = super().__getnewargs__()[0] | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,7 +2,7 @@ | |
| set shell := ["bash", "-c"] | ||
|
|
||
| # Commonly used command segments. | ||
| typing_run := "uv run --group typing --extra aws --extra encryption --extra ocsp --extra snappy --extra test --extra zstd" | ||
| typing_run := "uv run --group typing --extra aws --extra encryption --with numpy --extra ocsp --extra snappy --extra test --extra zstd" | ||
| docs_run := "uv run --extra docs" | ||
| doc_build := "./doc/_build" | ||
| mypy_args := "--install-types --non-interactive" | ||
|
|
@@ -39,14 +39,14 @@ typing: && resync | |
|
|
||
| [group('typing')] | ||
| typing-mypy: && resync | ||
| {{typing_run}} mypy {{mypy_args}} bson gridfs tools pymongo | ||
| {{typing_run}} mypy {{mypy_args}} --config-file mypy_test.ini test | ||
| {{typing_run}} mypy {{mypy_args}} test/test_typing.py test/test_typing_strict.py | ||
| {{typing_run}} python -m mypy {{mypy_args}} bson gridfs tools pymongo | ||
| {{typing_run}} python -m mypy {{mypy_args}} --config-file mypy_test.ini test | ||
| {{typing_run}} python -m mypy {{mypy_args}} test/test_typing.py test/test_typing_strict.py | ||
|
|
||
| [group('typing')] | ||
| typing-pyright: && resync | ||
| {{typing_run}} pyright test/test_typing.py test/test_typing_strict.py | ||
| {{typing_run}} pyright -p strict_pyrightconfig.json test/test_typing_strict.py | ||
| {{typing_run}} python -m pyright test/test_typing.py test/test_typing_strict.py | ||
| {{typing_run}} python -m pyright -p strict_pyrightconfig.json test/test_typing_strict.py | ||
Jibola marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| [group('lint')] | ||
| lint *args="": && resync | ||
|
|
@@ -58,7 +58,13 @@ lint-manual *args="": && resync | |
|
|
||
| [group('test')] | ||
| test *args="-v --durations=5 --maxfail=10": && resync | ||
| uv run --extra test pytest {{args}} | ||
| uv run --extra test python -m pytest {{args}} | ||
|
|
||
| [group('test')] | ||
| test-bson *args="-v --durations=5 --maxfail=10": && resync | ||
|
||
| uv run --extra test --with numpy python -m pytest test/test_bson.py | ||
|
|
||
|
|
||
|
|
||
| [group('test')] | ||
| run-tests *args: && resync | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ | |
| import array | ||
| import collections | ||
| import datetime | ||
| import importlib.util | ||
| import mmap | ||
| import os | ||
| import pickle | ||
|
|
@@ -71,6 +72,8 @@ | |
| from bson.timestamp import Timestamp | ||
| from bson.tz_util import FixedOffset, utc | ||
|
|
||
| _NUMPY_AVAILABLE = importlib.util.find_spec("numpy") is not None | ||
|
|
||
|
|
||
| class NotADict(abc.MutableMapping): | ||
| """Non-dict type that implements the mapping protocol.""" | ||
|
|
@@ -871,6 +874,62 @@ def test_binaryvector_equality(self): | |
| BinaryVector([1], BinaryVectorDtype.INT8), BinaryVector([2], BinaryVectorDtype.INT8) | ||
| ) | ||
|
|
||
| @unittest.skipIf(not _NUMPY_AVAILABLE, "numpy optional-dependency not installed.") | ||
| def test_vector_from_numpy(self): | ||
|
Comment on lines
+877
to
+878
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NIT because we don't have a general policy around this, but there's three separate test themes in here.
I'd push we split those three into three separate unit tests for better semantic representation, but I'm not hard-pressed on this. |
||
| """Follows test_vector except for input type numpy.ndarray""" | ||
| # Simple data values could be treated as any of our BinaryVectorDtypes | ||
| import numpy as np | ||
|
|
||
| arr = np.array([2, 3]) | ||
| # INT8 | ||
| binary_vector_int8 = Binary.from_vector(arr, BinaryVectorDtype.INT8) | ||
| # as_vector | ||
| vector = binary_vector_int8.as_vector() | ||
| assert isinstance(vector, BinaryVector) | ||
| assert vector.data == arr.tolist() | ||
| # as_numpy_vector | ||
| vector_np = binary_vector_int8.as_numpy_vector() | ||
| assert isinstance(vector_np, BinaryVector) | ||
| assert np.all(vector.data == arr) | ||
| # PACKED_BIT | ||
| binary_vector_uint8 = Binary.from_vector(arr, BinaryVectorDtype.PACKED_BIT) | ||
| # as_vector | ||
| vector = binary_vector_uint8.as_vector() | ||
| assert isinstance(vector, BinaryVector) | ||
| assert vector.data == arr.tolist() | ||
| # as_numpy_vector | ||
| vector_np = binary_vector_uint8.as_numpy_vector() | ||
| assert isinstance(vector_np, BinaryVector) | ||
| assert np.all(vector_np.data == arr) | ||
| # FLOAT32 | ||
| binary_vector_float32 = Binary.from_vector(arr, BinaryVectorDtype.FLOAT32) | ||
| # as_vector | ||
| vector = binary_vector_float32.as_vector() | ||
| assert isinstance(vector, BinaryVector) | ||
| assert vector.data == arr.tolist() | ||
| # as_numpy_vector | ||
| vector_np = binary_vector_float32.as_numpy_vector() | ||
| assert isinstance(vector_np, BinaryVector) | ||
| assert np.all(vector_np.data == arr) | ||
|
|
||
| # Invalid cases | ||
| with self.assertRaises(ValueError): | ||
| Binary.from_vector(np.array([-1]), BinaryVectorDtype.PACKED_BIT) | ||
| with self.assertRaises(ValueError): | ||
| Binary.from_vector(np.array([128]), BinaryVectorDtype.PACKED_BIT) | ||
| with self.assertRaises(ValueError): | ||
| Binary.from_vector(np.array([-198]), BinaryVectorDtype.INT8) | ||
|
|
||
| # Unexpected cases | ||
| # Creating a vector of INT8 from a list of doubles will be caught by struct.pack | ||
| # Numpy's default behavior is to cast to the type requested. | ||
| list_floats = [-1.1, 1.1] | ||
| cast_bin = Binary.from_vector(np.array(list_floats), BinaryVectorDtype.INT8) | ||
| vector = cast_bin.as_vector() | ||
| vector_np = cast_bin.as_numpy_vector() | ||
| assert vector.data != list_floats | ||
| assert vector.data == vector_np.data.tolist() == [-1, 1] | ||
|
|
||
| def test_unicode_regex(self): | ||
| """Tests we do not get a segfault for C extension on unicode RegExs. | ||
| This had been happening. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is
BinaryVectora subtype of list?