Skip to content

Commit 797e665

Browse files
committed
Initial addition of numpy ndarrays to BinaryVector. New tests
1 parent 6a796c8 commit 797e665

File tree

5 files changed

+384
-21
lines changed

5 files changed

+384
-21
lines changed

bson/binary.py

Lines changed: 107 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,15 @@
6666
from mmap import mmap as _mmap
6767

6868

69+
_NUMPY_AVAILABLE = False
70+
try:
71+
import numpy as np
72+
73+
_NUMPY_AVAILABLE = True
74+
except ImportError:
75+
np = None # type: ignore
76+
77+
6978
class UuidRepresentation:
7079
UNSPECIFIED = 0
7180
"""An unspecified UUID representation.
@@ -234,13 +243,22 @@ class BinaryVector:
234243

235244
__slots__ = ("data", "dtype", "padding")
236245

237-
def __init__(self, data: Sequence[float | int], dtype: BinaryVectorDtype, padding: int = 0):
246+
def __init__(
247+
self,
248+
data: Union[Sequence[float | int], np.ndarray],
249+
dtype: BinaryVectorDtype,
250+
padding: int = 0,
251+
):
238252
"""
239253
:param data: Sequence of numbers representing the mathematical vector.
240254
:param dtype: The data type stored in binary
241255
:param padding: The number of bits in the final byte that are to be ignored
242256
when a vector element's size is less than a byte
243257
and the length of the vector is not a multiple of 8.
258+
(This is equivalent to a negative value of `count` in`numpy.unpackbits`_)
259+
260+
.. _numpy.unpackbits: https://numpy.org/doc/stable/reference/generated/numpy.unpackbits.html
261+
244262
"""
245263
self.data = data
246264
self.dtype = dtype
@@ -425,9 +443,19 @@ def from_vector(
425443
...
426444

427445
@classmethod
446+
@overload
428447
def from_vector(
429448
cls: Type[Binary],
430-
vector: Union[BinaryVector, list[int], list[float]],
449+
vector: np.ndarray,
450+
dtype: BinaryVectorDtype,
451+
padding: int = 0,
452+
) -> Binary:
453+
...
454+
455+
@classmethod
456+
def from_vector(
457+
cls: Type[Binary],
458+
vector: Union[BinaryVector, list[int], list[float], np.ndarray],
431459
dtype: Optional[BinaryVectorDtype] = None,
432460
padding: Optional[int] = None,
433461
) -> Binary:
@@ -459,25 +487,29 @@ def from_vector(
459487
vector = vector.data # type: ignore
460488

461489
padding = 0 if padding is None else padding
462-
if dtype == BinaryVectorDtype.INT8: # pack ints in [-128, 127] as signed int8
463-
format_str = "b"
464-
if padding:
465-
raise ValueError(f"padding does not apply to {dtype=}")
466-
elif dtype == BinaryVectorDtype.PACKED_BIT: # pack ints in [0, 255] as unsigned uint8
467-
format_str = "B"
468-
if 0 <= padding > 7:
469-
raise ValueError(f"{padding=}. It must be in [0,1, ..7].")
470-
if padding and not vector:
471-
raise ValueError("Empty vector with non-zero padding.")
472-
elif dtype == BinaryVectorDtype.FLOAT32: # pack floats as float32
473-
format_str = "f"
474-
if padding:
475-
raise ValueError(f"padding does not apply to {dtype=}")
490+
metadata = struct.pack("<sB", dtype.value, padding)
491+
492+
if isinstance(vector, np.ndarray):
493+
data = _numpy_vector_to_bytes(vector, dtype)
476494
else:
477-
raise NotImplementedError("%s not yet supported" % dtype)
495+
if dtype == BinaryVectorDtype.INT8: # pack ints in [-128, 127] as signed int8
496+
format_str = "b"
497+
if padding:
498+
raise ValueError(f"padding does not apply to {dtype=}")
499+
elif dtype == BinaryVectorDtype.PACKED_BIT: # pack ints in [0, 255] as unsigned uint8
500+
format_str = "B"
501+
if 0 <= padding > 7:
502+
raise ValueError(f"{padding=}. It must be in [0,1, ..7].")
503+
if padding and not vector:
504+
raise ValueError("Empty vector with non-zero padding.")
505+
elif dtype == BinaryVectorDtype.FLOAT32: # pack floats as float32
506+
format_str = "f"
507+
if padding:
508+
raise ValueError(f"padding does not apply to {dtype=}")
509+
else:
510+
raise NotImplementedError("%s not yet supported" % dtype)
511+
data = struct.pack(f"<{len(vector)}{format_str}", *vector) # type: ignore
478512

479-
metadata = struct.pack("<sB", dtype.value, padding)
480-
data = struct.pack(f"<{len(vector)}{format_str}", *vector) # type: ignore
481513
if padding and len(vector) and not (data[-1] & ((1 << padding) - 1)) == 0:
482514
raise ValueError(
483515
"Vector has a padding P, but bits in the final byte lower than P are non-zero. They must be zero."
@@ -549,6 +581,33 @@ def subtype(self) -> int:
549581
"""Subtype of this binary data."""
550582
return self.__subtype
551583

584+
def as_numpy_vector(self) -> BinaryVector:
585+
"""From the Binary, create a BinaryVector where data is a 1-dim numpy array.
586+
dtype still follows our typing (BinaryVectorDtype),
587+
and padding is as we define it, notably equivalent to a negative value of count
588+
in `numpy.unpackbits <https://numpy.org/doc/stable/reference/generated/numpy.unpackbits.html>`_.
589+
590+
:return: BinaryVector
591+
592+
.. versionadded:: 4.16
593+
"""
594+
if self.subtype != VECTOR_SUBTYPE:
595+
raise ValueError(f"Cannot decode subtype {self.subtype} as a vector")
596+
if not _NUMPY_AVAILABLE:
597+
raise ImportError("Converting binary to numpy.ndarray requires numpy to be installed.")
598+
dtype, padding = struct.unpack_from("<sB", self, 0)
599+
dtype = BinaryVectorDtype(dtype)
600+
match dtype:
601+
case BinaryVectorDtype.INT8:
602+
data = np.frombuffer(self[2:], dtype="int8")
603+
case BinaryVectorDtype.FLOAT32:
604+
data = np.frombuffer(self[2:], dtype="float32")
605+
case BinaryVectorDtype.PACKED_BIT:
606+
data = np.frombuffer(self[2:], dtype="uint8")
607+
case _:
608+
raise ValueError(f"Unsupported dtype code: {dtype!r}")
609+
return BinaryVector(data, dtype, padding)
610+
552611
def __getnewargs__(self) -> Tuple[bytes, int]: # type: ignore[override]
553612
# Work around http://bugs.python.org/issue7382
554613
data = super().__getnewargs__()[0]
@@ -575,3 +634,32 @@ def __repr__(self) -> str:
575634
return f"<Binary(REDACTED, {self.__subtype})>"
576635
else:
577636
return f"Binary({bytes.__repr__(self)}, {self.__subtype})"
637+
638+
639+
def _numpy_vector_to_bytes(
640+
vector: np.ndarray,
641+
dtype: BinaryVectorDtype,
642+
) -> bytes:
643+
if not _NUMPY_AVAILABLE:
644+
raise ImportError("Converting numpy.ndarray to binary requires numpy to be installed.")
645+
646+
assert isinstance(vector, np.ndarray)
647+
assert (
648+
vector.ndim == 1
649+
), "from_numpy_vector only supports 1D arrays as it creates a single vector."
650+
651+
if dtype == BinaryVectorDtype.FLOAT32:
652+
vector = vector.astype(np.dtype("float32"), copy=False)
653+
elif dtype == BinaryVectorDtype.INT8:
654+
if vector.min() >= -128 and vector.max() <= 127:
655+
vector = vector.astype(np.dtype("int8"), copy=False)
656+
else:
657+
raise ValueError("Values found outside INT8 range.")
658+
elif dtype == BinaryVectorDtype.PACKED_BIT:
659+
if vector.min() >= 0 and vector.max() <= 127:
660+
vector = vector.astype(np.dtype("uint8"), copy=False)
661+
else:
662+
raise ValueError("Values found outside UINT8 range.")
663+
else:
664+
raise NotImplementedError("%s not yet supported" % dtype)
665+
return vector.tobytes()

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ ocsp = ["requirements/ocsp.txt"]
8787
snappy = ["requirements/snappy.txt"]
8888
test = ["requirements/test.txt"]
8989
zstd = ["requirements/zstd.txt"]
90+
numpy = ["requirements/numpy.txt"]
9091

9192
[tool.pytest.ini_options]
9293
minversion = "7"

requirements/numpy.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
numpy>=1.21

test/test_bson.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@
7171
from bson.timestamp import Timestamp
7272
from bson.tz_util import FixedOffset, utc
7373

74+
_NUMPY_AVAILABLE = False
75+
try:
76+
import numpy as np
77+
78+
_NUMPY_AVAILABLE = True
79+
except ImportError:
80+
np = None # type: ignore
81+
7482

7583
class NotADict(abc.MutableMapping):
7684
"""Non-dict type that implements the mapping protocol."""
@@ -735,6 +743,60 @@ def test_uuid_legacy(self):
735743
transformed = bin.as_uuid(UuidRepresentation.PYTHON_LEGACY)
736744
self.assertEqual(id, transformed)
737745

746+
@unittest.skipIf(not _NUMPY_AVAILABLE, "numpy optional-dependency not installed.")
747+
def test_vector_from_numpy(self):
748+
"""Follows test_vector except for input type numpy.ndarray"""
749+
# Simple data values could be treated as any of our BinaryVectorDtypes
750+
arr = np.array([2, 3])
751+
# INT8
752+
binary_vector_int8 = Binary.from_vector(arr, BinaryVectorDtype.INT8)
753+
# as_vector
754+
vector = binary_vector_int8.as_vector()
755+
assert isinstance(vector, BinaryVector)
756+
assert vector.data == arr.tolist()
757+
# as_numpy_vector
758+
vector_np = binary_vector_int8.as_numpy_vector()
759+
assert isinstance(vector_np, BinaryVector)
760+
assert np.all(vector.data == arr)
761+
# PACKED_BIT
762+
binary_vector_uint8 = Binary.from_vector(arr, BinaryVectorDtype.PACKED_BIT)
763+
# as_vector
764+
vector = binary_vector_uint8.as_vector()
765+
assert isinstance(vector, BinaryVector)
766+
assert vector.data == arr.tolist()
767+
# as_numpy_vector
768+
vector_np = binary_vector_uint8.as_numpy_vector()
769+
assert isinstance(vector_np, BinaryVector)
770+
assert np.all(vector_np.data == arr)
771+
# FLOAT32
772+
binary_vector_float32 = Binary.from_vector(arr, BinaryVectorDtype.FLOAT32)
773+
# as_vector
774+
vector = binary_vector_float32.as_vector()
775+
assert isinstance(vector, BinaryVector)
776+
assert vector.data == arr.tolist()
777+
# as_numpy_vector
778+
vector_np = binary_vector_float32.as_numpy_vector()
779+
assert isinstance(vector_np, BinaryVector)
780+
assert np.all(vector_np.data == arr)
781+
782+
# Invalid cases
783+
with self.assertRaises(ValueError):
784+
Binary.from_vector(np.array([-1]), BinaryVectorDtype.PACKED_BIT)
785+
with self.assertRaises(ValueError):
786+
Binary.from_vector(np.array([128]), BinaryVectorDtype.PACKED_BIT)
787+
with self.assertRaises(ValueError):
788+
Binary.from_vector(np.array([-198]), BinaryVectorDtype.INT8)
789+
790+
# Unexpected cases
791+
# Creating a vector of INT8 from a list of doubles will be caught by struct.pack
792+
# Numpy's default behavior is to cast to the type requested.
793+
list_floats = [-1.1, 1.1]
794+
cast_bin = Binary.from_vector(np.array(list_floats), BinaryVectorDtype.INT8)
795+
vector = cast_bin.as_vector()
796+
vector_np = cast_bin.as_numpy_vector()
797+
assert vector.data != list_floats
798+
assert vector.data == vector_np.data.tolist() == [-1, 1]
799+
738800
def test_vector(self):
739801
"""Tests of subtype 9"""
740802
# We start with valid cases, across the 3 dtypes implemented.

0 commit comments

Comments
 (0)