diff --git a/tests/test_search_binary_preservation.py b/tests/test_search_binary_preservation.py new file mode 100644 index 00000000..86157506 --- /dev/null +++ b/tests/test_search_binary_preservation.py @@ -0,0 +1,192 @@ +""" +Binary preservation tests for search results. + +These tests are in a separate file because the main search test suite (test_search.py) +has compatibility issues with the current Valkey search module version. Most existing +search tests fail due to unsupported field types and parameters (e.g., TEXT fields, +SKIPINITIALSCAN, etc.). + +Our binary preservation functionality works correctly with the current search module +using direct FT.CREATE commands and KNN vector queries, so we maintain these tests +separately to ensure the feature remains properly tested while the broader search +test compatibility issues are resolved. +""" + +import struct + +import pytest +import valkey + +from .conftest import _get_client, is_resp2_connection, skip_ifmodversion_lt + + +@pytest.mark.valkeymod +@skip_ifmodversion_lt("1.0.0", "search") +def test_vector_binary_preservation_default_behavior(request): + """Test that default behavior still corrupts binary data (backward compatibility)""" + client = _get_client(valkey.Valkey, request, decode_responses=False) + + # Create index with vector field using direct command + client.execute_command( + "FT.CREATE", "test_idx", "SCHEMA", + "embedding", "VECTOR", "FLAT", "6", "TYPE", "FLOAT32", "DIM", "3", + "DISTANCE_METRIC", "COSINE" + ) + + # Create vector data as bytes (simulating embeddings) + vec1 = [0.1, 0.2, 0.3] + vec1_bytes = struct.pack('3f', *vec1) + + # Store document with vector + client.hset("doc:1", mapping={"embedding": vec1_bytes}) + + # Search without preserve_bytes (default behavior) using KNN query + results = client.ft("test_idx").search( + "*=>[KNN 1 @embedding $vec]", {"vec": vec1_bytes} + ) + + if is_resp2_connection(client): + doc = results.docs[0] + # Default behavior should decode bytes to string (corrupting binary data) + assert isinstance(doc.embedding, str) + assert doc.embedding != vec1_bytes # Should be corrupted + + client.execute_command("FT.DROPINDEX", "test_idx") + + +@pytest.mark.valkeymod +@skip_ifmodversion_lt("1.0.0", "search") +def test_vector_binary_preservation_enabled(request): + """Test that preserve_bytes=True preserves binary vector data""" + client = _get_client(valkey.Valkey, request, decode_responses=False) + + # Create index with vector field using direct command + client.execute_command( + "FT.CREATE", "test_idx", "SCHEMA", + "embedding", "VECTOR", "FLAT", "6", "TYPE", "FLOAT32", "DIM", "3", + "DISTANCE_METRIC", "COSINE" + ) + + # Create vector data as bytes (simulating embeddings) + vec1 = [0.1, 0.2, 0.3] + vec1_bytes = struct.pack('3f', *vec1) + + # Store document with vector + client.hset("doc:1", mapping={"embedding": vec1_bytes}) + + # Search with preserve_bytes=True using KNN query + results = client.ft("test_idx").search( + "*=>[KNN 1 @embedding $vec]", {"vec": vec1_bytes}, preserve_bytes=True + ) + + if is_resp2_connection(client): + doc = results.docs[0] + # With preserve_bytes=True, binary data should be preserved + assert isinstance(doc.embedding, bytes) + assert doc.embedding == vec1_bytes + + client.execute_command("FT.DROPINDEX", "test_idx") + + +@pytest.mark.valkeymod +@skip_ifmodversion_lt("1.0.0", "search") +def test_multiple_field_types_and_vectors(request): + """Test binary preservation with multiple field types and vector dimensions""" + client = _get_client(valkey.Valkey, request, decode_responses=False) + + # Create index with diverse field types and different vector dimensions + client.execute_command( + "FT.CREATE", "test_idx", "SCHEMA", + "title", "TAG", + "price", "NUMERIC", + "embedding_3d", "VECTOR", "FLAT", "6", "TYPE", "FLOAT32", "DIM", "3", + "DISTANCE_METRIC", "COSINE", + "embedding_4d", "VECTOR", "FLAT", "6", "TYPE", "FLOAT32", "DIM", "4", + "DISTANCE_METRIC", "L2", + "binary_data", "TAG" + ) + + # Create test data with different vector dimensions + vec_3d = [0.1, 0.2, 0.3] + vec_3d_bytes = struct.pack("3f", *vec_3d) + vec_4d = [0.4, 0.5, 0.6, 0.7] + vec_4d_bytes = struct.pack("4f", *vec_4d) + + # Store multiple documents + for i in range(3): + client.hset(f"doc:{i + 1}", mapping={ + "title": f"item_{i + 1}", + "price": 10.0 + i, + "embedding_3d": vec_3d_bytes, + "embedding_4d": vec_4d_bytes, + "binary_data": b"binary_content" + }) + + # Test with multiple results (KNN 3 instead of KNN 1) + results = client.ft("test_idx").search( + "*=>[KNN 3 @embedding_3d $vec]", + {"vec": vec_3d_bytes}, + preserve_bytes=True, + binary_fields=["embedding_3d", "embedding_4d"] + ) + + if is_resp2_connection(client): + assert len(results.docs) == 3 + for doc in results.docs: + # Vector fields should be preserved as bytes + assert isinstance(doc.embedding_3d, bytes) + assert doc.embedding_3d == vec_3d_bytes + assert isinstance(doc.embedding_4d, bytes) + assert doc.embedding_4d == vec_4d_bytes + # Non-binary fields should be strings + assert isinstance(doc.title, str) + assert isinstance(doc.binary_data, str) + + client.execute_command("FT.DROPINDEX", "test_idx") + + +@pytest.mark.valkeymod +@skip_ifmodversion_lt("1.0.0", "search") +def test_binary_fields_selective_preservation(request): + """Test that binary_fields parameter selectively preserves specific fields""" + client = _get_client(valkey.Valkey, request, decode_responses=False) + + # Create index with vector and tag fields using direct command + client.execute_command( + "FT.CREATE", "test_idx", "SCHEMA", + "embedding1", "VECTOR", "FLAT", "6", "TYPE", "FLOAT32", "DIM", "3", + "DISTANCE_METRIC", "COSINE", + "embedding2", "VECTOR", "FLAT", "6", "TYPE", "FLOAT32", "DIM", "3", + "DISTANCE_METRIC", "COSINE", + "binary_tag", "TAG" + ) + + # Create vector data as bytes + vec1 = [0.1, 0.2, 0.3] + vec1_bytes = struct.pack("3f", *vec1) + vec2 = [0.4, 0.5, 0.6] + vec2_bytes = struct.pack("3f", *vec2) + + # Store document with vectors and tag + client.hset("doc:1", mapping={ + "embedding1": vec1_bytes, + "embedding2": vec2_bytes, + "binary_tag": b"test_tag" + }) + + # Search with selective binary preservation (only embedding1) using KNN query + results = client.ft("test_idx").search( + "*=>[KNN 1 @embedding1 $vec]", + {"vec": vec1_bytes}, + preserve_bytes=True, + binary_fields=["embedding1"] + ) + + if is_resp2_connection(client): + doc = results.docs[0] + assert isinstance(doc.embedding1, bytes) + assert doc.embedding1 == vec1_bytes + assert isinstance(doc.embedding2, str) + assert isinstance(doc.binary_tag, str) + + client.execute_command("FT.DROPINDEX", "test_idx") diff --git a/valkey/commands/search/_util.py b/valkey/commands/search/_util.py index dd1dff33..59f697f6 100644 --- a/valkey/commands/search/_util.py +++ b/valkey/commands/search/_util.py @@ -5,3 +5,15 @@ def to_string(s): return s.decode("utf-8", "ignore") else: return s # Not a string we care about + + +def to_string_or_bytes(s, preserve_bytes=False, binary_fields=None, field_name=None): + """Convert value to string or preserve as bytes based on parameters.""" + if isinstance(s, str): + return s + elif isinstance(s, bytes): + if preserve_bytes and (binary_fields is None or field_name in binary_fields): + return s # Keep as bytes + return s.decode("utf-8", "ignore") + else: + return s # Not a string we care about diff --git a/valkey/commands/search/commands.py b/valkey/commands/search/commands.py index 14823d6f..efcc080d 100644 --- a/valkey/commands/search/commands.py +++ b/valkey/commands/search/commands.py @@ -80,6 +80,8 @@ def _parse_search(self, res, **kwargs): duration=kwargs["duration"], has_payload=kwargs["query"]._with_payloads, with_scores=kwargs["query"]._with_scores, + preserve_bytes=kwargs.get("preserve_bytes", False), + binary_fields=kwargs.get("binary_fields", None), ) def _parse_aggregate(self, res, **kwargs): @@ -96,6 +98,8 @@ def _parse_profile(self, res, **kwargs): duration=kwargs["duration"], has_payload=query._with_payloads, with_scores=query._with_scores, + preserve_bytes=kwargs.get("preserve_bytes", False), + binary_fields=kwargs.get("binary_fields", None), ) return result, parse_to_dict(res[1]) @@ -484,6 +488,8 @@ def search( self, query: Union[str, Query], query_params: Union[Dict[str, Union[str, int, float, bytes]], None] = None, + preserve_bytes: bool = False, + binary_fields: Optional[List[str]] = None, ): """ Search the index for a given query, and return a result of documents @@ -493,6 +499,11 @@ def search( - **query**: the search query. Either a text for simple queries with default parameters, or a Query object for complex queries. See RediSearch's documentation on query format + - **preserve_bytes**: If True, preserve binary field values as bytes + instead of converting to UTF-8 strings + - **binary_fields**: List of field names to preserve as bytes when + preserve_bytes=True. If None, all binary fields + are preserved For more information see `FT.SEARCH `_. """ # noqa @@ -504,7 +515,8 @@ def search( return res return self._parse_results( - SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0 + SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0, + preserve_bytes=preserve_bytes, binary_fields=binary_fields ) def explain( @@ -911,6 +923,8 @@ async def search( self, query: Union[str, Query], query_params: Dict[str, Union[str, int, float]] = None, + preserve_bytes: bool = False, + binary_fields: Optional[List[str]] = None, ): """ Search the index for a given query, and return a result of documents @@ -920,6 +934,11 @@ async def search( - **query**: the search query. Either a text for simple queries with default parameters, or a Query object for complex queries. See RediSearch's documentation on query format + - **preserve_bytes**: If True, preserve binary field values as bytes + instead of converting to UTF-8 strings + - **binary_fields**: List of field names to preserve as bytes when + preserve_bytes=True. If None, all binary fields + are preserved For more information see `FT.SEARCH `_. """ # noqa @@ -931,7 +950,8 @@ async def search( return res return self._parse_results( - SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0 + SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0, + preserve_bytes=preserve_bytes, binary_fields=binary_fields ) async def aggregate( diff --git a/valkey/commands/search/result.py b/valkey/commands/search/result.py index 5b19e6fa..a5d5862b 100644 --- a/valkey/commands/search/result.py +++ b/valkey/commands/search/result.py @@ -1,4 +1,4 @@ -from ._util import to_string +from ._util import to_string, to_string_or_bytes from .document import Document @@ -9,7 +9,8 @@ class Result: """ def __init__( - self, res, hascontent, duration=0, has_payload=False, with_scores=False + self, res, hascontent, duration=0, has_payload=False, with_scores=False, + preserve_bytes=False, binary_fields=None ): """ - **snippets**: An optional dictionary of the form @@ -39,18 +40,19 @@ def __init__( fields = {} if hascontent and res[i + fields_offset] is not None: - fields = ( - dict( - dict( - zip( - map(to_string, res[i + fields_offset][::2]), - map(to_string, res[i + fields_offset][1::2]), - ) - ) + field_names = list(map(to_string, res[i + fields_offset][::2])) + field_values = res[i + fields_offset][1::2] + + # Process field values with binary preservation + processed_values = [] + for field_name, field_value in zip(field_names, field_values): + processed_value = to_string_or_bytes( + field_value, preserve_bytes, binary_fields, field_name ) - if hascontent - else {} - ) + processed_values.append(processed_value) + + fields = dict(zip(field_names, processed_values)) + try: del fields["id"] except KeyError: