From bb38ff081e98c3072311a31831a78d5e1fd57ae4 Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Mon, 27 Oct 2025 14:33:59 +0200 Subject: [PATCH 01/13] Added policies parsing and policies resolving (#3805) * Added poilcy resolution method * Moved main command proceessing on top * Fixed return type and keyless detection * Added Dynamic and Static policies * Added coverage for policy resolvers * Removed all policies except search (phase 1) * Applied comments --- redis/_parsers/commands.py | 204 ++++++++++++++++++++++++++++++++- redis/commands/policies.py | 130 +++++++++++++++++++++ redis/exceptions.py | 6 + tests/test_command_parser.py | 63 ++++++++++ tests/test_command_policies.py | 57 +++++++++ 5 files changed, 458 insertions(+), 2 deletions(-) create mode 100644 redis/commands/policies.py create mode 100644 tests/test_command_policies.py diff --git a/redis/_parsers/commands.py b/redis/_parsers/commands.py index b5109252ae..a7571ac195 100644 --- a/redis/_parsers/commands.py +++ b/redis/_parsers/commands.py @@ -1,11 +1,43 @@ +from dataclasses import dataclass +from enum import Enum from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union -from redis.exceptions import RedisError, ResponseError +from redis.exceptions import RedisError, ResponseError, IncorrectPolicyType from redis.utils import str_if_bytes if TYPE_CHECKING: from redis.asyncio.cluster import ClusterNode +class RequestPolicy(Enum): + ALL_NODES = 'all_nodes' + ALL_SHARDS = 'all_shards' + MULTI_SHARD = 'multi_shard' + SPECIAL = 'special' + DEFAULT_KEYLESS = 'default_keyless' + DEFAULT_KEYED = 'default_keyed' + +class ResponsePolicy(Enum): + ONE_SUCCEEDED = 'one_succeeded' + ALL_SUCCEEDED = 'all_succeeded' + AGG_LOGICAL_AND = 'agg_logical_and' + AGG_LOGICAL_OR = 'agg_logical_or' + AGG_MIN = 'agg_min' + AGG_MAX = 'agg_max' + AGG_SUM = 'agg_sum' + SPECIAL = 'special' + DEFAULT_KEYLESS = 'default_keyless' + DEFAULT_KEYED = 'default_keyed' + +class CommandPolicies: + def __init__( + self, + request_policy: RequestPolicy = RequestPolicy.DEFAULT_KEYLESS, + response_policy: ResponsePolicy = ResponsePolicy.DEFAULT_KEYLESS + ): + self.request_policy = request_policy + self.response_policy = response_policy + +PolicyRecords = dict[str, dict[str, CommandPolicies]] class AbstractCommandsParser: def _get_pubsub_keys(self, *args): @@ -64,7 +96,8 @@ class CommandsParser(AbstractCommandsParser): def __init__(self, redis_connection): self.commands = {} - self.initialize(redis_connection) + self.redis_connection = redis_connection + self.initialize(self.redis_connection) def initialize(self, r): commands = r.command() @@ -169,6 +202,173 @@ def _get_moveable_keys(self, redis_conn, *args): raise e return keys + def _is_keyless_command(self, command_name: str, subcommand_name: Optional[str]=None) -> bool: + """ + Determines whether a given command or subcommand is considered "keyless". + + A keyless command does not operate on specific keys, which is determined based + on the first key position in the command or subcommand details. If the command + or subcommand's first key position is zero or negative, it is treated as keyless. + + Parameters: + command_name: str + The name of the command to check. + subcommand_name: Optional[str], default=None + The name of the subcommand to check, if applicable. If not provided, + the check is performed only on the command. + + Returns: + bool + True if the specified command or subcommand is considered keyless, + False otherwise. + + Raises: + ValueError + If the specified subcommand is not found within the command or the + specified command does not exist in the available commands. + """ + if subcommand_name: + for subcommand in self.commands.get(command_name)['subcommands']: + if str_if_bytes(subcommand[0]) == subcommand_name: + parsed_subcmd = self.parse_subcommand(subcommand) + return parsed_subcmd['first_key_pos'] <= 0 + raise ValueError(f"Subcommand {subcommand_name} not found in command {command_name}") + else: + command_details = self.commands.get(command_name, None) + if command_details is not None: + return command_details['first_key_pos'] <= 0 + + raise ValueError(f"Command {command_name} not found in commands") + + def get_command_policies(self) -> PolicyRecords: + """ + Retrieve and process the command policies for all commands and subcommands. + + This method traverses through commands and subcommands, extracting policy details + from associated data structures and constructing a dictionary of commands with their + associated policies. It supports nested data structures and handles both main commands + and their subcommands. + + Returns: + PolicyRecords: A collection of commands and subcommands associated with their + respective policies. + + Raises: + IncorrectPolicyType: If an invalid policy type is encountered during policy extraction. + """ + command_with_policies = {} + + def extract_policies(data, module_name, command_name): + """ + Recursively extract policies from nested data structures. + + Args: + data: The data structure to search (can be list, dict, str, bytes, etc.) + command_name: The command name to associate with found policies + """ + if isinstance(data, (str, bytes)): + # Decode bytes to string if needed + policy = str_if_bytes(data.decode()) + + # Check if this is a policy string + if policy.startswith('request_policy') or policy.startswith('response_policy'): + if policy.startswith('request_policy'): + policy_type = policy.split(':')[1] + + try: + command_with_policies[module_name][command_name].request_policy = RequestPolicy(policy_type) + except ValueError: + raise IncorrectPolicyType(f"Incorrect request policy type: {policy_type}") + + if policy.startswith('response_policy'): + policy_type = policy.split(':')[1] + + try: + command_with_policies[module_name][command_name].response_policy = ResponsePolicy(policy_type) + except ValueError: + raise IncorrectPolicyType(f"Incorrect response policy type: {policy_type}") + + elif isinstance(data, list): + # For lists, recursively process each element + for item in data: + extract_policies(item, module_name, command_name) + + elif isinstance(data, dict): + # For dictionaries, recursively process each value + for value in data.values(): + extract_policies(value, module_name, command_name) + + for command, details in self.commands.items(): + # Check whether the command has keys + is_keyless = self._is_keyless_command(command) + + if is_keyless: + default_request_policy = RequestPolicy.DEFAULT_KEYLESS + default_response_policy = ResponsePolicy.DEFAULT_KEYLESS + else: + default_request_policy = RequestPolicy.DEFAULT_KEYED + default_response_policy = ResponsePolicy.DEFAULT_KEYED + + # Check if it's a core or module command + split_name = command.split('.') + + if len(split_name) > 1: + module_name = split_name[0] + command_name = split_name[1] + else: + module_name = 'core' + command_name = split_name[0] + + # Create a CommandPolicies object with default policies on the new command. + if command_with_policies.get(module_name, None) is None: + command_with_policies[module_name] = {command_name: CommandPolicies( + request_policy=default_request_policy, + response_policy=default_response_policy + )} + else: + command_with_policies[module_name][command_name] = CommandPolicies( + request_policy=default_request_policy, + response_policy=default_response_policy + ) + + tips = details.get('tips') + subcommands = details.get('subcommands') + + # Process tips for the main command + if tips: + extract_policies(tips, module_name, command_name) + + # Process subcommands + if subcommands: + for subcommand_details in subcommands: + # Get the subcommand name (first element) + subcmd_name = subcommand_details[0] + if isinstance(subcmd_name, bytes): + subcmd_name = subcmd_name.decode() + + # Check whether the subcommand has keys + is_keyless = self._is_keyless_command(command, subcmd_name) + + if is_keyless: + default_request_policy = RequestPolicy.DEFAULT_KEYLESS + default_response_policy = ResponsePolicy.DEFAULT_KEYLESS + else: + default_request_policy = RequestPolicy.DEFAULT_KEYED + default_response_policy = ResponsePolicy.DEFAULT_KEYED + + subcmd_name = subcmd_name.replace('|', ' ') + + # Create a CommandPolicies object with default policies on the new command. + command_with_policies[module_name][subcmd_name] = CommandPolicies( + request_policy=default_request_policy, + response_policy=default_response_policy + ) + + # Recursively extract policies from the rest of the subcommand details + for subcommand_detail in subcommand_details[1:]: + extract_policies(subcommand_detail, module_name, subcmd_name) + + return command_with_policies class AsyncCommandsParser(AbstractCommandsParser): """ diff --git a/redis/commands/policies.py b/redis/commands/policies.py new file mode 100644 index 0000000000..a2f7f45924 --- /dev/null +++ b/redis/commands/policies.py @@ -0,0 +1,130 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from redis._parsers.commands import CommandPolicies, PolicyRecords, RequestPolicy, ResponsePolicy, CommandsParser + +STATIC_POLICIES: PolicyRecords = { + 'ft': { + 'explaincli': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'suglen': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYED, response_policy=ResponsePolicy.DEFAULT_KEYED), + 'profile': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'dropindex': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'aliasupdate': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'alter': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'aggregate': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'syndump': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'create': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'explain': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'sugget': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYED, response_policy=ResponsePolicy.DEFAULT_KEYED), + 'dictdel': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'aliasadd': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'dictadd': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'synupdate': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'drop': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'info': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'sugadd': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYED, response_policy=ResponsePolicy.DEFAULT_KEYED), + 'dictdump': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'cursor': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'search': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'tagvals': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'aliasdel': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'sugdel': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYED, response_policy=ResponsePolicy.DEFAULT_KEYED), + 'spellcheck': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + } +} + +class PolicyResolver(ABC): + + @abstractmethod + def resolve(self, command_name: str) -> CommandPolicies: + """ + Resolves the command name and determines the associated command policies. + + Args: + command_name: The name of the command to resolve. + + Returns: + CommandPolicies: The policies associated with the specified command. + """ + pass + + @abstractmethod + def with_fallback(self, fallback: "PolicyResolver") -> "PolicyResolver": + """ + Factory method to instantiate a policy resolver with a fallback resolver. + + Args: + fallback: Fallback resolver + + Returns: + PolicyResolver: Returns a new policy resolver with the specified fallback resolver. + """ + pass + +class BasePolicyResolver(PolicyResolver): + """ + Base class for policy resolvers. + """ + def __init__(self, policies: PolicyRecords, fallback: Optional[PolicyResolver] = None) -> None: + self._policies = policies + self._fallback = fallback + + def resolve(self, command_name: str) -> CommandPolicies: + parts = command_name.split(".") + + if len(parts) > 2: + raise ValueError(f"Wrong command or module name: {command_name}") + + module, command = parts if len(parts) == 2 else ("core", parts[0]) + + if self._policies.get(module, None) is None: + if self._fallback is not None: + return self._fallback.resolve(command_name) + else: + raise ValueError(f"Module {module} not found") + + if self._policies.get(module).get(command, None) is None: + if self._fallback is not None: + return self._fallback.resolve(command_name) + else: + raise ValueError(f"Command {command} not found in module {module}") + + return self._policies.get(module).get(command) + + @abstractmethod + def with_fallback(self, fallback: "PolicyResolver") -> "PolicyResolver": + pass + + +class DynamicPolicyResolver(BasePolicyResolver): + """ + Resolves policy dynamically based on the COMMAND output. + """ + def __init__(self, commands_parser: CommandsParser, fallback: Optional[PolicyResolver] = None) -> None: + """ + Parameters: + commands_parser (CommandsParser): COMMAND output parser. + fallback (Optional[PolicyResolver]): An optional resolver to be used when the + primary policies cannot handle a specific request. + """ + self._commands_parser = commands_parser + super().__init__(commands_parser.get_command_policies(), fallback) + + def with_fallback(self, fallback: "PolicyResolver") -> "PolicyResolver": + return DynamicPolicyResolver(self._commands_parser, fallback) + + +class StaticPolicyResolver(BasePolicyResolver): + """ + Resolves policy from a static list of policy records. + """ + def __init__(self, fallback: Optional[PolicyResolver] = None) -> None: + """ + Parameters: + fallback (Optional[PolicyResolver]): An optional fallback policy resolver + used for resolving policies if static policies are inadequate. + """ + super().__init__(STATIC_POLICIES, fallback) + + def with_fallback(self, fallback: "PolicyResolver") -> "PolicyResolver": + return StaticPolicyResolver(fallback) \ No newline at end of file diff --git a/redis/exceptions.py b/redis/exceptions.py index 643444986b..458ba5843f 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -245,3 +245,9 @@ class InvalidPipelineStack(RedisClusterException): """ pass + +class IncorrectPolicyType(Exception): + """ + Raised when a policy type isn't matching to any known policy types. + """ + pass \ No newline at end of file diff --git a/tests/test_command_parser.py b/tests/test_command_parser.py index e3b44a147f..6be43e5823 100644 --- a/tests/test_command_parser.py +++ b/tests/test_command_parser.py @@ -1,5 +1,8 @@ +from pprint import pprint + import pytest from redis._parsers import CommandsParser +from redis._parsers.commands import RequestPolicy, ResponsePolicy from .conftest import ( assert_resp_response, @@ -106,3 +109,63 @@ def test_get_pubsub_keys(self, r): assert commands_parser.get_keys(r, *args2) == ["foo1", "foo2", "foo3"] assert commands_parser.get_keys(r, *args3) == ["*"] assert commands_parser.get_keys(r, *args4) == ["foo1", "foo2", "foo3"] + + @skip_if_server_version_lt("7.0.0") + @pytest.mark.onlycluster + def test_get_command_policies(self, r): + commands_parser = CommandsParser(r) + expected_command_policies = { + 'core': { + 'keys': ['keys', RequestPolicy.ALL_SHARDS, ResponsePolicy.DEFAULT_KEYLESS], + 'acl setuser': ['acl setuser', RequestPolicy.ALL_NODES, ResponsePolicy.ALL_SUCCEEDED], + 'exists': ['exists', RequestPolicy.MULTI_SHARD, ResponsePolicy.AGG_SUM], + 'config resetstat': ['config resetstat', RequestPolicy.ALL_NODES, ResponsePolicy.ALL_SUCCEEDED], + 'slowlog len': ['slowlog len', RequestPolicy.ALL_NODES, ResponsePolicy.AGG_SUM], + 'scan': ['scan', RequestPolicy.SPECIAL, ResponsePolicy.SPECIAL], + 'latency history': ['latency history', RequestPolicy.ALL_NODES, ResponsePolicy.SPECIAL], + 'memory doctor': ['memory doctor', RequestPolicy.ALL_SHARDS, ResponsePolicy.SPECIAL], + 'randomkey': ['randomkey', RequestPolicy.ALL_SHARDS, ResponsePolicy.SPECIAL], + 'mget': ['mget', RequestPolicy.MULTI_SHARD, ResponsePolicy.DEFAULT_KEYED], + 'function restore': ['function restore', RequestPolicy.ALL_SHARDS, ResponsePolicy.ALL_SUCCEEDED], + }, + 'json': { + 'debug': ['debug', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + 'get': ['get', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + }, + 'ft': { + 'search': ['search', RequestPolicy.DEFAULT_KEYLESS, ResponsePolicy.DEFAULT_KEYLESS], + 'create': ['create', RequestPolicy.DEFAULT_KEYLESS, ResponsePolicy.DEFAULT_KEYLESS], + }, + 'bf': { + 'add': ['add', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + 'madd': ['madd', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + }, + 'cf': { + 'add': ['add', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + 'mexists': ['mexists', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + }, + 'tdigest': { + 'add': ['add', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + 'min': ['min', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + }, + 'ts': { + 'create': ['create', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + 'info': ['info', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + }, + 'topk': { + 'list': ['list', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + 'query': ['query', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + } + } + + actual_policies = commands_parser.get_command_policies() + assert len(actual_policies) > 0 + + for module_name, commands in expected_command_policies.items(): + for command, command_policies in commands.items(): + assert command in actual_policies[module_name] + assert command_policies == [ + command, + actual_policies[module_name][command].request_policy, + actual_policies[module_name][command].response_policy + ] \ No newline at end of file diff --git a/tests/test_command_policies.py b/tests/test_command_policies.py new file mode 100644 index 0000000000..c0d057f0b0 --- /dev/null +++ b/tests/test_command_policies.py @@ -0,0 +1,57 @@ +from unittest.mock import Mock + +import pytest + +from redis._parsers import CommandsParser +from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy +from redis.commands.policies import DynamicPolicyResolver, StaticPolicyResolver + + +@pytest.mark.onlycluster +class TestBasePolicyResolver: + def test_resolve(self): + mock_command_parser = Mock(spec=CommandsParser) + zcount_policy = CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYED, response_policy=ResponsePolicy.DEFAULT_KEYED) + rpoplpush_policy = CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYED, response_policy=ResponsePolicy.DEFAULT_KEYED) + + mock_command_parser.get_command_policies.return_value = { + 'core': { + 'zcount': zcount_policy, + 'rpoplpush': rpoplpush_policy, + } + } + + dynamic_resolver = DynamicPolicyResolver(mock_command_parser) + assert dynamic_resolver.resolve('zcount') == zcount_policy + assert dynamic_resolver.resolve('rpoplpush') == rpoplpush_policy + + with pytest.raises(ValueError, match="Wrong command or module name: foo.bar.baz"): + dynamic_resolver.resolve('foo.bar.baz') + + with pytest.raises(ValueError, match="Module foo not found"): + dynamic_resolver.resolve('foo.bar') + + with pytest.raises(ValueError, match="Command foo not found in module core"): + dynamic_resolver.resolve('core.foo') + + # Test that policy fallback correctly + static_resolver = StaticPolicyResolver() + with_fallback_dynamic_resolver = dynamic_resolver.with_fallback(static_resolver) + + assert with_fallback_dynamic_resolver.resolve('ft.aggregate').request_policy == RequestPolicy.DEFAULT_KEYLESS + assert with_fallback_dynamic_resolver.resolve('ft.aggregate').response_policy == ResponsePolicy.DEFAULT_KEYLESS + + # Extended chain with one more resolver + mock_command_parser = Mock(spec=CommandsParser) + foo_bar_policy = CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS) + + mock_command_parser.get_command_policies.return_value = { + 'foo': { + 'bar': foo_bar_policy, + } + } + another_dynamic_resolver = DynamicPolicyResolver(mock_command_parser) + with_fallback_static_resolver = static_resolver.with_fallback(another_dynamic_resolver) + with_double_fallback_dynamic_resolver = dynamic_resolver.with_fallback(with_fallback_static_resolver) + + assert with_double_fallback_dynamic_resolver.resolve('foo.bar') == foo_bar_policy \ No newline at end of file From df45cc8749789afa264996ed4add9a17802a5fef Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Thu, 30 Oct 2025 08:39:23 +0200 Subject: [PATCH 02/13] Apply policies for normal and pipeline mode (#3818) * Added poilcy resolution method * Moved main command proceessing on top * Fixed return type and keyless detection * Added Dynamic and Static policies * Added coverage for policy resolvers * Applied request policies * Removed all policies except search (phase 1) * Policy applied for normal and pipeline mode * Added assertion with core command * Applied comments * Renamed method --- redis/_parsers/commands.py | 6 +- redis/cluster.py | 308 ++++++++++++++++++++++++++------- redis/commands/policies.py | 11 +- tests/conftest.py | 2 +- tests/test_cluster.py | 23 ++- tests/test_command_policies.py | 107 +++++++++++- 6 files changed, 378 insertions(+), 79 deletions(-) diff --git a/redis/_parsers/commands.py b/redis/_parsers/commands.py index a7571ac195..cff2296b27 100644 --- a/redis/_parsers/commands.py +++ b/redis/_parsers/commands.py @@ -11,10 +11,12 @@ class RequestPolicy(Enum): ALL_NODES = 'all_nodes' ALL_SHARDS = 'all_shards' + ALL_REPLICAS = 'all_replicas' MULTI_SHARD = 'multi_shard' SPECIAL = 'special' DEFAULT_KEYLESS = 'default_keyless' DEFAULT_KEYED = 'default_keyed' + DEFAULT_NODE = 'default_node' class ResponsePolicy(Enum): ONE_SUCCEEDED = 'one_succeeded' @@ -162,7 +164,9 @@ def get_keys(self, redis_conn, *args): for subcmd in command["subcommands"]: if str_if_bytes(subcmd[0]) == subcmd_name: command = self.parse_subcommand(subcmd) - is_subcmd = True + + if command['first_key_pos'] > 0: + is_subcmd = True # The command doesn't have keys in it if not is_subcmd: diff --git a/redis/cluster.py b/redis/cluster.py index 839721edf1..8fc7ef5ef7 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -11,12 +11,14 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from redis._parsers import CommandsParser, Encoder +from redis._parsers.commands import RequestPolicy, CommandPolicies, ResponsePolicy from redis._parsers.helpers import parse_scan from redis.backoff import ExponentialWithJitterBackoff, NoBackoff from redis.cache import CacheConfig, CacheFactory, CacheFactoryInterface, CacheInterface from redis.client import EMPTY_RESPONSE, CaseInsensitiveDict, PubSub, Redis from redis.commands import READ_COMMANDS, RedisClusterCommands from redis.commands.helpers import list_or_args +from redis.commands.policies import PolicyResolver, StaticPolicyResolver from redis.connection import ( Connection, ConnectionPool, @@ -531,6 +533,7 @@ def __init__( cache: Optional[CacheInterface] = None, cache_config: Optional[CacheConfig] = None, event_dispatcher: Optional[EventDispatcher] = None, + policy_resolver: PolicyResolver = StaticPolicyResolver(), **kwargs, ): """ @@ -712,7 +715,34 @@ def __init__( ) self.result_callbacks = CaseInsensitiveDict(self.__class__.RESULT_CALLBACKS) + # For backward compatibility, mapping from existing policies to new one + self._command_flags_mapping: dict[str, Union[RequestPolicy, ResponsePolicy]] = { + self.__class__.RANDOM: RequestPolicy.DEFAULT_KEYLESS, + self.__class__.PRIMARIES: RequestPolicy.ALL_SHARDS, + self.__class__.ALL_NODES: RequestPolicy.ALL_NODES, + self.__class__.REPLICAS: RequestPolicy.ALL_REPLICAS, + self.__class__.DEFAULT_NODE: RequestPolicy.DEFAULT_NODE, + SLOT_ID: RequestPolicy.DEFAULT_KEYED, + } + + self._policies_callback_mapping: dict[Union[RequestPolicy, ResponsePolicy], Callable] = { + RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [self.get_random_primary_or_all_nodes(command_name)], + RequestPolicy.DEFAULT_KEYED: lambda command, *args: self.get_nodes_from_slot(command, *args), + RequestPolicy.DEFAULT_NODE: lambda: [self.get_default_node()], + RequestPolicy.ALL_SHARDS: self.get_primaries, + RequestPolicy.ALL_NODES: self.get_nodes, + RequestPolicy.ALL_REPLICAS: self.get_replicas, + RequestPolicy.MULTI_SHARD: lambda *args, **kwargs: self._split_multi_shard_command(*args, **kwargs), + RequestPolicy.SPECIAL: self.get_special_nodes, + ResponsePolicy.DEFAULT_KEYLESS: lambda res: res, + ResponsePolicy.DEFAULT_KEYED: lambda res: res, + } + + self._policy_resolver = policy_resolver self.commands_parser = CommandsParser(self) + + # Node where FT.AGGREGATE command is executed. + self._aggregate_nodes = None self._lock = threading.RLock() def __enter__(self): @@ -775,6 +805,15 @@ def get_replicas(self): def get_random_node(self): return random.choice(list(self.nodes_manager.nodes_cache.values())) + def get_random_primary_or_all_nodes(self, command_name): + """ + Returns random primary or all nodes depends on READONLY mode. + """ + if self.read_from_replicas and command_name in READ_COMMANDS: + return self.get_random_node() + + return self.get_random_primary_node() + def get_nodes(self): return list(self.nodes_manager.nodes_cache.values()) @@ -804,6 +843,74 @@ def get_default_node(self): """ return self.nodes_manager.default_node + def get_nodes_from_slot(self, command: str, *args): + """ + Returns a list of nodes that hold the specified keys' slots. + """ + # get the node that holds the key's slot + slot = self.determine_slot(*args) + node = self.nodes_manager.get_node_from_slot( + slot, + self.read_from_replicas and command in READ_COMMANDS, + self.load_balancing_strategy if command in READ_COMMANDS else None, + ) + return [node] + + def _split_multi_shard_command(self, *args, **kwargs) -> list[dict]: + """ + Splits the command with Multi-Shard policy, to the multiple commands + """ + keys = self._get_command_keys(*args) + commands = [] + + for key in keys: + commands.append({ + 'args': (args[0], key), + 'kwargs': kwargs, + }) + + return commands + + def get_special_nodes(self) -> Optional[list["ClusterNode"]]: + """ + Returns a list of nodes for commands with a special policy. + """ + if not self._aggregate_nodes: + raise RedisClusterException('Cannot execute FT.CURSOR commands without FT.AGGREGATE') + + return self._aggregate_nodes + + def get_random_primary_node(self) -> "ClusterNode": + """ + Returns a random primary node + """ + return random.choice(self.get_primaries()) + + def _evaluate_all_succeeded(self, res): + """ + Evaluate the result of a command with ResponsePolicy.ALL_SUCCEEDED + """ + first_successful_response = None + + if isinstance(res, dict): + for key, value in res.items(): + if value: + if first_successful_response is None: + first_successful_response = {key: value} + else: + return {key: False} + else: + for response in res: + if response: + if first_successful_response is None: + # Dynamically resolve type + first_successful_response = type(response)(response) + else: + return type(response)(False) + + return first_successful_response + + def set_default_node(self, node): """ Set the default node of the cluster. @@ -953,9 +1060,10 @@ def set_response_callback(self, command, callback): """Set a custom Response Callback""" self.cluster_response_callbacks[command] = callback - def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: - # Determine which nodes should be executed the command on. - # Returns a list of target nodes. + def _determine_nodes(self, *args, request_policy: RequestPolicy, **kwargs) -> List["ClusterNode"]: + """ + Determines a nodes the command should be executed on. + """ command = args[0].upper() if len(args) >= 2 and f"{args[0]} {args[1]}".upper() in self.command_flags: command = f"{args[0]} {args[1]}".upper() @@ -967,32 +1075,25 @@ def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: else: # get the nodes group for this command if it was predefined command_flag = self.command_flags.get(command) - if command_flag == self.__class__.RANDOM: - # return a random node - return [self.get_random_node()] - elif command_flag == self.__class__.PRIMARIES: - # return all primaries - return self.get_primaries() - elif command_flag == self.__class__.REPLICAS: - # return all replicas - return self.get_replicas() - elif command_flag == self.__class__.ALL_NODES: - # return all nodes - return self.get_nodes() - elif command_flag == self.__class__.DEFAULT_NODE: - # return the cluster's default node - return [self.nodes_manager.default_node] - elif command in self.__class__.SEARCH_COMMANDS[0]: - return [self.nodes_manager.default_node] + + if command_flag in self._command_flags_mapping: + request_policy = self._command_flags_mapping[command_flag] + + policy_callback = self._policies_callback_mapping[request_policy] + + if request_policy == RequestPolicy.DEFAULT_KEYED: + nodes = policy_callback(command, *args) + elif request_policy == RequestPolicy.MULTI_SHARD: + nodes = policy_callback(*args, **kwargs) + elif request_policy == RequestPolicy.DEFAULT_KEYLESS: + nodes = policy_callback(args[0]) else: - # get the node that holds the key's slot - slot = self.determine_slot(*args) - node = self.nodes_manager.get_node_from_slot( - slot, - self.read_from_replicas and command in READ_COMMANDS, - self.load_balancing_strategy if command in READ_COMMANDS else None, - ) - return [node] + nodes = policy_callback() + + if args[0].lower() == "ft.aggregate": + self._aggregate_nodes = nodes + + return nodes def _should_reinitialized(self): # To reinitialize the cluster on every MOVED error, @@ -1142,6 +1243,35 @@ def _internal_execute_command(self, *args, **kwargs): is_default_node = False target_nodes = None passed_targets = kwargs.pop("target_nodes", None) + command_policies = self._policy_resolver.resolve(args[0].lower()) + + if not command_policies: + command = args[0].upper() + if len(args) >= 2 and f"{args[0]} {args[1]}".upper() in self.command_flags: + command = f"{args[0]} {args[1]}".upper() + + # We only could resolve key properties if command is not + # in a list of pre-defined request policies + command_flag = self.command_flags.get(command) + if not command_flag: + # Fallback to default policy + if not self.get_default_node(): + keys = None + else: + keys = self._get_command_keys(*args) + if not keys or len(keys) == 0: + command_policies = CommandPolicies() + else: + command_policies = CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYED, + response_policy=ResponsePolicy.DEFAULT_KEYED, + ) + else: + if command_flag in self._command_flags_mapping: + command_policies = CommandPolicies(request_policy=self._command_flags_mapping[command_flag]) + else: + command_policies = CommandPolicies() + if passed_targets is not None and not self._is_nodes_flag(passed_targets): target_nodes = self._parse_target_nodes(passed_targets) target_nodes_specified = True @@ -1162,7 +1292,7 @@ def _internal_execute_command(self, *args, **kwargs): if not target_nodes_specified: # Determine the nodes to execute the command on target_nodes = self._determine_nodes( - *args, **kwargs, nodes_flag=passed_targets + *args, request_policy=command_policies.request_policy, nodes_flag=passed_targets ) if not target_nodes: raise RedisClusterException( @@ -1175,8 +1305,12 @@ def _internal_execute_command(self, *args, **kwargs): is_default_node = True for node in target_nodes: res[node.name] = self._execute_command(node, *args, **kwargs) + + if command_policies.response_policy == ResponsePolicy.ONE_SUCCEEDED: + break + # Return the processed result - return self._process_result(args[0], res, **kwargs) + return self._process_result(args[0], res, response_policy=command_policies.response_policy, **kwargs) except Exception as e: if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY: if is_default_node: @@ -1316,7 +1450,7 @@ def close(self) -> None: # RedisCluster's __init__ can fail before nodes_manager is set pass - def _process_result(self, command, res, **kwargs): + def _process_result(self, command, res, response_policy: ResponsePolicy, **kwargs): """ Process the result of the executed command. The function would return a dict or a single value. @@ -1328,13 +1462,13 @@ def _process_result(self, command, res, **kwargs): Dict """ if command in self.result_callbacks: - return self.result_callbacks[command](command, res, **kwargs) + res = self.result_callbacks[command](command, res, **kwargs) elif len(res) == 1: # When we execute the command on a single node, we can # remove the dictionary and return a single response - return list(res.values())[0] - else: - return res + res = list(res.values())[0] + + return self._policies_callback_mapping[response_policy](res) def load_external_module(self, funcname, func): """ @@ -2155,6 +2289,7 @@ def __init__( retry: Optional[Retry] = None, lock=None, transaction=False, + policy_resolver: PolicyResolver = StaticPolicyResolver(), **kwargs, ): """ """ @@ -2193,6 +2328,31 @@ def __init__( PipelineStrategy(self) if not transaction else TransactionStrategy(self) ) + # For backward compatibility, mapping from existing policies to new one + self._command_flags_mapping: dict[str, Union[RequestPolicy, ResponsePolicy]] = { + self.__class__.RANDOM: RequestPolicy.DEFAULT_KEYLESS, + self.__class__.PRIMARIES: RequestPolicy.ALL_SHARDS, + self.__class__.ALL_NODES: RequestPolicy.ALL_NODES, + self.__class__.REPLICAS: RequestPolicy.ALL_REPLICAS, + self.__class__.DEFAULT_NODE: RequestPolicy.DEFAULT_NODE, + SLOT_ID: RequestPolicy.DEFAULT_KEYED, + } + + self._policies_callback_mapping: dict[Union[RequestPolicy, ResponsePolicy], Callable] = { + RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [self.get_random_primary_or_all_nodes(command_name)], + RequestPolicy.DEFAULT_KEYED: lambda command, *args: self.get_nodes_from_slot(command, *args), + RequestPolicy.DEFAULT_NODE: lambda: [self.get_default_node()], + RequestPolicy.ALL_SHARDS: self.get_primaries, + RequestPolicy.ALL_NODES: self.get_nodes, + RequestPolicy.ALL_REPLICAS: self.get_replicas, + RequestPolicy.MULTI_SHARD: lambda *args, **kwargs: self._split_multi_shard_command(*args, **kwargs), + RequestPolicy.SPECIAL: self.get_special_nodes, + ResponsePolicy.DEFAULT_KEYLESS: lambda res: res, + ResponsePolicy.DEFAULT_KEYED: lambda res: res, + } + + self._policy_resolver = policy_resolver + def __repr__(self): """ """ return f"{type(self).__name__}" @@ -2771,6 +2931,35 @@ def _send_cluster_commands( # we figure out the slot number that command maps to, then from # the slot determine the node. for c in attempt: + command_policies = self._pipe._policy_resolver.resolve(c.args[0].lower()) + + if not command_policies: + command = c.args[0].upper() + if len(c.args) >= 2 and f"{c.args[0]} {c.args[1]}".upper() in self._pipe.command_flags: + command = f"{c.args[0]} {c.args[1]}".upper() + + # We only could resolve key properties if command is not + # in a list of pre-defined request policies + command_flag = self.command_flags.get(command) + if not command_flag: + # Fallback to default policy + if not self._pipe.get_default_node(): + keys = None + else: + keys = self._pipe._get_command_keys(*c.args) + if not keys or len(keys) == 0: + command_policies = CommandPolicies() + else: + command_policies = CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYED, + response_policy=ResponsePolicy.DEFAULT_KEYED, + ) + else: + if command_flag in self._pipe._command_flags_mapping: + command_policies = CommandPolicies(request_policy=self._pipe._command_flags_mapping[command_flag]) + else: + command_policies = CommandPolicies() + while True: # refer to our internal node -> slot table that # tells us where a given command should route to. @@ -2781,7 +2970,7 @@ def _send_cluster_commands( target_nodes = self._parse_target_nodes(passed_targets) else: target_nodes = self._determine_nodes( - *c.args, node_flag=passed_targets + *c.args, request_policy=command_policies.request_policy, node_flag=passed_targets ) if not target_nodes: raise RedisClusterException( @@ -2944,7 +3133,7 @@ def _parse_target_nodes(self, target_nodes): ) return nodes - def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: + def _determine_nodes(self, *args, request_policy: RequestPolicy, **kwargs) -> List["ClusterNode"]: # Determine which nodes should be executed the command on. # Returns a list of target nodes. command = args[0].upper() @@ -2961,34 +3150,25 @@ def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: else: # get the nodes group for this command if it was predefined command_flag = self._pipe.command_flags.get(command) - if command_flag == self._pipe.RANDOM: - # return a random node - return [self._pipe.get_random_node()] - elif command_flag == self._pipe.PRIMARIES: - # return all primaries - return self._pipe.get_primaries() - elif command_flag == self._pipe.REPLICAS: - # return all replicas - return self._pipe.get_replicas() - elif command_flag == self._pipe.ALL_NODES: - # return all nodes - return self._pipe.get_nodes() - elif command_flag == self._pipe.DEFAULT_NODE: - # return the cluster's default node - return [self._nodes_manager.default_node] - elif command in self._pipe.SEARCH_COMMANDS[0]: - return [self._nodes_manager.default_node] + + if command_flag in self._pipe._command_flags_mapping: + request_policy = self._pipe._command_flags_mapping[command_flag] + + policy_callback = self._pipe._policies_callback_mapping[request_policy] + + if request_policy == RequestPolicy.DEFAULT_KEYED: + nodes = policy_callback(command, *args) + elif request_policy == RequestPolicy.MULTI_SHARD: + nodes = policy_callback(*args, **kwargs) + elif request_policy == RequestPolicy.DEFAULT_KEYLESS: + nodes = policy_callback(args[0]) else: - # get the node that holds the key's slot - slot = self._pipe.determine_slot(*args) - node = self._nodes_manager.get_node_from_slot( - slot, - self._pipe.read_from_replicas and command in READ_COMMANDS, - self._pipe.load_balancing_strategy - if command in READ_COMMANDS - else None, - ) - return [node] + nodes = policy_callback() + + if args[0].lower() == "ft.aggregate": + self._aggregate_nodes = nodes + + return nodes def multi(self): raise RedisClusterException( diff --git a/redis/commands/policies.py b/redis/commands/policies.py index a2f7f45924..7413f3f68d 100644 --- a/redis/commands/policies.py +++ b/redis/commands/policies.py @@ -24,12 +24,15 @@ 'info': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), 'sugadd': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYED, response_policy=ResponsePolicy.DEFAULT_KEYED), 'dictdump': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), - 'cursor': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + 'cursor': CommandPolicies(request_policy=RequestPolicy.SPECIAL, response_policy=ResponsePolicy.DEFAULT_KEYLESS), 'search': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), 'tagvals': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), 'aliasdel': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), 'sugdel': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYED, response_policy=ResponsePolicy.DEFAULT_KEYED), 'spellcheck': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + }, + 'core': { + 'command': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), } } @@ -69,7 +72,7 @@ def __init__(self, policies: PolicyRecords, fallback: Optional[PolicyResolver] = self._policies = policies self._fallback = fallback - def resolve(self, command_name: str) -> CommandPolicies: + def resolve(self, command_name: str) -> Optional[CommandPolicies]: parts = command_name.split(".") if len(parts) > 2: @@ -81,13 +84,13 @@ def resolve(self, command_name: str) -> CommandPolicies: if self._fallback is not None: return self._fallback.resolve(command_name) else: - raise ValueError(f"Module {module} not found") + return None if self._policies.get(module).get(command, None) is None: if self._fallback is not None: return self._fallback.resolve(command_name) else: - raise ValueError(f"Command {command} not found in module {module}") + return None return self._policies.get(module).get(command) diff --git a/tests/conftest.py b/tests/conftest.py index 7eaccb1acb..9c174974ef 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,7 +30,7 @@ from tests.ssl_utils import get_tls_certificates REDIS_INFO = {} -default_redis_url = "redis://localhost:6379/0" +default_redis_url = "redis://localhost:16379/0" default_protocol = "2" default_redismod_url = "redis://localhost:6479" diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 2936bb0024..759c93ffc6 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -207,7 +207,28 @@ def cmd_init_mock(self, r): "first_key_pos": 1, "last_key_pos": 1, "step_count": 1, - } + }, + "cluster delslots": { + "name": "cluster delslots", + "flags": ["readonly", "fast"], + "first_key_pos": 0, + "last_key_pos": 0, + "step_count": 0, + }, + "cluster delslotsrange": { + "name": "cluster delslotsrange", + "flags": ["readonly", "fast"], + "first_key_pos": 0, + "last_key_pos": 0, + "step_count": 0, + }, + "cluster addslots": { + "name": "cluster delslotsrange", + "flags": ["readonly", "fast"], + "first_key_pos": 0, + "last_key_pos": 0, + "step_count": 0, + }, } cmd_parser_initialize.side_effect = cmd_init_mock diff --git a/tests/test_command_policies.py b/tests/test_command_policies.py index c0d057f0b0..ca3ecb1036 100644 --- a/tests/test_command_policies.py +++ b/tests/test_command_policies.py @@ -1,11 +1,15 @@ -from unittest.mock import Mock +import random +from unittest.mock import Mock, patch import pytest +from redis import ResponseError + from redis._parsers import CommandsParser from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy from redis.commands.policies import DynamicPolicyResolver, StaticPolicyResolver - +from redis.commands.search.aggregation import AggregateRequest +from redis.commands.search.field import TextField, NumericField @pytest.mark.onlycluster class TestBasePolicyResolver: @@ -28,11 +32,8 @@ def test_resolve(self): with pytest.raises(ValueError, match="Wrong command or module name: foo.bar.baz"): dynamic_resolver.resolve('foo.bar.baz') - with pytest.raises(ValueError, match="Module foo not found"): - dynamic_resolver.resolve('foo.bar') - - with pytest.raises(ValueError, match="Command foo not found in module core"): - dynamic_resolver.resolve('core.foo') + assert dynamic_resolver.resolve('foo.bar') is None + assert dynamic_resolver.resolve('core.foo') is None # Test that policy fallback correctly static_resolver = StaticPolicyResolver() @@ -54,4 +55,94 @@ def test_resolve(self): with_fallback_static_resolver = static_resolver.with_fallback(another_dynamic_resolver) with_double_fallback_dynamic_resolver = dynamic_resolver.with_fallback(with_fallback_static_resolver) - assert with_double_fallback_dynamic_resolver.resolve('foo.bar') == foo_bar_policy \ No newline at end of file + assert with_double_fallback_dynamic_resolver.resolve('foo.bar') == foo_bar_policy + +@pytest.mark.onlycluster +class TestClusterWithPolicies: + def test_resolves_correctly_policies(self, r, monkeypatch): + # original nodes selection method + determine_nodes = r._determine_nodes + determined_nodes = [] + primary_nodes = r.get_primaries() + calls = iter(list(range(len(primary_nodes)))) + + def wrapper(*args, request_policy: RequestPolicy, **kwargs): + nonlocal determined_nodes + determined_nodes = determine_nodes(*args, request_policy=request_policy, **kwargs) + return determined_nodes + + # Mock random.choice to always return a pre-defined sequence of nodes + monkeypatch.setattr(random, "choice", lambda seq: seq[next(calls)]) + + with patch.object(r, '_determine_nodes', side_effect=wrapper, autospec=True): + # Routed to a random primary node + r.ft().create_index( + ( + NumericField("random_num"), + TextField("title"), + TextField("body"), + TextField("parent"), + ) + ) + assert determined_nodes[0] == primary_nodes[0] + + # Routed to another random primary node + info = r.ft().info() + assert info['index_name'] == 'idx' + assert determined_nodes[0] == primary_nodes[1] + + expected_node = r.get_nodes_from_slot('ft.suglen', *['FT.SUGLEN', 'foo']) + r.ft().suglen('foo') + assert determined_nodes[0] == expected_node[0] + + # Indexing a document + r.hset( + "search", + mapping={ + "title": "RediSearch", + "body": "Redisearch impements a search engine on top of redis", + "parent": "redis", + "random_num": 10, + }, + ) + r.hset( + "ai", + mapping={ + "title": "RedisAI", + "body": "RedisAI executes Deep Learning/Machine Learning models and managing their data.", # noqa + "parent": "redis", + "random_num": 3, + }, + ) + r.hset( + "json", + mapping={ + "title": "RedisJson", + "body": "RedisJSON implements ECMA-404 The JSON Data Interchange Standard as a native data type.", # noqa + "parent": "redis", + "random_num": 8, + }, + ) + + req = AggregateRequest("redis").group_by( + "@parent" + ).cursor(1) + cursor = r.ft().aggregate(req).cursor + + # Ensure that aggregate node was cached. + assert determined_nodes[0] == r._aggregate_nodes[0] + + r.ft().aggregate(cursor) + + # Verify that FT.CURSOR dispatched to the same node. + assert determined_nodes[0] == r._aggregate_nodes[0] + + # Error propagates to a user + with pytest.raises(ResponseError, match="Cursor not found, id: 0"): + r.ft().aggregate(cursor) + + assert determined_nodes[0] == primary_nodes[2] + + # Core commands also randomly distributed across masters + r.randomkey() + assert determined_nodes[0] == primary_nodes[0] \ No newline at end of file From 2e3aceb613c1f5094b78294ae6b8b8426af6f701 Mon Sep 17 00:00:00 2001 From: Vladyslav Vildanov <117659936+vladvildanov@users.noreply.github.com> Date: Mon, 3 Nov 2025 14:58:45 +0200 Subject: [PATCH 03/13] Added async implementation for Request-Response policies (#3824) * Added poilcy resolution method * Moved main command proceessing on top * Fixed return type and keyless detection * Added Dynamic and Static policies * Added coverage for policy resolvers * Applied request policies * Removed all policies except search (phase 1) * Policy applied for normal and pipeline mode * Added assertion with core command * Applied comments * Added async support for Request/Responce policies * Fixed method name * Updated async to sync --- redis/_parsers/commands.py | 175 ++++++++++++++++++- redis/asyncio/cluster.py | 176 ++++++++++++++++---- redis/cluster.py | 83 +++++---- redis/commands/policies.py | 101 ++++++++++- tests/test_asyncio/test_command_parser.py | 69 ++++++++ tests/test_asyncio/test_command_policies.py | 147 ++++++++++++++++ 6 files changed, 678 insertions(+), 73 deletions(-) create mode 100644 tests/test_asyncio/test_command_parser.py create mode 100644 tests/test_asyncio/test_command_policies.py diff --git a/redis/_parsers/commands.py b/redis/_parsers/commands.py index cff2296b27..bb1d92d0c7 100644 --- a/redis/_parsers/commands.py +++ b/redis/_parsers/commands.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, Awaitable from redis.exceptions import RedisError, ResponseError, IncorrectPolicyType from redis.utils import str_if_bytes @@ -455,7 +455,9 @@ async def get_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: for subcmd in command["subcommands"]: if str_if_bytes(subcmd[0]) == subcmd_name: command = self.parse_subcommand(subcmd) - is_subcmd = True + + if command['first_key_pos'] > 0: + is_subcmd = True # The command doesn't have keys in it if not is_subcmd: @@ -483,3 +485,172 @@ async def _get_moveable_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: else: raise e return keys + + async def _is_keyless_command(self, command_name: str, subcommand_name: Optional[str]=None) -> bool: + """ + Determines whether a given command or subcommand is considered "keyless". + + A keyless command does not operate on specific keys, which is determined based + on the first key position in the command or subcommand details. If the command + or subcommand's first key position is zero or negative, it is treated as keyless. + + Parameters: + command_name: str + The name of the command to check. + subcommand_name: Optional[str], default=None + The name of the subcommand to check, if applicable. If not provided, + the check is performed only on the command. + + Returns: + bool + True if the specified command or subcommand is considered keyless, + False otherwise. + + Raises: + ValueError + If the specified subcommand is not found within the command or the + specified command does not exist in the available commands. + """ + if subcommand_name: + for subcommand in self.commands.get(command_name)['subcommands']: + if str_if_bytes(subcommand[0]) == subcommand_name: + parsed_subcmd = self.parse_subcommand(subcommand) + return parsed_subcmd['first_key_pos'] <= 0 + raise ValueError(f"Subcommand {subcommand_name} not found in command {command_name}") + else: + command_details = self.commands.get(command_name, None) + if command_details is not None: + return command_details['first_key_pos'] <= 0 + + raise ValueError(f"Command {command_name} not found in commands") + + async def get_command_policies(self) -> Awaitable[PolicyRecords]: + """ + Retrieve and process the command policies for all commands and subcommands. + + This method traverses through commands and subcommands, extracting policy details + from associated data structures and constructing a dictionary of commands with their + associated policies. It supports nested data structures and handles both main commands + and their subcommands. + + Returns: + PolicyRecords: A collection of commands and subcommands associated with their + respective policies. + + Raises: + IncorrectPolicyType: If an invalid policy type is encountered during policy extraction. + """ + command_with_policies = {} + + def extract_policies(data, module_name, command_name): + """ + Recursively extract policies from nested data structures. + + Args: + data: The data structure to search (can be list, dict, str, bytes, etc.) + command_name: The command name to associate with found policies + """ + if isinstance(data, (str, bytes)): + # Decode bytes to string if needed + policy = str_if_bytes(data.decode()) + + # Check if this is a policy string + if policy.startswith('request_policy') or policy.startswith('response_policy'): + if policy.startswith('request_policy'): + policy_type = policy.split(':')[1] + + try: + command_with_policies[module_name][command_name].request_policy = RequestPolicy(policy_type) + except ValueError: + raise IncorrectPolicyType(f"Incorrect request policy type: {policy_type}") + + if policy.startswith('response_policy'): + policy_type = policy.split(':')[1] + + try: + command_with_policies[module_name][command_name].response_policy = ResponsePolicy( + policy_type) + except ValueError: + raise IncorrectPolicyType(f"Incorrect response policy type: {policy_type}") + + elif isinstance(data, list): + # For lists, recursively process each element + for item in data: + extract_policies(item, module_name, command_name) + + elif isinstance(data, dict): + # For dictionaries, recursively process each value + for value in data.values(): + extract_policies(value, module_name, command_name) + + for command, details in self.commands.items(): + # Check whether the command has keys + is_keyless = await self._is_keyless_command(command) + + if is_keyless: + default_request_policy = RequestPolicy.DEFAULT_KEYLESS + default_response_policy = ResponsePolicy.DEFAULT_KEYLESS + else: + default_request_policy = RequestPolicy.DEFAULT_KEYED + default_response_policy = ResponsePolicy.DEFAULT_KEYED + + # Check if it's a core or module command + split_name = command.split('.') + + if len(split_name) > 1: + module_name = split_name[0] + command_name = split_name[1] + else: + module_name = 'core' + command_name = split_name[0] + + # Create a CommandPolicies object with default policies on the new command. + if command_with_policies.get(module_name, None) is None: + command_with_policies[module_name] = {command_name: CommandPolicies( + request_policy=default_request_policy, + response_policy=default_response_policy + )} + else: + command_with_policies[module_name][command_name] = CommandPolicies( + request_policy=default_request_policy, + response_policy=default_response_policy + ) + + tips = details.get('tips') + subcommands = details.get('subcommands') + + # Process tips for the main command + if tips: + extract_policies(tips, module_name, command_name) + + # Process subcommands + if subcommands: + for subcommand_details in subcommands: + # Get the subcommand name (first element) + subcmd_name = subcommand_details[0] + if isinstance(subcmd_name, bytes): + subcmd_name = subcmd_name.decode() + + # Check whether the subcommand has keys + is_keyless = await self._is_keyless_command(command, subcmd_name) + + if is_keyless: + default_request_policy = RequestPolicy.DEFAULT_KEYLESS + default_response_policy = ResponsePolicy.DEFAULT_KEYLESS + else: + default_request_policy = RequestPolicy.DEFAULT_KEYED + default_response_policy = ResponsePolicy.DEFAULT_KEYED + + subcmd_name = subcmd_name.replace('|', ' ') + + # Create a CommandPolicies object with default policies on the new command. + command_with_policies[module_name][subcmd_name] = CommandPolicies( + request_policy=default_request_policy, + response_policy=default_response_policy + ) + + # Recursively extract policies from the rest of the subcommand details + for subcommand_detail in subcommand_details[1:]: + extract_policies(subcommand_detail, module_name, subcmd_name) + + return command_with_policies \ No newline at end of file diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 4e0e06517d..665e780038 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -26,6 +26,7 @@ ) from redis._parsers import AsyncCommandsParser, Encoder +from redis._parsers.commands import RequestPolicy, ResponsePolicy, CommandPolicies from redis._parsers.helpers import ( _RedisCallbacks, _RedisCallbacksRESP2, @@ -51,6 +52,7 @@ parse_cluster_slots, ) from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands +from redis.commands.policies import AsyncPolicyResolver, AsyncStaticPolicyResolver from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot from redis.credentials import CredentialProvider from redis.event import AfterAsyncClusterInstantiationEvent, EventDispatcher @@ -310,6 +312,7 @@ def __init__( protocol: Optional[int] = 2, address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, event_dispatcher: Optional[EventDispatcher] = None, + policy_resolver: AsyncPolicyResolver = AsyncStaticPolicyResolver(), ) -> None: if db: raise RedisClusterException( @@ -422,7 +425,32 @@ def __init__( self.load_balancing_strategy = load_balancing_strategy self.reinitialize_steps = reinitialize_steps self.reinitialize_counter = 0 + + # For backward compatibility, mapping from existing policies to new one + self._command_flags_mapping: dict[str, Union[RequestPolicy, ResponsePolicy]] = { + self.__class__.RANDOM: RequestPolicy.DEFAULT_KEYLESS, + self.__class__.PRIMARIES: RequestPolicy.ALL_SHARDS, + self.__class__.ALL_NODES: RequestPolicy.ALL_NODES, + self.__class__.REPLICAS: RequestPolicy.ALL_REPLICAS, + self.__class__.DEFAULT_NODE: RequestPolicy.DEFAULT_NODE, + SLOT_ID: RequestPolicy.DEFAULT_KEYED, + } + + self._policies_callback_mapping: dict[Union[RequestPolicy, ResponsePolicy], Callable] = { + RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [self.get_random_primary_or_all_nodes(command_name)], + RequestPolicy.DEFAULT_KEYED: self.get_nodes_from_slot, + RequestPolicy.DEFAULT_NODE: lambda: [self.get_default_node()], + RequestPolicy.ALL_SHARDS: self.get_primaries, + RequestPolicy.ALL_NODES: self.get_nodes, + RequestPolicy.ALL_REPLICAS: self.get_replicas, + RequestPolicy.SPECIAL: self.get_special_nodes, + ResponsePolicy.DEFAULT_KEYLESS: lambda res: res, + ResponsePolicy.DEFAULT_KEYED: lambda res: res, + } + + self._policy_resolver = policy_resolver self.commands_parser = AsyncCommandsParser() + self._aggregate_nodes = None self.node_flags = self.__class__.NODE_FLAGS.copy() self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.response_callbacks = kwargs["response_callbacks"] @@ -618,6 +646,43 @@ def get_node_from_key( return slot_cache[node_idx] + def get_random_primary_or_all_nodes(self, command_name): + """ + Returns random primary or all nodes depends on READONLY mode. + """ + if self.read_from_replicas and command_name in READ_COMMANDS: + return self.get_random_node() + + return self.get_random_primary_node() + + def get_random_primary_node(self) -> "ClusterNode": + """ + Returns a random primary node + """ + return random.choice(self.get_primaries()) + + async def get_nodes_from_slot(self, command: str, *args): + """ + Returns a list of nodes that hold the specified keys' slots. + """ + # get the node that holds the key's slot + return [ + self.nodes_manager.get_node_from_slot( + await self._determine_slot(command, *args), + self.read_from_replicas and command in READ_COMMANDS, + self.load_balancing_strategy if command in READ_COMMANDS else None, + ) + ] + + def get_special_nodes(self) -> Optional[list["ClusterNode"]]: + """ + Returns a list of nodes for commands with a special policy. + """ + if not self._aggregate_nodes: + raise RedisClusterException('Cannot execute FT.CURSOR commands without FT.AGGREGATE') + + return self._aggregate_nodes + def keyslot(self, key: EncodableT) -> int: """ Find the keyslot for a given key. @@ -642,7 +707,7 @@ def set_response_callback(self, command: str, callback: ResponseCallbackT) -> No self.response_callbacks[command] = callback async def _determine_nodes( - self, command: str, *args: Any, node_flag: Optional[str] = None + self, command: str, *args: Any, request_policy: RequestPolicy, node_flag: Optional[str] = None ) -> List["ClusterNode"]: # Determine which nodes should be executed the command on. # Returns a list of target nodes. @@ -650,31 +715,22 @@ async def _determine_nodes( # get the nodes group for this command if it was predefined node_flag = self.command_flags.get(command) - if node_flag in self.node_flags: - if node_flag == self.__class__.DEFAULT_NODE: - # return the cluster's default node - return [self.nodes_manager.default_node] - if node_flag == self.__class__.PRIMARIES: - # return all primaries - return self.nodes_manager.get_nodes_by_server_type(PRIMARY) - if node_flag == self.__class__.REPLICAS: - # return all replicas - return self.nodes_manager.get_nodes_by_server_type(REPLICA) - if node_flag == self.__class__.ALL_NODES: - # return all nodes - return list(self.nodes_manager.nodes_cache.values()) - if node_flag == self.__class__.RANDOM: - # return a random node - return [random.choice(list(self.nodes_manager.nodes_cache.values()))] + if node_flag in self._command_flags_mapping: + request_policy = self._command_flags_mapping[node_flag] - # get the node that holds the key's slot - return [ - self.nodes_manager.get_node_from_slot( - await self._determine_slot(command, *args), - self.read_from_replicas and command in READ_COMMANDS, - self.load_balancing_strategy if command in READ_COMMANDS else None, - ) - ] + policy_callback = self._policies_callback_mapping[request_policy] + + if request_policy == RequestPolicy.DEFAULT_KEYED: + nodes = await policy_callback(command, *args) + elif request_policy == RequestPolicy.DEFAULT_KEYLESS: + nodes = policy_callback(command) + else: + nodes = policy_callback() + + if command.lower() == "ft.aggregate": + self._aggregate_nodes = nodes + + return nodes async def _determine_slot(self, command: str, *args: Any) -> int: if self.command_flags.get(command) == SLOT_ID: @@ -779,6 +835,31 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: target_nodes_specified = True retry_attempts = 0 + command_policies = await self._policy_resolver.resolve(args[0].lower()) + + if not command_policies and not target_nodes_specified: + command_flag = self.command_flags.get(command) + if not command_flag: + # Fallback to default policy + if not self.get_default_node(): + slot = None + else: + slot = await self._determine_slot(*args) + if not slot: + command_policies = CommandPolicies() + else: + command_policies = CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYED, + response_policy=ResponsePolicy.DEFAULT_KEYED, + ) + else: + if command_flag in self._command_flags_mapping: + command_policies = CommandPolicies(request_policy=self._command_flags_mapping[command_flag]) + else: + command_policies = CommandPolicies() + elif not command_policies and target_nodes_specified: + command_policies = CommandPolicies() + # Add one for the first execution execute_attempts = 1 + retry_attempts for _ in range(execute_attempts): @@ -794,7 +875,7 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: if not target_nodes_specified: # Determine the nodes to execute the command on target_nodes = await self._determine_nodes( - *args, node_flag=passed_targets + *args, request_policy=command_policies.request_policy, node_flag=passed_targets ) if not target_nodes: raise RedisClusterException( @@ -805,10 +886,10 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: # Return the processed result ret = await self._execute_command(target_nodes[0], *args, **kwargs) if command in self.result_callbacks: - return self.result_callbacks[command]( + ret = self.result_callbacks[command]( command, {target_nodes[0].name: ret}, **kwargs ) - return ret + return self._policies_callback_mapping[command_policies.response_policy](ret) else: keys = [node.name for node in target_nodes] values = await asyncio.gather( @@ -823,7 +904,7 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: return self.result_callbacks[command]( command, dict(zip(keys, values)), **kwargs ) - return dict(zip(keys, values)) + return self._policies_callback_mapping[command_policies.response_policy](dict(zip(keys, values))) except Exception as e: if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY: # The nodes and slots cache were should be reinitialized. @@ -1739,6 +1820,7 @@ def __init__(self, position: int, *args: Any, **kwargs: Any) -> None: self.kwargs = kwargs self.position = position self.result: Union[Any, Exception] = None + self.command_policies: Optional[CommandPolicies] = None def __repr__(self) -> str: return f"[{self.position}] {self.args} ({self.kwargs})" @@ -1979,16 +2061,44 @@ async def _execute( nodes = {} for cmd in todo: passed_targets = cmd.kwargs.pop("target_nodes", None) + command_policies = await client._policy_resolver.resolve(cmd.args[0].lower()) + if passed_targets and not client._is_node_flag(passed_targets): target_nodes = client._parse_target_nodes(passed_targets) + + if not command_policies: + command_policies = CommandPolicies() else: + if not command_policies: + command_flag = client.command_flags.get(cmd.args[0]) + if not command_flag: + # Fallback to default policy + if not client.get_default_node(): + slot = None + else: + slot = await client._determine_slot(*cmd.args) + if not slot: + command_policies = CommandPolicies() + else: + command_policies = CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYED, + response_policy=ResponsePolicy.DEFAULT_KEYED, + ) + else: + if command_flag in client._command_flags_mapping: + command_policies = CommandPolicies( + request_policy=client._command_flags_mapping[command_flag]) + else: + command_policies = CommandPolicies() + target_nodes = await client._determine_nodes( - *cmd.args, node_flag=passed_targets + *cmd.args, request_policy=command_policies.request_policy, node_flag=passed_targets ) if not target_nodes: raise RedisClusterException( f"No targets were found to execute {cmd.args} command on" ) + cmd.command_policies = command_policies if len(target_nodes) > 1: raise RedisClusterException(f"Too many targets for command {cmd.args}") node = target_nodes[0] @@ -2009,8 +2119,10 @@ async def _execute( for cmd in todo: if isinstance(cmd.result, (TryAgainError, MovedError, AskError)): try: - cmd.result = await client.execute_command( - *cmd.args, **cmd.kwargs + cmd.result = client._policies_callback_mapping[cmd.command_policies.response_policy]( + await client.execute_command( + *cmd.args, **cmd.kwargs + ) ) except Exception as e: cmd.result = e diff --git a/redis/cluster.py b/redis/cluster.py index 8fc7ef5ef7..4225a437db 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1245,7 +1245,11 @@ def _internal_execute_command(self, *args, **kwargs): passed_targets = kwargs.pop("target_nodes", None) command_policies = self._policy_resolver.resolve(args[0].lower()) - if not command_policies: + if passed_targets is not None and not self._is_nodes_flag(passed_targets): + target_nodes = self._parse_target_nodes(passed_targets) + target_nodes_specified = True + + if not command_policies and not target_nodes_specified: command = args[0].upper() if len(args) >= 2 and f"{args[0]} {args[1]}".upper() in self.command_flags: command = f"{args[0]} {args[1]}".upper() @@ -1256,10 +1260,10 @@ def _internal_execute_command(self, *args, **kwargs): if not command_flag: # Fallback to default policy if not self.get_default_node(): - keys = None + slot = None else: - keys = self._get_command_keys(*args) - if not keys or len(keys) == 0: + slot = self.determine_slot(*args) + if not slot: command_policies = CommandPolicies() else: command_policies = CommandPolicies( @@ -1271,10 +1275,9 @@ def _internal_execute_command(self, *args, **kwargs): command_policies = CommandPolicies(request_policy=self._command_flags_mapping[command_flag]) else: command_policies = CommandPolicies() + elif not command_policies and target_nodes_specified: + command_policies = CommandPolicies() - if passed_targets is not None and not self._is_nodes_flag(passed_targets): - target_nodes = self._parse_target_nodes(passed_targets) - target_nodes_specified = True # If an error that allows retrying was thrown, the nodes and slots # cache were reinitialized. We will retry executing the command with # the updated cluster setup only when the target nodes can be @@ -2573,6 +2576,7 @@ def __init__(self, args, options=None, position=None): self.result = None self.node = None self.asking = False + self.command_policies: Optional[CommandPolicies] = None class NodeCommands: @@ -2933,33 +2937,6 @@ def _send_cluster_commands( for c in attempt: command_policies = self._pipe._policy_resolver.resolve(c.args[0].lower()) - if not command_policies: - command = c.args[0].upper() - if len(c.args) >= 2 and f"{c.args[0]} {c.args[1]}".upper() in self._pipe.command_flags: - command = f"{c.args[0]} {c.args[1]}".upper() - - # We only could resolve key properties if command is not - # in a list of pre-defined request policies - command_flag = self.command_flags.get(command) - if not command_flag: - # Fallback to default policy - if not self._pipe.get_default_node(): - keys = None - else: - keys = self._pipe._get_command_keys(*c.args) - if not keys or len(keys) == 0: - command_policies = CommandPolicies() - else: - command_policies = CommandPolicies( - request_policy=RequestPolicy.DEFAULT_KEYED, - response_policy=ResponsePolicy.DEFAULT_KEYED, - ) - else: - if command_flag in self._pipe._command_flags_mapping: - command_policies = CommandPolicies(request_policy=self._pipe._command_flags_mapping[command_flag]) - else: - command_policies = CommandPolicies() - while True: # refer to our internal node -> slot table that # tells us where a given command should route to. @@ -2968,7 +2945,38 @@ def _send_cluster_commands( passed_targets = c.options.pop("target_nodes", None) if passed_targets and not self._is_nodes_flag(passed_targets): target_nodes = self._parse_target_nodes(passed_targets) + + if not command_policies: + command_policies = CommandPolicies() else: + if not command_policies: + command = c.args[0].upper() + if len(c.args) >= 2 and f"{c.args[0]} {c.args[1]}".upper() in self._pipe.command_flags: + command = f"{c.args[0]} {c.args[1]}".upper() + + # We only could resolve key properties if command is not + # in a list of pre-defined request policies + command_flag = self.command_flags.get(command) + if not command_flag: + # Fallback to default policy + if not self._pipe.get_default_node(): + keys = None + else: + keys = self._pipe._get_command_keys(*c.args) + if not keys or len(keys) == 0: + command_policies = CommandPolicies() + else: + command_policies = CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYED, + response_policy=ResponsePolicy.DEFAULT_KEYED, + ) + else: + if command_flag in self._pipe._command_flags_mapping: + command_policies = CommandPolicies( + request_policy=self._pipe._command_flags_mapping[command_flag]) + else: + command_policies = CommandPolicies() + target_nodes = self._determine_nodes( *c.args, request_policy=command_policies.request_policy, node_flag=passed_targets ) @@ -2976,6 +2984,7 @@ def _send_cluster_commands( raise RedisClusterException( f"No targets were found to execute {c.args} command on" ) + c.command_policies = command_policies if len(target_nodes) > 1: raise RedisClusterException( f"Too many targets for command {c.args}" @@ -3100,8 +3109,10 @@ def _send_cluster_commands( if c.args[0] in self._pipe.cluster_response_callbacks: # Remove keys entry, it needs only for cache. c.options.pop("keys", None) - c.result = self._pipe.cluster_response_callbacks[c.args[0]]( - c.result, **c.options + c.result = self._pipe._policies_callback_mapping[c.command_policies.response_policy]( + self._pipe.cluster_response_callbacks[c.args[0]]( + c.result, **c.options + ) ) response.append(c.result) diff --git a/redis/commands/policies.py b/redis/commands/policies.py index 7413f3f68d..ba2cc8968c 100644 --- a/redis/commands/policies.py +++ b/redis/commands/policies.py @@ -1,7 +1,9 @@ +import asyncio from abc import ABC, abstractmethod from typing import Optional -from redis._parsers.commands import CommandPolicies, PolicyRecords, RequestPolicy, ResponsePolicy, CommandsParser +from redis._parsers.commands import CommandPolicies, PolicyRecords, RequestPolicy, ResponsePolicy, CommandsParser, \ + AsyncCommandsParser STATIC_POLICIES: PolicyRecords = { 'ft': { @@ -39,7 +41,7 @@ class PolicyResolver(ABC): @abstractmethod - def resolve(self, command_name: str) -> CommandPolicies: + def resolve(self, command_name: str) -> Optional[CommandPolicies]: """ Resolves the command name and determines the associated command policies. @@ -64,6 +66,34 @@ def with_fallback(self, fallback: "PolicyResolver") -> "PolicyResolver": """ pass +class AsyncPolicyResolver(ABC): + + @abstractmethod + async def resolve(self, command_name: str) -> Optional[CommandPolicies]: + """ + Resolves the command name and determines the associated command policies. + + Args: + command_name: The name of the command to resolve. + + Returns: + CommandPolicies: The policies associated with the specified command. + """ + pass + + @abstractmethod + def with_fallback(self, fallback: "AsyncPolicyResolver") -> "AsyncPolicyResolver": + """ + Factory method to instantiate an async policy resolver with a fallback resolver. + + Args: + fallback: Fallback resolver + + Returns: + AsyncPolicyResolver: Returns a new policy resolver with the specified fallback resolver. + """ + pass + class BasePolicyResolver(PolicyResolver): """ Base class for policy resolvers. @@ -98,6 +128,40 @@ def resolve(self, command_name: str) -> Optional[CommandPolicies]: def with_fallback(self, fallback: "PolicyResolver") -> "PolicyResolver": pass +class AsyncBasePolicyResolver(AsyncPolicyResolver): + """ + Async base class for policy resolvers. + """ + def __init__(self, policies: PolicyRecords, fallback: Optional[AsyncPolicyResolver] = None) -> None: + self._policies = policies + self._fallback = fallback + + async def resolve(self, command_name: str) -> Optional[CommandPolicies]: + parts = command_name.split(".") + + if len(parts) > 2: + raise ValueError(f"Wrong command or module name: {command_name}") + + module, command = parts if len(parts) == 2 else ("core", parts[0]) + + if self._policies.get(module, None) is None: + if self._fallback is not None: + return await self._fallback.resolve(command_name) + else: + return None + + if self._policies.get(module).get(command, None) is None: + if self._fallback is not None: + return await self._fallback.resolve(command_name) + else: + return None + + return self._policies.get(module).get(command) + + @abstractmethod + def with_fallback(self, fallback: "AsyncPolicyResolver") -> "AsyncPolicyResolver": + pass + class DynamicPolicyResolver(BasePolicyResolver): """ @@ -130,4 +194,35 @@ def __init__(self, fallback: Optional[PolicyResolver] = None) -> None: super().__init__(STATIC_POLICIES, fallback) def with_fallback(self, fallback: "PolicyResolver") -> "PolicyResolver": - return StaticPolicyResolver(fallback) \ No newline at end of file + return StaticPolicyResolver(fallback) + +class AsyncDynamicPolicyResolver(AsyncBasePolicyResolver): + """ + Async version of DynamicPolicyResolver. + """ + def __init__(self, policy_records: PolicyRecords, fallback: Optional[AsyncPolicyResolver] = None) -> None: + """ + Parameters: + policy_records (PolicyRecords): Policy records. + fallback (Optional[AsyncPolicyResolver]): An optional resolver to be used when the + primary policies cannot handle a specific request. + """ + super().__init__(policy_records, fallback) + + def with_fallback(self, fallback: "AsyncPolicyResolver") -> "AsyncPolicyResolver": + return AsyncDynamicPolicyResolver(self._policies, fallback) + +class AsyncStaticPolicyResolver(AsyncBasePolicyResolver): + """ + Async version of StaticPolicyResolver. + """ + def __init__(self, fallback: Optional[AsyncPolicyResolver] = None) -> None: + """ + Parameters: + fallback (Optional[AsyncPolicyResolver]): An optional fallback policy resolver + used for resolving policies if static policies are inadequate. + """ + super().__init__(STATIC_POLICIES, fallback) + + def with_fallback(self, fallback: "AsyncPolicyResolver") -> "AsyncPolicyResolver": + return AsyncStaticPolicyResolver(fallback) \ No newline at end of file diff --git a/tests/test_asyncio/test_command_parser.py b/tests/test_asyncio/test_command_parser.py new file mode 100644 index 0000000000..da714a13d7 --- /dev/null +++ b/tests/test_asyncio/test_command_parser.py @@ -0,0 +1,69 @@ +import pytest + +from redis._parsers import AsyncCommandsParser +from redis._parsers.commands import RequestPolicy, ResponsePolicy +from tests.conftest import skip_if_server_version_lt + + +class TestAsyncCommandParser: + @skip_if_server_version_lt("7.0.0") + @pytest.mark.onlycluster + @pytest.mark.asyncio + async def test_get_command_policies(self, r): + commands_parser = AsyncCommandsParser() + await commands_parser.initialize(node=r.get_default_node()) + expected_command_policies = { + 'core': { + 'keys': ['keys', RequestPolicy.ALL_SHARDS, ResponsePolicy.DEFAULT_KEYLESS], + 'acl setuser': ['acl setuser', RequestPolicy.ALL_NODES, ResponsePolicy.ALL_SUCCEEDED], + 'exists': ['exists', RequestPolicy.MULTI_SHARD, ResponsePolicy.AGG_SUM], + 'config resetstat': ['config resetstat', RequestPolicy.ALL_NODES, ResponsePolicy.ALL_SUCCEEDED], + 'slowlog len': ['slowlog len', RequestPolicy.ALL_NODES, ResponsePolicy.AGG_SUM], + 'scan': ['scan', RequestPolicy.SPECIAL, ResponsePolicy.SPECIAL], + 'latency history': ['latency history', RequestPolicy.ALL_NODES, ResponsePolicy.SPECIAL], + 'memory doctor': ['memory doctor', RequestPolicy.ALL_SHARDS, ResponsePolicy.SPECIAL], + 'randomkey': ['randomkey', RequestPolicy.ALL_SHARDS, ResponsePolicy.SPECIAL], + 'mget': ['mget', RequestPolicy.MULTI_SHARD, ResponsePolicy.DEFAULT_KEYED], + 'function restore': ['function restore', RequestPolicy.ALL_SHARDS, ResponsePolicy.ALL_SUCCEEDED], + }, + 'json': { + 'debug': ['debug', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + 'get': ['get', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + }, + 'ft': { + 'search': ['search', RequestPolicy.DEFAULT_KEYLESS, ResponsePolicy.DEFAULT_KEYLESS], + 'create': ['create', RequestPolicy.DEFAULT_KEYLESS, ResponsePolicy.DEFAULT_KEYLESS], + }, + 'bf': { + 'add': ['add', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + 'madd': ['madd', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + }, + 'cf': { + 'add': ['add', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + 'mexists': ['mexists', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + }, + 'tdigest': { + 'add': ['add', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + 'min': ['min', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + }, + 'ts': { + 'create': ['create', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + 'info': ['info', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + }, + 'topk': { + 'list': ['list', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + 'query': ['query', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + } + } + + actual_policies = await commands_parser.get_command_policies() + assert len(actual_policies) > 0 + + for module_name, commands in expected_command_policies.items(): + for command, command_policies in commands.items(): + assert command in actual_policies[module_name] + assert command_policies == [ + command, + actual_policies[module_name][command].request_policy, + actual_policies[module_name][command].response_policy + ] \ No newline at end of file diff --git a/tests/test_asyncio/test_command_policies.py b/tests/test_asyncio/test_command_policies.py new file mode 100644 index 0000000000..2c0f0d2ddb --- /dev/null +++ b/tests/test_asyncio/test_command_policies.py @@ -0,0 +1,147 @@ +import random + +import pytest +from mock import patch + +from redis import ResponseError +from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy +from redis.asyncio import RedisCluster +from redis.commands.policies import AsyncDynamicPolicyResolver, AsyncStaticPolicyResolver +from redis.commands.search.aggregation import AggregateRequest +from redis.commands.search.field import NumericField, TextField + + +@pytest.mark.asyncio +@pytest.mark.onlycluster +class TestBasePolicyResolver: + async def test_resolve(self): + zcount_policy = CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYED, response_policy=ResponsePolicy.DEFAULT_KEYED) + rpoplpush_policy = CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYED, response_policy=ResponsePolicy.DEFAULT_KEYED) + + dynamic_resolver = AsyncDynamicPolicyResolver({ + 'core': { + 'zcount': zcount_policy, + 'rpoplpush': rpoplpush_policy, + } + }) + assert await dynamic_resolver.resolve('zcount') == zcount_policy + assert await dynamic_resolver.resolve('rpoplpush') == rpoplpush_policy + + with pytest.raises(ValueError, match="Wrong command or module name: foo.bar.baz"): + await dynamic_resolver.resolve('foo.bar.baz') + + assert await dynamic_resolver.resolve('foo.bar') is None + assert await dynamic_resolver.resolve('core.foo') is None + + # Test that policy fallback correctly + static_resolver = AsyncStaticPolicyResolver() + with_fallback_dynamic_resolver = dynamic_resolver.with_fallback(static_resolver) + resolved_policies = await with_fallback_dynamic_resolver.resolve('ft.aggregate') + + assert resolved_policies.request_policy == RequestPolicy.DEFAULT_KEYLESS + assert resolved_policies.response_policy == ResponsePolicy.DEFAULT_KEYLESS + + # Extended chain with one more resolver + foo_bar_policy = CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS) + + another_dynamic_resolver = AsyncDynamicPolicyResolver({ + 'foo': { + 'bar': foo_bar_policy, + } + }) + with_fallback_static_resolver = static_resolver.with_fallback(another_dynamic_resolver) + with_double_fallback_dynamic_resolver = dynamic_resolver.with_fallback(with_fallback_static_resolver) + + assert await with_double_fallback_dynamic_resolver.resolve('foo.bar') == foo_bar_policy + +@pytest.mark.onlycluster +@pytest.mark.asyncio +class TestClusterWithPolicies: + async def test_resolves_correctly_policies(self, r: RedisCluster, monkeypatch): + # original nodes selection method + determine_nodes = r._determine_nodes + determined_nodes = [] + primary_nodes = r.get_primaries() + calls = iter(list(range(len(primary_nodes)))) + + async def wrapper(*args, request_policy: RequestPolicy, **kwargs): + nonlocal determined_nodes + determined_nodes = await determine_nodes(*args, request_policy=request_policy, **kwargs) + return determined_nodes + + # Mock random.choice to always return a pre-defined sequence of nodes + monkeypatch.setattr(random, "choice", lambda seq: seq[next(calls)]) + + with patch.object(r, '_determine_nodes', side_effect=wrapper, autospec=True): + # Routed to a random primary node + await r.ft().create_index( + [ + NumericField("random_num"), + TextField("title"), + TextField("body"), + TextField("parent"), + ] + ) + assert determined_nodes[0] == primary_nodes[0] + + # Routed to another random primary node + info = await r.ft().info() + assert info['index_name'] == 'idx' + assert determined_nodes[0] == primary_nodes[1] + + expected_node = await r.get_nodes_from_slot('FT.SUGLEN', *['foo']) + await r.ft().suglen('foo') + assert determined_nodes[0] == expected_node[0] + + # Indexing a document + await r.hset( + "search", + mapping={ + "title": "RediSearch", + "body": "Redisearch impements a search engine on top of redis", + "parent": "redis", + "random_num": 10, + }, + ) + await r.hset( + "ai", + mapping={ + "title": "RedisAI", + "body": "RedisAI executes Deep Learning/Machine Learning models and managing their data.", # noqa + "parent": "redis", + "random_num": 3, + }, + ) + await r.hset( + "json", + mapping={ + "title": "RedisJson", + "body": "RedisJSON implements ECMA-404 The JSON Data Interchange Standard as a native data type.", # noqa + "parent": "redis", + "random_num": 8, + }, + ) + + req = AggregateRequest("redis").group_by( + "@parent" + ).cursor(1) + res = await r.ft().aggregate(req) + cursor = res.cursor + + # Ensure that aggregate node was cached. + assert determined_nodes[0] == r._aggregate_nodes[0] + + await r.ft().aggregate(cursor) + + # Verify that FT.CURSOR dispatched to the same node. + assert determined_nodes[0] == r._aggregate_nodes[0] + + # Error propagates to a user + with pytest.raises(ResponseError, match="Cursor not found, id: 0"): + await r.ft().aggregate(cursor) + + assert determined_nodes[0] == primary_nodes[2] + + # Core commands also randomly distributed across masters + await r.randomkey() + assert determined_nodes[0] == primary_nodes[0] \ No newline at end of file From eff2af0d5e27a4dcd4ec2335a3b8b56125bcc59c Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 4 Nov 2025 09:35:23 +0200 Subject: [PATCH 04/13] Codestyle changes --- redis/_parsers/commands.py | 193 ++++++++++++-------- redis/asyncio/cluster.py | 55 ++++-- redis/cluster.py | 85 ++++++--- redis/commands/policies.py | 180 ++++++++++++++---- redis/exceptions.py | 4 +- tests/test_asyncio/test_command_parser.py | 164 +++++++++++++---- tests/test_asyncio/test_command_policies.py | 94 ++++++---- tests/test_command_parser.py | 168 +++++++++++++---- tests/test_command_policies.py | 81 +++++--- 9 files changed, 732 insertions(+), 292 deletions(-) diff --git a/redis/_parsers/commands.py b/redis/_parsers/commands.py index bb1d92d0c7..cbad42bd79 100644 --- a/redis/_parsers/commands.py +++ b/redis/_parsers/commands.py @@ -8,39 +8,44 @@ if TYPE_CHECKING: from redis.asyncio.cluster import ClusterNode + class RequestPolicy(Enum): - ALL_NODES = 'all_nodes' - ALL_SHARDS = 'all_shards' - ALL_REPLICAS = 'all_replicas' - MULTI_SHARD = 'multi_shard' - SPECIAL = 'special' - DEFAULT_KEYLESS = 'default_keyless' - DEFAULT_KEYED = 'default_keyed' - DEFAULT_NODE = 'default_node' + ALL_NODES = "all_nodes" + ALL_SHARDS = "all_shards" + ALL_REPLICAS = "all_replicas" + MULTI_SHARD = "multi_shard" + SPECIAL = "special" + DEFAULT_KEYLESS = "default_keyless" + DEFAULT_KEYED = "default_keyed" + DEFAULT_NODE = "default_node" + class ResponsePolicy(Enum): - ONE_SUCCEEDED = 'one_succeeded' - ALL_SUCCEEDED = 'all_succeeded' - AGG_LOGICAL_AND = 'agg_logical_and' - AGG_LOGICAL_OR = 'agg_logical_or' - AGG_MIN = 'agg_min' - AGG_MAX = 'agg_max' - AGG_SUM = 'agg_sum' - SPECIAL = 'special' - DEFAULT_KEYLESS = 'default_keyless' - DEFAULT_KEYED = 'default_keyed' + ONE_SUCCEEDED = "one_succeeded" + ALL_SUCCEEDED = "all_succeeded" + AGG_LOGICAL_AND = "agg_logical_and" + AGG_LOGICAL_OR = "agg_logical_or" + AGG_MIN = "agg_min" + AGG_MAX = "agg_max" + AGG_SUM = "agg_sum" + SPECIAL = "special" + DEFAULT_KEYLESS = "default_keyless" + DEFAULT_KEYED = "default_keyed" + class CommandPolicies: def __init__( - self, - request_policy: RequestPolicy = RequestPolicy.DEFAULT_KEYLESS, - response_policy: ResponsePolicy = ResponsePolicy.DEFAULT_KEYLESS + self, + request_policy: RequestPolicy = RequestPolicy.DEFAULT_KEYLESS, + response_policy: ResponsePolicy = ResponsePolicy.DEFAULT_KEYLESS, ): self.request_policy = request_policy self.response_policy = response_policy + PolicyRecords = dict[str, dict[str, CommandPolicies]] + class AbstractCommandsParser: def _get_pubsub_keys(self, *args): """ @@ -165,7 +170,7 @@ def get_keys(self, redis_conn, *args): if str_if_bytes(subcmd[0]) == subcmd_name: command = self.parse_subcommand(subcmd) - if command['first_key_pos'] > 0: + if command["first_key_pos"] > 0: is_subcmd = True # The command doesn't have keys in it @@ -206,7 +211,9 @@ def _get_moveable_keys(self, redis_conn, *args): raise e return keys - def _is_keyless_command(self, command_name: str, subcommand_name: Optional[str]=None) -> bool: + def _is_keyless_command( + self, command_name: str, subcommand_name: Optional[str] = None + ) -> bool: """ Determines whether a given command or subcommand is considered "keyless". @@ -232,15 +239,17 @@ def _is_keyless_command(self, command_name: str, subcommand_name: Optional[str]= specified command does not exist in the available commands. """ if subcommand_name: - for subcommand in self.commands.get(command_name)['subcommands']: + for subcommand in self.commands.get(command_name)["subcommands"]: if str_if_bytes(subcommand[0]) == subcommand_name: parsed_subcmd = self.parse_subcommand(subcommand) - return parsed_subcmd['first_key_pos'] <= 0 - raise ValueError(f"Subcommand {subcommand_name} not found in command {command_name}") + return parsed_subcmd["first_key_pos"] <= 0 + raise ValueError( + f"Subcommand {subcommand_name} not found in command {command_name}" + ) else: command_details = self.commands.get(command_name, None) if command_details is not None: - return command_details['first_key_pos'] <= 0 + return command_details["first_key_pos"] <= 0 raise ValueError(f"Command {command_name} not found in commands") @@ -265,7 +274,7 @@ def get_command_policies(self) -> PolicyRecords: def extract_policies(data, module_name, command_name): """ Recursively extract policies from nested data structures. - + Args: data: The data structure to search (can be list, dict, str, bytes, etc.) command_name: The command name to associate with found policies @@ -275,28 +284,38 @@ def extract_policies(data, module_name, command_name): policy = str_if_bytes(data.decode()) # Check if this is a policy string - if policy.startswith('request_policy') or policy.startswith('response_policy'): - if policy.startswith('request_policy'): - policy_type = policy.split(':')[1] + if policy.startswith("request_policy") or policy.startswith( + "response_policy" + ): + if policy.startswith("request_policy"): + policy_type = policy.split(":")[1] try: - command_with_policies[module_name][command_name].request_policy = RequestPolicy(policy_type) + command_with_policies[module_name][ + command_name + ].request_policy = RequestPolicy(policy_type) except ValueError: - raise IncorrectPolicyType(f"Incorrect request policy type: {policy_type}") + raise IncorrectPolicyType( + f"Incorrect request policy type: {policy_type}" + ) - if policy.startswith('response_policy'): - policy_type = policy.split(':')[1] + if policy.startswith("response_policy"): + policy_type = policy.split(":")[1] try: - command_with_policies[module_name][command_name].response_policy = ResponsePolicy(policy_type) + command_with_policies[module_name][ + command_name + ].response_policy = ResponsePolicy(policy_type) except ValueError: - raise IncorrectPolicyType(f"Incorrect response policy type: {policy_type}") - + raise IncorrectPolicyType( + f"Incorrect response policy type: {policy_type}" + ) + elif isinstance(data, list): # For lists, recursively process each element for item in data: extract_policies(item, module_name, command_name) - + elif isinstance(data, dict): # For dictionaries, recursively process each value for value in data.values(): @@ -314,29 +333,31 @@ def extract_policies(data, module_name, command_name): default_response_policy = ResponsePolicy.DEFAULT_KEYED # Check if it's a core or module command - split_name = command.split('.') + split_name = command.split(".") if len(split_name) > 1: module_name = split_name[0] command_name = split_name[1] else: - module_name = 'core' + module_name = "core" command_name = split_name[0] # Create a CommandPolicies object with default policies on the new command. if command_with_policies.get(module_name, None) is None: - command_with_policies[module_name] = {command_name: CommandPolicies( - request_policy=default_request_policy, - response_policy=default_response_policy - )} + command_with_policies[module_name] = { + command_name: CommandPolicies( + request_policy=default_request_policy, + response_policy=default_response_policy, + ) + } else: command_with_policies[module_name][command_name] = CommandPolicies( request_policy=default_request_policy, - response_policy=default_response_policy + response_policy=default_response_policy, ) - tips = details.get('tips') - subcommands = details.get('subcommands') + tips = details.get("tips") + subcommands = details.get("subcommands") # Process tips for the main command if tips: @@ -360,12 +381,12 @@ def extract_policies(data, module_name, command_name): default_request_policy = RequestPolicy.DEFAULT_KEYED default_response_policy = ResponsePolicy.DEFAULT_KEYED - subcmd_name = subcmd_name.replace('|', ' ') + subcmd_name = subcmd_name.replace("|", " ") # Create a CommandPolicies object with default policies on the new command. command_with_policies[module_name][subcmd_name] = CommandPolicies( request_policy=default_request_policy, - response_policy=default_response_policy + response_policy=default_response_policy, ) # Recursively extract policies from the rest of the subcommand details @@ -374,6 +395,7 @@ def extract_policies(data, module_name, command_name): return command_with_policies + class AsyncCommandsParser(AbstractCommandsParser): """ Parses Redis commands to get command keys. @@ -456,7 +478,7 @@ async def get_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: if str_if_bytes(subcmd[0]) == subcmd_name: command = self.parse_subcommand(subcmd) - if command['first_key_pos'] > 0: + if command["first_key_pos"] > 0: is_subcmd = True # The command doesn't have keys in it @@ -486,7 +508,9 @@ async def _get_moveable_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: raise e return keys - async def _is_keyless_command(self, command_name: str, subcommand_name: Optional[str]=None) -> bool: + async def _is_keyless_command( + self, command_name: str, subcommand_name: Optional[str] = None + ) -> bool: """ Determines whether a given command or subcommand is considered "keyless". @@ -512,15 +536,17 @@ async def _is_keyless_command(self, command_name: str, subcommand_name: Optional specified command does not exist in the available commands. """ if subcommand_name: - for subcommand in self.commands.get(command_name)['subcommands']: + for subcommand in self.commands.get(command_name)["subcommands"]: if str_if_bytes(subcommand[0]) == subcommand_name: parsed_subcmd = self.parse_subcommand(subcommand) - return parsed_subcmd['first_key_pos'] <= 0 - raise ValueError(f"Subcommand {subcommand_name} not found in command {command_name}") + return parsed_subcmd["first_key_pos"] <= 0 + raise ValueError( + f"Subcommand {subcommand_name} not found in command {command_name}" + ) else: command_details = self.commands.get(command_name, None) if command_details is not None: - return command_details['first_key_pos'] <= 0 + return command_details["first_key_pos"] <= 0 raise ValueError(f"Command {command_name} not found in commands") @@ -555,23 +581,32 @@ def extract_policies(data, module_name, command_name): policy = str_if_bytes(data.decode()) # Check if this is a policy string - if policy.startswith('request_policy') or policy.startswith('response_policy'): - if policy.startswith('request_policy'): - policy_type = policy.split(':')[1] + if policy.startswith("request_policy") or policy.startswith( + "response_policy" + ): + if policy.startswith("request_policy"): + policy_type = policy.split(":")[1] try: - command_with_policies[module_name][command_name].request_policy = RequestPolicy(policy_type) + command_with_policies[module_name][ + command_name + ].request_policy = RequestPolicy(policy_type) except ValueError: - raise IncorrectPolicyType(f"Incorrect request policy type: {policy_type}") + raise IncorrectPolicyType( + f"Incorrect request policy type: {policy_type}" + ) - if policy.startswith('response_policy'): - policy_type = policy.split(':')[1] + if policy.startswith("response_policy"): + policy_type = policy.split(":")[1] try: - command_with_policies[module_name][command_name].response_policy = ResponsePolicy( - policy_type) + command_with_policies[module_name][ + command_name + ].response_policy = ResponsePolicy(policy_type) except ValueError: - raise IncorrectPolicyType(f"Incorrect response policy type: {policy_type}") + raise IncorrectPolicyType( + f"Incorrect response policy type: {policy_type}" + ) elif isinstance(data, list): # For lists, recursively process each element @@ -595,29 +630,31 @@ def extract_policies(data, module_name, command_name): default_response_policy = ResponsePolicy.DEFAULT_KEYED # Check if it's a core or module command - split_name = command.split('.') + split_name = command.split(".") if len(split_name) > 1: module_name = split_name[0] command_name = split_name[1] else: - module_name = 'core' + module_name = "core" command_name = split_name[0] # Create a CommandPolicies object with default policies on the new command. if command_with_policies.get(module_name, None) is None: - command_with_policies[module_name] = {command_name: CommandPolicies( - request_policy=default_request_policy, - response_policy=default_response_policy - )} + command_with_policies[module_name] = { + command_name: CommandPolicies( + request_policy=default_request_policy, + response_policy=default_response_policy, + ) + } else: command_with_policies[module_name][command_name] = CommandPolicies( request_policy=default_request_policy, - response_policy=default_response_policy + response_policy=default_response_policy, ) - tips = details.get('tips') - subcommands = details.get('subcommands') + tips = details.get("tips") + subcommands = details.get("subcommands") # Process tips for the main command if tips: @@ -641,16 +678,16 @@ def extract_policies(data, module_name, command_name): default_request_policy = RequestPolicy.DEFAULT_KEYED default_response_policy = ResponsePolicy.DEFAULT_KEYED - subcmd_name = subcmd_name.replace('|', ' ') + subcmd_name = subcmd_name.replace("|", " ") # Create a CommandPolicies object with default policies on the new command. command_with_policies[module_name][subcmd_name] = CommandPolicies( request_policy=default_request_policy, - response_policy=default_response_policy + response_policy=default_response_policy, ) # Recursively extract policies from the rest of the subcommand details for subcommand_detail in subcommand_details[1:]: extract_policies(subcommand_detail, module_name, subcmd_name) - return command_with_policies \ No newline at end of file + return command_with_policies diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index da2fd35f8a..09a086aa31 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -437,8 +437,12 @@ def __init__( SLOT_ID: RequestPolicy.DEFAULT_KEYED, } - self._policies_callback_mapping: dict[Union[RequestPolicy, ResponsePolicy], Callable] = { - RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [self.get_random_primary_or_all_nodes(command_name)], + self._policies_callback_mapping: dict[ + Union[RequestPolicy, ResponsePolicy], Callable + ] = { + RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [ + self.get_random_primary_or_all_nodes(command_name) + ], RequestPolicy.DEFAULT_KEYED: self.get_nodes_from_slot, RequestPolicy.DEFAULT_NODE: lambda: [self.get_default_node()], RequestPolicy.ALL_SHARDS: self.get_primaries, @@ -680,7 +684,9 @@ def get_special_nodes(self) -> Optional[list["ClusterNode"]]: Returns a list of nodes for commands with a special policy. """ if not self._aggregate_nodes: - raise RedisClusterException('Cannot execute FT.CURSOR commands without FT.AGGREGATE') + raise RedisClusterException( + "Cannot execute FT.CURSOR commands without FT.AGGREGATE" + ) return self._aggregate_nodes @@ -708,7 +714,11 @@ def set_response_callback(self, command: str, callback: ResponseCallbackT) -> No self.response_callbacks[command] = callback async def _determine_nodes( - self, command: str, *args: Any, request_policy: RequestPolicy, node_flag: Optional[str] = None + self, + command: str, + *args: Any, + request_policy: RequestPolicy, + node_flag: Optional[str] = None, ) -> List["ClusterNode"]: # Determine which nodes should be executed the command on. # Returns a list of target nodes. @@ -855,7 +865,9 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: ) else: if command_flag in self._command_flags_mapping: - command_policies = CommandPolicies(request_policy=self._command_flags_mapping[command_flag]) + command_policies = CommandPolicies( + request_policy=self._command_flags_mapping[command_flag] + ) else: command_policies = CommandPolicies() elif not command_policies and target_nodes_specified: @@ -876,7 +888,9 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: if not target_nodes_specified: # Determine the nodes to execute the command on target_nodes = await self._determine_nodes( - *args, request_policy=command_policies.request_policy, node_flag=passed_targets + *args, + request_policy=command_policies.request_policy, + node_flag=passed_targets, ) if not target_nodes: raise RedisClusterException( @@ -890,7 +904,9 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: ret = self.result_callbacks[command]( command, {target_nodes[0].name: ret}, **kwargs ) - return self._policies_callback_mapping[command_policies.response_policy](ret) + return self._policies_callback_mapping[ + command_policies.response_policy + ](ret) else: keys = [node.name for node in target_nodes] values = await asyncio.gather( @@ -905,7 +921,9 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: return self.result_callbacks[command]( command, dict(zip(keys, values)), **kwargs ) - return self._policies_callback_mapping[command_policies.response_policy](dict(zip(keys, values))) + return self._policies_callback_mapping[ + command_policies.response_policy + ](dict(zip(keys, values))) except Exception as e: if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY: # The nodes and slots cache were should be reinitialized. @@ -2062,7 +2080,9 @@ async def _execute( nodes = {} for cmd in todo: passed_targets = cmd.kwargs.pop("target_nodes", None) - command_policies = await client._policy_resolver.resolve(cmd.args[0].lower()) + command_policies = await client._policy_resolver.resolve( + cmd.args[0].lower() + ) if passed_targets and not client._is_node_flag(passed_targets): target_nodes = client._parse_target_nodes(passed_targets) @@ -2088,12 +2108,17 @@ async def _execute( else: if command_flag in client._command_flags_mapping: command_policies = CommandPolicies( - request_policy=client._command_flags_mapping[command_flag]) + request_policy=client._command_flags_mapping[ + command_flag + ] + ) else: command_policies = CommandPolicies() target_nodes = await client._determine_nodes( - *cmd.args, request_policy=command_policies.request_policy, node_flag=passed_targets + *cmd.args, + request_policy=command_policies.request_policy, + node_flag=passed_targets, ) if not target_nodes: raise RedisClusterException( @@ -2120,11 +2145,9 @@ async def _execute( for cmd in todo: if isinstance(cmd.result, (TryAgainError, MovedError, AskError)): try: - cmd.result = client._policies_callback_mapping[cmd.command_policies.response_policy]( - await client.execute_command( - *cmd.args, **cmd.kwargs - ) - ) + cmd.result = client._policies_callback_mapping[ + cmd.command_policies.response_policy + ](await client.execute_command(*cmd.args, **cmd.kwargs)) except Exception as e: cmd.result = e diff --git a/redis/cluster.py b/redis/cluster.py index 2e240d1ef5..9765dc59d2 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -727,14 +727,20 @@ def __init__( SLOT_ID: RequestPolicy.DEFAULT_KEYED, } - self._policies_callback_mapping: dict[Union[RequestPolicy, ResponsePolicy], Callable] = { - RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [self.get_random_primary_or_all_nodes(command_name)], - RequestPolicy.DEFAULT_KEYED: lambda command, *args: self.get_nodes_from_slot(command, *args), + self._policies_callback_mapping: dict[ + Union[RequestPolicy, ResponsePolicy], Callable + ] = { + RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [ + self.get_random_primary_or_all_nodes(command_name) + ], + RequestPolicy.DEFAULT_KEYED: lambda command, + *args: self.get_nodes_from_slot(command, *args), RequestPolicy.DEFAULT_NODE: lambda: [self.get_default_node()], RequestPolicy.ALL_SHARDS: self.get_primaries, RequestPolicy.ALL_NODES: self.get_nodes, RequestPolicy.ALL_REPLICAS: self.get_replicas, - RequestPolicy.MULTI_SHARD: lambda *args, **kwargs: self._split_multi_shard_command(*args, **kwargs), + RequestPolicy.MULTI_SHARD: lambda *args, + **kwargs: self._split_multi_shard_command(*args, **kwargs), RequestPolicy.SPECIAL: self.get_special_nodes, ResponsePolicy.DEFAULT_KEYLESS: lambda res: res, ResponsePolicy.DEFAULT_KEYED: lambda res: res, @@ -866,10 +872,12 @@ def _split_multi_shard_command(self, *args, **kwargs) -> list[dict]: commands = [] for key in keys: - commands.append({ - 'args': (args[0], key), - 'kwargs': kwargs, - }) + commands.append( + { + "args": (args[0], key), + "kwargs": kwargs, + } + ) return commands @@ -878,7 +886,9 @@ def get_special_nodes(self) -> Optional[list["ClusterNode"]]: Returns a list of nodes for commands with a special policy. """ if not self._aggregate_nodes: - raise RedisClusterException('Cannot execute FT.CURSOR commands without FT.AGGREGATE') + raise RedisClusterException( + "Cannot execute FT.CURSOR commands without FT.AGGREGATE" + ) return self._aggregate_nodes @@ -912,7 +922,6 @@ def _evaluate_all_succeeded(self, res): return first_successful_response - def set_default_node(self, node): """ Set the default node of the cluster. @@ -1062,7 +1071,9 @@ def set_response_callback(self, command, callback): """Set a custom Response Callback""" self.cluster_response_callbacks[command] = callback - def _determine_nodes(self, *args, request_policy: RequestPolicy, **kwargs) -> List["ClusterNode"]: + def _determine_nodes( + self, *args, request_policy: RequestPolicy, **kwargs + ) -> List["ClusterNode"]: """ Determines a nodes the command should be executed on. """ @@ -1274,7 +1285,9 @@ def _internal_execute_command(self, *args, **kwargs): ) else: if command_flag in self._command_flags_mapping: - command_policies = CommandPolicies(request_policy=self._command_flags_mapping[command_flag]) + command_policies = CommandPolicies( + request_policy=self._command_flags_mapping[command_flag] + ) else: command_policies = CommandPolicies() elif not command_policies and target_nodes_specified: @@ -1297,7 +1310,9 @@ def _internal_execute_command(self, *args, **kwargs): if not target_nodes_specified: # Determine the nodes to execute the command on target_nodes = self._determine_nodes( - *args, request_policy=command_policies.request_policy, nodes_flag=passed_targets + *args, + request_policy=command_policies.request_policy, + nodes_flag=passed_targets, ) if not target_nodes: raise RedisClusterException( @@ -1315,7 +1330,12 @@ def _internal_execute_command(self, *args, **kwargs): break # Return the processed result - return self._process_result(args[0], res, response_policy=command_policies.response_policy, **kwargs) + return self._process_result( + args[0], + res, + response_policy=command_policies.response_policy, + **kwargs, + ) except Exception as e: if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY: if is_default_node: @@ -2348,14 +2368,20 @@ def __init__( SLOT_ID: RequestPolicy.DEFAULT_KEYED, } - self._policies_callback_mapping: dict[Union[RequestPolicy, ResponsePolicy], Callable] = { - RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [self.get_random_primary_or_all_nodes(command_name)], - RequestPolicy.DEFAULT_KEYED: lambda command, *args: self.get_nodes_from_slot(command, *args), + self._policies_callback_mapping: dict[ + Union[RequestPolicy, ResponsePolicy], Callable + ] = { + RequestPolicy.DEFAULT_KEYLESS: lambda command_name: [ + self.get_random_primary_or_all_nodes(command_name) + ], + RequestPolicy.DEFAULT_KEYED: lambda command, + *args: self.get_nodes_from_slot(command, *args), RequestPolicy.DEFAULT_NODE: lambda: [self.get_default_node()], RequestPolicy.ALL_SHARDS: self.get_primaries, RequestPolicy.ALL_NODES: self.get_nodes, RequestPolicy.ALL_REPLICAS: self.get_replicas, - RequestPolicy.MULTI_SHARD: lambda *args, **kwargs: self._split_multi_shard_command(*args, **kwargs), + RequestPolicy.MULTI_SHARD: lambda *args, + **kwargs: self._split_multi_shard_command(*args, **kwargs), RequestPolicy.SPECIAL: self.get_special_nodes, ResponsePolicy.DEFAULT_KEYLESS: lambda res: res, ResponsePolicy.DEFAULT_KEYED: lambda res: res, @@ -2959,7 +2985,11 @@ def _send_cluster_commands( else: if not command_policies: command = c.args[0].upper() - if len(c.args) >= 2 and f"{c.args[0]} {c.args[1]}".upper() in self._pipe.command_flags: + if ( + len(c.args) >= 2 + and f"{c.args[0]} {c.args[1]}".upper() + in self._pipe.command_flags + ): command = f"{c.args[0]} {c.args[1]}".upper() # We only could resolve key properties if command is not @@ -2981,12 +3011,17 @@ def _send_cluster_commands( else: if command_flag in self._pipe._command_flags_mapping: command_policies = CommandPolicies( - request_policy=self._pipe._command_flags_mapping[command_flag]) + request_policy=self._pipe._command_flags_mapping[ + command_flag + ] + ) else: command_policies = CommandPolicies() target_nodes = self._determine_nodes( - *c.args, request_policy=command_policies.request_policy, node_flag=passed_targets + *c.args, + request_policy=command_policies.request_policy, + node_flag=passed_targets, ) if not target_nodes: raise RedisClusterException( @@ -3117,7 +3152,9 @@ def _send_cluster_commands( if c.args[0] in self._pipe.cluster_response_callbacks: # Remove keys entry, it needs only for cache. c.options.pop("keys", None) - c.result = self._pipe._policies_callback_mapping[c.command_policies.response_policy]( + c.result = self._pipe._policies_callback_mapping[ + c.command_policies.response_policy + ]( self._pipe.cluster_response_callbacks[c.args[0]]( c.result, **c.options ) @@ -3152,7 +3189,9 @@ def _parse_target_nodes(self, target_nodes): ) return nodes - def _determine_nodes(self, *args, request_policy: RequestPolicy, **kwargs) -> List["ClusterNode"]: + def _determine_nodes( + self, *args, request_policy: RequestPolicy, **kwargs + ) -> List["ClusterNode"]: # Determine which nodes should be executed the command on. # Returns a list of target nodes. command = args[0].upper() diff --git a/redis/commands/policies.py b/redis/commands/policies.py index ba2cc8968c..4e8998af7f 100644 --- a/redis/commands/policies.py +++ b/redis/commands/policies.py @@ -2,44 +2,128 @@ from abc import ABC, abstractmethod from typing import Optional -from redis._parsers.commands import CommandPolicies, PolicyRecords, RequestPolicy, ResponsePolicy, CommandsParser, \ - AsyncCommandsParser +from redis._parsers.commands import ( + CommandPolicies, + PolicyRecords, + RequestPolicy, + ResponsePolicy, + CommandsParser, + AsyncCommandsParser, +) STATIC_POLICIES: PolicyRecords = { - 'ft': { - 'explaincli': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), - 'suglen': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYED, response_policy=ResponsePolicy.DEFAULT_KEYED), - 'profile': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), - 'dropindex': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), - 'aliasupdate': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), - 'alter': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), - 'aggregate': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), - 'syndump': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), - 'create': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), - 'explain': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), - 'sugget': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYED, response_policy=ResponsePolicy.DEFAULT_KEYED), - 'dictdel': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), - 'aliasadd': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), - 'dictadd': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), - 'synupdate': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), - 'drop': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), - 'info': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), - 'sugadd': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYED, response_policy=ResponsePolicy.DEFAULT_KEYED), - 'dictdump': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), - 'cursor': CommandPolicies(request_policy=RequestPolicy.SPECIAL, response_policy=ResponsePolicy.DEFAULT_KEYLESS), - 'search': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), - 'tagvals': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), - 'aliasdel': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), - 'sugdel': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYED, response_policy=ResponsePolicy.DEFAULT_KEYED), - 'spellcheck': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), + "ft": { + "explaincli": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "suglen": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYED, + response_policy=ResponsePolicy.DEFAULT_KEYED, + ), + "profile": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "dropindex": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "aliasupdate": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "alter": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "aggregate": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "syndump": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "create": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "explain": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "sugget": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYED, + response_policy=ResponsePolicy.DEFAULT_KEYED, + ), + "dictdel": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "aliasadd": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "dictadd": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "synupdate": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "drop": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "info": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "sugadd": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYED, + response_policy=ResponsePolicy.DEFAULT_KEYED, + ), + "dictdump": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "cursor": CommandPolicies( + request_policy=RequestPolicy.SPECIAL, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "search": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "tagvals": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "aliasdel": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + "sugdel": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYED, + response_policy=ResponsePolicy.DEFAULT_KEYED, + ), + "spellcheck": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), + }, + "core": { + "command": CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ), }, - 'core': { - 'command': CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS), - } } -class PolicyResolver(ABC): +class PolicyResolver(ABC): @abstractmethod def resolve(self, command_name: str) -> Optional[CommandPolicies]: """ @@ -66,8 +150,8 @@ def with_fallback(self, fallback: "PolicyResolver") -> "PolicyResolver": """ pass -class AsyncPolicyResolver(ABC): +class AsyncPolicyResolver(ABC): @abstractmethod async def resolve(self, command_name: str) -> Optional[CommandPolicies]: """ @@ -94,11 +178,15 @@ def with_fallback(self, fallback: "AsyncPolicyResolver") -> "AsyncPolicyResolver """ pass + class BasePolicyResolver(PolicyResolver): """ Base class for policy resolvers. """ - def __init__(self, policies: PolicyRecords, fallback: Optional[PolicyResolver] = None) -> None: + + def __init__( + self, policies: PolicyRecords, fallback: Optional[PolicyResolver] = None + ) -> None: self._policies = policies self._fallback = fallback @@ -128,11 +216,15 @@ def resolve(self, command_name: str) -> Optional[CommandPolicies]: def with_fallback(self, fallback: "PolicyResolver") -> "PolicyResolver": pass + class AsyncBasePolicyResolver(AsyncPolicyResolver): """ Async base class for policy resolvers. """ - def __init__(self, policies: PolicyRecords, fallback: Optional[AsyncPolicyResolver] = None) -> None: + + def __init__( + self, policies: PolicyRecords, fallback: Optional[AsyncPolicyResolver] = None + ) -> None: self._policies = policies self._fallback = fallback @@ -167,7 +259,10 @@ class DynamicPolicyResolver(BasePolicyResolver): """ Resolves policy dynamically based on the COMMAND output. """ - def __init__(self, commands_parser: CommandsParser, fallback: Optional[PolicyResolver] = None) -> None: + + def __init__( + self, commands_parser: CommandsParser, fallback: Optional[PolicyResolver] = None + ) -> None: """ Parameters: commands_parser (CommandsParser): COMMAND output parser. @@ -185,6 +280,7 @@ class StaticPolicyResolver(BasePolicyResolver): """ Resolves policy from a static list of policy records. """ + def __init__(self, fallback: Optional[PolicyResolver] = None) -> None: """ Parameters: @@ -196,11 +292,17 @@ def __init__(self, fallback: Optional[PolicyResolver] = None) -> None: def with_fallback(self, fallback: "PolicyResolver") -> "PolicyResolver": return StaticPolicyResolver(fallback) + class AsyncDynamicPolicyResolver(AsyncBasePolicyResolver): """ Async version of DynamicPolicyResolver. """ - def __init__(self, policy_records: PolicyRecords, fallback: Optional[AsyncPolicyResolver] = None) -> None: + + def __init__( + self, + policy_records: PolicyRecords, + fallback: Optional[AsyncPolicyResolver] = None, + ) -> None: """ Parameters: policy_records (PolicyRecords): Policy records. @@ -212,10 +314,12 @@ def __init__(self, policy_records: PolicyRecords, fallback: Optional[AsyncPolicy def with_fallback(self, fallback: "AsyncPolicyResolver") -> "AsyncPolicyResolver": return AsyncDynamicPolicyResolver(self._policies, fallback) + class AsyncStaticPolicyResolver(AsyncBasePolicyResolver): """ Async version of StaticPolicyResolver. """ + def __init__(self, fallback: Optional[AsyncPolicyResolver] = None) -> None: """ Parameters: @@ -225,4 +329,4 @@ def __init__(self, fallback: Optional[AsyncPolicyResolver] = None) -> None: super().__init__(STATIC_POLICIES, fallback) def with_fallback(self, fallback: "AsyncPolicyResolver") -> "AsyncPolicyResolver": - return AsyncStaticPolicyResolver(fallback) \ No newline at end of file + return AsyncStaticPolicyResolver(fallback) diff --git a/redis/exceptions.py b/redis/exceptions.py index 97ccacb354..dab17c5c1f 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -254,8 +254,10 @@ class ExternalAuthProviderError(ConnectionError): pass + class IncorrectPolicyType(Exception): """ Raised when a policy type isn't matching to any known policy types. """ - pass \ No newline at end of file + + pass diff --git a/tests/test_asyncio/test_command_parser.py b/tests/test_asyncio/test_command_parser.py index da714a13d7..0e17596564 100644 --- a/tests/test_asyncio/test_command_parser.py +++ b/tests/test_asyncio/test_command_parser.py @@ -13,47 +13,139 @@ async def test_get_command_policies(self, r): commands_parser = AsyncCommandsParser() await commands_parser.initialize(node=r.get_default_node()) expected_command_policies = { - 'core': { - 'keys': ['keys', RequestPolicy.ALL_SHARDS, ResponsePolicy.DEFAULT_KEYLESS], - 'acl setuser': ['acl setuser', RequestPolicy.ALL_NODES, ResponsePolicy.ALL_SUCCEEDED], - 'exists': ['exists', RequestPolicy.MULTI_SHARD, ResponsePolicy.AGG_SUM], - 'config resetstat': ['config resetstat', RequestPolicy.ALL_NODES, ResponsePolicy.ALL_SUCCEEDED], - 'slowlog len': ['slowlog len', RequestPolicy.ALL_NODES, ResponsePolicy.AGG_SUM], - 'scan': ['scan', RequestPolicy.SPECIAL, ResponsePolicy.SPECIAL], - 'latency history': ['latency history', RequestPolicy.ALL_NODES, ResponsePolicy.SPECIAL], - 'memory doctor': ['memory doctor', RequestPolicy.ALL_SHARDS, ResponsePolicy.SPECIAL], - 'randomkey': ['randomkey', RequestPolicy.ALL_SHARDS, ResponsePolicy.SPECIAL], - 'mget': ['mget', RequestPolicy.MULTI_SHARD, ResponsePolicy.DEFAULT_KEYED], - 'function restore': ['function restore', RequestPolicy.ALL_SHARDS, ResponsePolicy.ALL_SUCCEEDED], + "core": { + "keys": [ + "keys", + RequestPolicy.ALL_SHARDS, + ResponsePolicy.DEFAULT_KEYLESS, + ], + "acl setuser": [ + "acl setuser", + RequestPolicy.ALL_NODES, + ResponsePolicy.ALL_SUCCEEDED, + ], + "exists": ["exists", RequestPolicy.MULTI_SHARD, ResponsePolicy.AGG_SUM], + "config resetstat": [ + "config resetstat", + RequestPolicy.ALL_NODES, + ResponsePolicy.ALL_SUCCEEDED, + ], + "slowlog len": [ + "slowlog len", + RequestPolicy.ALL_NODES, + ResponsePolicy.AGG_SUM, + ], + "scan": ["scan", RequestPolicy.SPECIAL, ResponsePolicy.SPECIAL], + "latency history": [ + "latency history", + RequestPolicy.ALL_NODES, + ResponsePolicy.SPECIAL, + ], + "memory doctor": [ + "memory doctor", + RequestPolicy.ALL_SHARDS, + ResponsePolicy.SPECIAL, + ], + "randomkey": [ + "randomkey", + RequestPolicy.ALL_SHARDS, + ResponsePolicy.SPECIAL, + ], + "mget": [ + "mget", + RequestPolicy.MULTI_SHARD, + ResponsePolicy.DEFAULT_KEYED, + ], + "function restore": [ + "function restore", + RequestPolicy.ALL_SHARDS, + ResponsePolicy.ALL_SUCCEEDED, + ], }, - 'json': { - 'debug': ['debug', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], - 'get': ['get', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + "json": { + "debug": [ + "debug", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + "get": [ + "get", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], }, - 'ft': { - 'search': ['search', RequestPolicy.DEFAULT_KEYLESS, ResponsePolicy.DEFAULT_KEYLESS], - 'create': ['create', RequestPolicy.DEFAULT_KEYLESS, ResponsePolicy.DEFAULT_KEYLESS], + "ft": { + "search": [ + "search", + RequestPolicy.DEFAULT_KEYLESS, + ResponsePolicy.DEFAULT_KEYLESS, + ], + "create": [ + "create", + RequestPolicy.DEFAULT_KEYLESS, + ResponsePolicy.DEFAULT_KEYLESS, + ], }, - 'bf': { - 'add': ['add', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], - 'madd': ['madd', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + "bf": { + "add": [ + "add", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + "madd": [ + "madd", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], }, - 'cf': { - 'add': ['add', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], - 'mexists': ['mexists', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + "cf": { + "add": [ + "add", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + "mexists": [ + "mexists", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], }, - 'tdigest': { - 'add': ['add', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], - 'min': ['min', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + "tdigest": { + "add": [ + "add", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + "min": [ + "min", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], }, - 'ts': { - 'create': ['create', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], - 'info': ['info', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + "ts": { + "create": [ + "create", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + "info": [ + "info", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + }, + "topk": { + "list": [ + "list", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + "query": [ + "query", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], }, - 'topk': { - 'list': ['list', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], - 'query': ['query', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], - } } actual_policies = await commands_parser.get_command_policies() @@ -65,5 +157,5 @@ async def test_get_command_policies(self, r): assert command_policies == [ command, actual_policies[module_name][command].request_policy, - actual_policies[module_name][command].response_policy - ] \ No newline at end of file + actual_policies[module_name][command].response_policy, + ] diff --git a/tests/test_asyncio/test_command_policies.py b/tests/test_asyncio/test_command_policies.py index 2c0f0d2ddb..22683b158d 100644 --- a/tests/test_asyncio/test_command_policies.py +++ b/tests/test_asyncio/test_command_policies.py @@ -6,7 +6,10 @@ from redis import ResponseError from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy from redis.asyncio import RedisCluster -from redis.commands.policies import AsyncDynamicPolicyResolver, AsyncStaticPolicyResolver +from redis.commands.policies import ( + AsyncDynamicPolicyResolver, + AsyncStaticPolicyResolver, +) from redis.commands.search.aggregation import AggregateRequest from redis.commands.search.field import NumericField, TextField @@ -15,44 +18,67 @@ @pytest.mark.onlycluster class TestBasePolicyResolver: async def test_resolve(self): - zcount_policy = CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYED, response_policy=ResponsePolicy.DEFAULT_KEYED) - rpoplpush_policy = CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYED, response_policy=ResponsePolicy.DEFAULT_KEYED) - - dynamic_resolver = AsyncDynamicPolicyResolver({ - 'core': { - 'zcount': zcount_policy, - 'rpoplpush': rpoplpush_policy, + zcount_policy = CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYED, + response_policy=ResponsePolicy.DEFAULT_KEYED, + ) + rpoplpush_policy = CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYED, + response_policy=ResponsePolicy.DEFAULT_KEYED, + ) + + dynamic_resolver = AsyncDynamicPolicyResolver( + { + "core": { + "zcount": zcount_policy, + "rpoplpush": rpoplpush_policy, + } } - }) - assert await dynamic_resolver.resolve('zcount') == zcount_policy - assert await dynamic_resolver.resolve('rpoplpush') == rpoplpush_policy + ) + assert await dynamic_resolver.resolve("zcount") == zcount_policy + assert await dynamic_resolver.resolve("rpoplpush") == rpoplpush_policy - with pytest.raises(ValueError, match="Wrong command or module name: foo.bar.baz"): - await dynamic_resolver.resolve('foo.bar.baz') + with pytest.raises( + ValueError, match="Wrong command or module name: foo.bar.baz" + ): + await dynamic_resolver.resolve("foo.bar.baz") - assert await dynamic_resolver.resolve('foo.bar') is None - assert await dynamic_resolver.resolve('core.foo') is None + assert await dynamic_resolver.resolve("foo.bar") is None + assert await dynamic_resolver.resolve("core.foo") is None # Test that policy fallback correctly static_resolver = AsyncStaticPolicyResolver() with_fallback_dynamic_resolver = dynamic_resolver.with_fallback(static_resolver) - resolved_policies = await with_fallback_dynamic_resolver.resolve('ft.aggregate') + resolved_policies = await with_fallback_dynamic_resolver.resolve("ft.aggregate") assert resolved_policies.request_policy == RequestPolicy.DEFAULT_KEYLESS assert resolved_policies.response_policy == ResponsePolicy.DEFAULT_KEYLESS # Extended chain with one more resolver - foo_bar_policy = CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS) - - another_dynamic_resolver = AsyncDynamicPolicyResolver({ - 'foo': { - 'bar': foo_bar_policy, + foo_bar_policy = CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ) + + another_dynamic_resolver = AsyncDynamicPolicyResolver( + { + "foo": { + "bar": foo_bar_policy, + } } - }) - with_fallback_static_resolver = static_resolver.with_fallback(another_dynamic_resolver) - with_double_fallback_dynamic_resolver = dynamic_resolver.with_fallback(with_fallback_static_resolver) + ) + with_fallback_static_resolver = static_resolver.with_fallback( + another_dynamic_resolver + ) + with_double_fallback_dynamic_resolver = dynamic_resolver.with_fallback( + with_fallback_static_resolver + ) + + assert ( + await with_double_fallback_dynamic_resolver.resolve("foo.bar") + == foo_bar_policy + ) - assert await with_double_fallback_dynamic_resolver.resolve('foo.bar') == foo_bar_policy @pytest.mark.onlycluster @pytest.mark.asyncio @@ -66,13 +92,15 @@ async def test_resolves_correctly_policies(self, r: RedisCluster, monkeypatch): async def wrapper(*args, request_policy: RequestPolicy, **kwargs): nonlocal determined_nodes - determined_nodes = await determine_nodes(*args, request_policy=request_policy, **kwargs) + determined_nodes = await determine_nodes( + *args, request_policy=request_policy, **kwargs + ) return determined_nodes # Mock random.choice to always return a pre-defined sequence of nodes monkeypatch.setattr(random, "choice", lambda seq: seq[next(calls)]) - with patch.object(r, '_determine_nodes', side_effect=wrapper, autospec=True): + with patch.object(r, "_determine_nodes", side_effect=wrapper, autospec=True): # Routed to a random primary node await r.ft().create_index( [ @@ -86,11 +114,11 @@ async def wrapper(*args, request_policy: RequestPolicy, **kwargs): # Routed to another random primary node info = await r.ft().info() - assert info['index_name'] == 'idx' + assert info["index_name"] == "idx" assert determined_nodes[0] == primary_nodes[1] - expected_node = await r.get_nodes_from_slot('FT.SUGLEN', *['foo']) - await r.ft().suglen('foo') + expected_node = await r.get_nodes_from_slot("FT.SUGLEN", *["foo"]) + await r.ft().suglen("foo") assert determined_nodes[0] == expected_node[0] # Indexing a document @@ -122,9 +150,7 @@ async def wrapper(*args, request_policy: RequestPolicy, **kwargs): }, ) - req = AggregateRequest("redis").group_by( - "@parent" - ).cursor(1) + req = AggregateRequest("redis").group_by("@parent").cursor(1) res = await r.ft().aggregate(req) cursor = res.cursor @@ -144,4 +170,4 @@ async def wrapper(*args, request_policy: RequestPolicy, **kwargs): # Core commands also randomly distributed across masters await r.randomkey() - assert determined_nodes[0] == primary_nodes[0] \ No newline at end of file + assert determined_nodes[0] == primary_nodes[0] diff --git a/tests/test_command_parser.py b/tests/test_command_parser.py index 6be43e5823..32478e04c3 100644 --- a/tests/test_command_parser.py +++ b/tests/test_command_parser.py @@ -115,47 +115,139 @@ def test_get_pubsub_keys(self, r): def test_get_command_policies(self, r): commands_parser = CommandsParser(r) expected_command_policies = { - 'core': { - 'keys': ['keys', RequestPolicy.ALL_SHARDS, ResponsePolicy.DEFAULT_KEYLESS], - 'acl setuser': ['acl setuser', RequestPolicy.ALL_NODES, ResponsePolicy.ALL_SUCCEEDED], - 'exists': ['exists', RequestPolicy.MULTI_SHARD, ResponsePolicy.AGG_SUM], - 'config resetstat': ['config resetstat', RequestPolicy.ALL_NODES, ResponsePolicy.ALL_SUCCEEDED], - 'slowlog len': ['slowlog len', RequestPolicy.ALL_NODES, ResponsePolicy.AGG_SUM], - 'scan': ['scan', RequestPolicy.SPECIAL, ResponsePolicy.SPECIAL], - 'latency history': ['latency history', RequestPolicy.ALL_NODES, ResponsePolicy.SPECIAL], - 'memory doctor': ['memory doctor', RequestPolicy.ALL_SHARDS, ResponsePolicy.SPECIAL], - 'randomkey': ['randomkey', RequestPolicy.ALL_SHARDS, ResponsePolicy.SPECIAL], - 'mget': ['mget', RequestPolicy.MULTI_SHARD, ResponsePolicy.DEFAULT_KEYED], - 'function restore': ['function restore', RequestPolicy.ALL_SHARDS, ResponsePolicy.ALL_SUCCEEDED], + "core": { + "keys": [ + "keys", + RequestPolicy.ALL_SHARDS, + ResponsePolicy.DEFAULT_KEYLESS, + ], + "acl setuser": [ + "acl setuser", + RequestPolicy.ALL_NODES, + ResponsePolicy.ALL_SUCCEEDED, + ], + "exists": ["exists", RequestPolicy.MULTI_SHARD, ResponsePolicy.AGG_SUM], + "config resetstat": [ + "config resetstat", + RequestPolicy.ALL_NODES, + ResponsePolicy.ALL_SUCCEEDED, + ], + "slowlog len": [ + "slowlog len", + RequestPolicy.ALL_NODES, + ResponsePolicy.AGG_SUM, + ], + "scan": ["scan", RequestPolicy.SPECIAL, ResponsePolicy.SPECIAL], + "latency history": [ + "latency history", + RequestPolicy.ALL_NODES, + ResponsePolicy.SPECIAL, + ], + "memory doctor": [ + "memory doctor", + RequestPolicy.ALL_SHARDS, + ResponsePolicy.SPECIAL, + ], + "randomkey": [ + "randomkey", + RequestPolicy.ALL_SHARDS, + ResponsePolicy.SPECIAL, + ], + "mget": [ + "mget", + RequestPolicy.MULTI_SHARD, + ResponsePolicy.DEFAULT_KEYED, + ], + "function restore": [ + "function restore", + RequestPolicy.ALL_SHARDS, + ResponsePolicy.ALL_SUCCEEDED, + ], }, - 'json': { - 'debug': ['debug', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], - 'get': ['get', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + "json": { + "debug": [ + "debug", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + "get": [ + "get", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], }, - 'ft': { - 'search': ['search', RequestPolicy.DEFAULT_KEYLESS, ResponsePolicy.DEFAULT_KEYLESS], - 'create': ['create', RequestPolicy.DEFAULT_KEYLESS, ResponsePolicy.DEFAULT_KEYLESS], + "ft": { + "search": [ + "search", + RequestPolicy.DEFAULT_KEYLESS, + ResponsePolicy.DEFAULT_KEYLESS, + ], + "create": [ + "create", + RequestPolicy.DEFAULT_KEYLESS, + ResponsePolicy.DEFAULT_KEYLESS, + ], }, - 'bf': { - 'add': ['add', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], - 'madd': ['madd', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + "bf": { + "add": [ + "add", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + "madd": [ + "madd", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], }, - 'cf': { - 'add': ['add', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], - 'mexists': ['mexists', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + "cf": { + "add": [ + "add", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + "mexists": [ + "mexists", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], }, - 'tdigest': { - 'add': ['add', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], - 'min': ['min', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + "tdigest": { + "add": [ + "add", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + "min": [ + "min", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], }, - 'ts': { - 'create': ['create', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], - 'info': ['info', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], + "ts": { + "create": [ + "create", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + "info": [ + "info", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + }, + "topk": { + "list": [ + "list", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], + "query": [ + "query", + RequestPolicy.DEFAULT_KEYED, + ResponsePolicy.DEFAULT_KEYED, + ], }, - 'topk': { - 'list': ['list', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], - 'query': ['query', RequestPolicy.DEFAULT_KEYED, ResponsePolicy.DEFAULT_KEYED], - } } actual_policies = commands_parser.get_command_policies() @@ -165,7 +257,7 @@ def test_get_command_policies(self, r): for command, command_policies in commands.items(): assert command in actual_policies[module_name] assert command_policies == [ - command, - actual_policies[module_name][command].request_policy, - actual_policies[module_name][command].response_policy - ] \ No newline at end of file + command, + actual_policies[module_name][command].request_policy, + actual_policies[module_name][command].response_policy, + ] diff --git a/tests/test_command_policies.py b/tests/test_command_policies.py index ca3ecb1036..5f222eb5ea 100644 --- a/tests/test_command_policies.py +++ b/tests/test_command_policies.py @@ -11,51 +11,76 @@ from redis.commands.search.aggregation import AggregateRequest from redis.commands.search.field import TextField, NumericField + @pytest.mark.onlycluster class TestBasePolicyResolver: def test_resolve(self): mock_command_parser = Mock(spec=CommandsParser) - zcount_policy = CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYED, response_policy=ResponsePolicy.DEFAULT_KEYED) - rpoplpush_policy = CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYED, response_policy=ResponsePolicy.DEFAULT_KEYED) + zcount_policy = CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYED, + response_policy=ResponsePolicy.DEFAULT_KEYED, + ) + rpoplpush_policy = CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYED, + response_policy=ResponsePolicy.DEFAULT_KEYED, + ) mock_command_parser.get_command_policies.return_value = { - 'core': { - 'zcount': zcount_policy, - 'rpoplpush': rpoplpush_policy, + "core": { + "zcount": zcount_policy, + "rpoplpush": rpoplpush_policy, } } dynamic_resolver = DynamicPolicyResolver(mock_command_parser) - assert dynamic_resolver.resolve('zcount') == zcount_policy - assert dynamic_resolver.resolve('rpoplpush') == rpoplpush_policy + assert dynamic_resolver.resolve("zcount") == zcount_policy + assert dynamic_resolver.resolve("rpoplpush") == rpoplpush_policy - with pytest.raises(ValueError, match="Wrong command or module name: foo.bar.baz"): - dynamic_resolver.resolve('foo.bar.baz') + with pytest.raises( + ValueError, match="Wrong command or module name: foo.bar.baz" + ): + dynamic_resolver.resolve("foo.bar.baz") - assert dynamic_resolver.resolve('foo.bar') is None - assert dynamic_resolver.resolve('core.foo') is None + assert dynamic_resolver.resolve("foo.bar") is None + assert dynamic_resolver.resolve("core.foo") is None # Test that policy fallback correctly static_resolver = StaticPolicyResolver() with_fallback_dynamic_resolver = dynamic_resolver.with_fallback(static_resolver) - assert with_fallback_dynamic_resolver.resolve('ft.aggregate').request_policy == RequestPolicy.DEFAULT_KEYLESS - assert with_fallback_dynamic_resolver.resolve('ft.aggregate').response_policy == ResponsePolicy.DEFAULT_KEYLESS + assert ( + with_fallback_dynamic_resolver.resolve("ft.aggregate").request_policy + == RequestPolicy.DEFAULT_KEYLESS + ) + assert ( + with_fallback_dynamic_resolver.resolve("ft.aggregate").response_policy + == ResponsePolicy.DEFAULT_KEYLESS + ) # Extended chain with one more resolver mock_command_parser = Mock(spec=CommandsParser) - foo_bar_policy = CommandPolicies(request_policy=RequestPolicy.DEFAULT_KEYLESS, response_policy=ResponsePolicy.DEFAULT_KEYLESS) + foo_bar_policy = CommandPolicies( + request_policy=RequestPolicy.DEFAULT_KEYLESS, + response_policy=ResponsePolicy.DEFAULT_KEYLESS, + ) mock_command_parser.get_command_policies.return_value = { - 'foo': { - 'bar': foo_bar_policy, + "foo": { + "bar": foo_bar_policy, } } another_dynamic_resolver = DynamicPolicyResolver(mock_command_parser) - with_fallback_static_resolver = static_resolver.with_fallback(another_dynamic_resolver) - with_double_fallback_dynamic_resolver = dynamic_resolver.with_fallback(with_fallback_static_resolver) + with_fallback_static_resolver = static_resolver.with_fallback( + another_dynamic_resolver + ) + with_double_fallback_dynamic_resolver = dynamic_resolver.with_fallback( + with_fallback_static_resolver + ) + + assert ( + with_double_fallback_dynamic_resolver.resolve("foo.bar") == foo_bar_policy + ) - assert with_double_fallback_dynamic_resolver.resolve('foo.bar') == foo_bar_policy @pytest.mark.onlycluster class TestClusterWithPolicies: @@ -68,13 +93,15 @@ def test_resolves_correctly_policies(self, r, monkeypatch): def wrapper(*args, request_policy: RequestPolicy, **kwargs): nonlocal determined_nodes - determined_nodes = determine_nodes(*args, request_policy=request_policy, **kwargs) + determined_nodes = determine_nodes( + *args, request_policy=request_policy, **kwargs + ) return determined_nodes # Mock random.choice to always return a pre-defined sequence of nodes monkeypatch.setattr(random, "choice", lambda seq: seq[next(calls)]) - with patch.object(r, '_determine_nodes', side_effect=wrapper, autospec=True): + with patch.object(r, "_determine_nodes", side_effect=wrapper, autospec=True): # Routed to a random primary node r.ft().create_index( ( @@ -88,11 +115,11 @@ def wrapper(*args, request_policy: RequestPolicy, **kwargs): # Routed to another random primary node info = r.ft().info() - assert info['index_name'] == 'idx' + assert info["index_name"] == "idx" assert determined_nodes[0] == primary_nodes[1] - expected_node = r.get_nodes_from_slot('ft.suglen', *['FT.SUGLEN', 'foo']) - r.ft().suglen('foo') + expected_node = r.get_nodes_from_slot("ft.suglen", *["FT.SUGLEN", "foo"]) + r.ft().suglen("foo") assert determined_nodes[0] == expected_node[0] # Indexing a document @@ -124,9 +151,7 @@ def wrapper(*args, request_policy: RequestPolicy, **kwargs): }, ) - req = AggregateRequest("redis").group_by( - "@parent" - ).cursor(1) + req = AggregateRequest("redis").group_by("@parent").cursor(1) cursor = r.ft().aggregate(req).cursor # Ensure that aggregate node was cached. @@ -145,4 +170,4 @@ def wrapper(*args, request_policy: RequestPolicy, **kwargs): # Core commands also randomly distributed across masters r.randomkey() - assert determined_nodes[0] == primary_nodes[0] \ No newline at end of file + assert determined_nodes[0] == primary_nodes[0] From 9d11768fbb7a53bfc2dc08adf317203863e70b7b Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 4 Nov 2025 09:37:31 +0200 Subject: [PATCH 05/13] Codestyle changes --- redis/_parsers/commands.py | 5 ++--- redis/asyncio/cluster.py | 2 +- redis/cluster.py | 2 +- redis/commands/policies.py | 4 +--- tests/test_command_parser.py | 1 - 5 files changed, 5 insertions(+), 9 deletions(-) diff --git a/redis/_parsers/commands.py b/redis/_parsers/commands.py index cbad42bd79..9ec50a240f 100644 --- a/redis/_parsers/commands.py +++ b/redis/_parsers/commands.py @@ -1,8 +1,7 @@ -from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, Awaitable +from typing import TYPE_CHECKING, Any, Awaitable, Dict, Optional, Tuple, Union -from redis.exceptions import RedisError, ResponseError, IncorrectPolicyType +from redis.exceptions import IncorrectPolicyType, RedisError, ResponseError from redis.utils import str_if_bytes if TYPE_CHECKING: diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 09a086aa31..d70569bb95 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -26,7 +26,7 @@ ) from redis._parsers import AsyncCommandsParser, Encoder -from redis._parsers.commands import RequestPolicy, ResponsePolicy, CommandPolicies +from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy from redis._parsers.helpers import ( _RedisCallbacks, _RedisCallbacksRESP2, diff --git a/redis/cluster.py b/redis/cluster.py index 9765dc59d2..33b54b1bed 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -11,7 +11,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from redis._parsers import CommandsParser, Encoder -from redis._parsers.commands import RequestPolicy, CommandPolicies, ResponsePolicy +from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy from redis._parsers.helpers import parse_scan from redis.backoff import ExponentialWithJitterBackoff, NoBackoff from redis.cache import CacheConfig, CacheFactory, CacheFactoryInterface, CacheInterface diff --git a/redis/commands/policies.py b/redis/commands/policies.py index 4e8998af7f..c0c98d37f1 100644 --- a/redis/commands/policies.py +++ b/redis/commands/policies.py @@ -1,14 +1,12 @@ -import asyncio from abc import ABC, abstractmethod from typing import Optional from redis._parsers.commands import ( CommandPolicies, + CommandsParser, PolicyRecords, RequestPolicy, ResponsePolicy, - CommandsParser, - AsyncCommandsParser, ) STATIC_POLICIES: PolicyRecords = { diff --git a/tests/test_command_parser.py b/tests/test_command_parser.py index 32478e04c3..169963d786 100644 --- a/tests/test_command_parser.py +++ b/tests/test_command_parser.py @@ -1,4 +1,3 @@ -from pprint import pprint import pytest from redis._parsers import CommandsParser From 70b743d1fbdbdf474d8b805cf773ccb6513c5470 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 4 Nov 2025 09:39:19 +0200 Subject: [PATCH 06/13] Codestyle changes --- tests/test_command_parser.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_command_parser.py b/tests/test_command_parser.py index 169963d786..fd4847ef66 100644 --- a/tests/test_command_parser.py +++ b/tests/test_command_parser.py @@ -1,4 +1,3 @@ - import pytest from redis._parsers import CommandsParser from redis._parsers.commands import RequestPolicy, ResponsePolicy From 898b842f9df6dcf57aecab73b1fa3b862a078e5d Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 4 Nov 2025 09:48:15 +0200 Subject: [PATCH 07/13] Revert changes --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 0c99175740..9d2f51795a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,7 +31,7 @@ from tests.ssl_utils import get_tls_certificates REDIS_INFO = {} -default_redis_url = "redis://localhost:16379/0" +default_redis_url = "redis://localhost:6379/0" default_protocol = "2" default_redismod_url = "redis://localhost:6479" From c9f10f3f0029856e257a760ece66fc0f6793a61a Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 4 Nov 2025 10:25:24 +0200 Subject: [PATCH 08/13] Marked tests with Redis 8.0 --- tests/test_asyncio/test_command_parser.py | 4 ++-- tests/test_asyncio/test_command_policies.py | 2 ++ tests/test_command_parser.py | 2 +- tests/test_command_policies.py | 2 ++ 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/test_asyncio/test_command_parser.py b/tests/test_asyncio/test_command_parser.py index 0e17596564..430a72f885 100644 --- a/tests/test_asyncio/test_command_parser.py +++ b/tests/test_asyncio/test_command_parser.py @@ -5,9 +5,9 @@ from tests.conftest import skip_if_server_version_lt +@pytest.mark.onlycluster +@skip_if_server_version_lt("8.0.0") class TestAsyncCommandParser: - @skip_if_server_version_lt("7.0.0") - @pytest.mark.onlycluster @pytest.mark.asyncio async def test_get_command_policies(self, r): commands_parser = AsyncCommandsParser() diff --git a/tests/test_asyncio/test_command_policies.py b/tests/test_asyncio/test_command_policies.py index 22683b158d..64de14a3f0 100644 --- a/tests/test_asyncio/test_command_policies.py +++ b/tests/test_asyncio/test_command_policies.py @@ -12,6 +12,7 @@ ) from redis.commands.search.aggregation import AggregateRequest from redis.commands.search.field import NumericField, TextField +from tests.conftest import skip_if_server_version_lt @pytest.mark.asyncio @@ -82,6 +83,7 @@ async def test_resolve(self): @pytest.mark.onlycluster @pytest.mark.asyncio +@skip_if_server_version_lt("8.0.0") class TestClusterWithPolicies: async def test_resolves_correctly_policies(self, r: RedisCluster, monkeypatch): # original nodes selection method diff --git a/tests/test_command_parser.py b/tests/test_command_parser.py index fd4847ef66..26b0edc238 100644 --- a/tests/test_command_parser.py +++ b/tests/test_command_parser.py @@ -108,7 +108,7 @@ def test_get_pubsub_keys(self, r): assert commands_parser.get_keys(r, *args3) == ["*"] assert commands_parser.get_keys(r, *args4) == ["foo1", "foo2", "foo3"] - @skip_if_server_version_lt("7.0.0") + @skip_if_server_version_lt("8.0.0") @pytest.mark.onlycluster def test_get_command_policies(self, r): commands_parser = CommandsParser(r) diff --git a/tests/test_command_policies.py b/tests/test_command_policies.py index 5f222eb5ea..02082078f1 100644 --- a/tests/test_command_policies.py +++ b/tests/test_command_policies.py @@ -10,6 +10,7 @@ from redis.commands.policies import DynamicPolicyResolver, StaticPolicyResolver from redis.commands.search.aggregation import AggregateRequest from redis.commands.search.field import TextField, NumericField +from tests.conftest import skip_if_server_version_lt @pytest.mark.onlycluster @@ -83,6 +84,7 @@ def test_resolve(self): @pytest.mark.onlycluster +@skip_if_server_version_lt("8.0.0") class TestClusterWithPolicies: def test_resolves_correctly_policies(self, r, monkeypatch): # original nodes selection method From 8b1644773a58f48d033b7fa61b2a9b655496f896 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 4 Nov 2025 11:26:03 +0200 Subject: [PATCH 09/13] Added timeouts for index creation --- tests/test_asyncio/test_command_policies.py | 4 ++++ tests/test_command_policies.py | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/tests/test_asyncio/test_command_policies.py b/tests/test_asyncio/test_command_policies.py index 64de14a3f0..5c3c8ac010 100644 --- a/tests/test_asyncio/test_command_policies.py +++ b/tests/test_asyncio/test_command_policies.py @@ -1,3 +1,4 @@ +import asyncio import random import pytest @@ -114,6 +115,9 @@ async def wrapper(*args, request_policy: RequestPolicy, **kwargs): ) assert determined_nodes[0] == primary_nodes[0] + # Wait for index creation + await asyncio.sleep(1) + # Routed to another random primary node info = await r.ft().info() assert info["index_name"] == "idx" diff --git a/tests/test_command_policies.py b/tests/test_command_policies.py index 02082078f1..8b73c86c28 100644 --- a/tests/test_command_policies.py +++ b/tests/test_command_policies.py @@ -1,4 +1,5 @@ import random +from time import sleep from unittest.mock import Mock, patch import pytest @@ -11,6 +12,7 @@ from redis.commands.search.aggregation import AggregateRequest from redis.commands.search.field import TextField, NumericField from tests.conftest import skip_if_server_version_lt +from tests.test_search import waitForIndex @pytest.mark.onlycluster @@ -115,6 +117,9 @@ def wrapper(*args, request_policy: RequestPolicy, **kwargs): ) assert determined_nodes[0] == primary_nodes[0] + # Wait for index creation + sleep(1) + # Routed to another random primary node info = r.ft().info() assert info["index_name"] == "idx" From b78101ca0afc32ca6859e67f6834f67b43366bc6 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 4 Nov 2025 11:28:18 +0200 Subject: [PATCH 10/13] Codestyle changes --- tests/test_command_policies.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_command_policies.py b/tests/test_command_policies.py index 8b73c86c28..cace819e29 100644 --- a/tests/test_command_policies.py +++ b/tests/test_command_policies.py @@ -12,7 +12,6 @@ from redis.commands.search.aggregation import AggregateRequest from redis.commands.search.field import TextField, NumericField from tests.conftest import skip_if_server_version_lt -from tests.test_search import waitForIndex @pytest.mark.onlycluster From f2910fbc98fae2af4f93cdd9045fbb9729387bbc Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 4 Nov 2025 12:30:37 +0200 Subject: [PATCH 11/13] Fixed RESP3 responses --- tests/test_asyncio/test_command_policies.py | 22 +++++++++++++-------- tests/test_command_policies.py | 21 ++++++++++++-------- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/tests/test_asyncio/test_command_policies.py b/tests/test_asyncio/test_command_policies.py index 5c3c8ac010..ad6bbfc75b 100644 --- a/tests/test_asyncio/test_command_policies.py +++ b/tests/test_asyncio/test_command_policies.py @@ -11,9 +11,9 @@ AsyncDynamicPolicyResolver, AsyncStaticPolicyResolver, ) -from redis.commands.search.aggregation import AggregateRequest +from redis.commands.search.aggregation import AggregateRequest, Cursor from redis.commands.search.field import NumericField, TextField -from tests.conftest import skip_if_server_version_lt +from tests.conftest import skip_if_server_version_lt, is_resp2_connection @pytest.mark.asyncio @@ -115,12 +115,14 @@ async def wrapper(*args, request_policy: RequestPolicy, **kwargs): ) assert determined_nodes[0] == primary_nodes[0] - # Wait for index creation - await asyncio.sleep(1) - # Routed to another random primary node info = await r.ft().info() - assert info["index_name"] == "idx" + + if is_resp2_connection(r): + assert info["index_name"] == "idx" + else: + assert info[b"index_name"] == b"idx" + assert determined_nodes[0] == primary_nodes[1] expected_node = await r.get_nodes_from_slot("FT.SUGLEN", *["foo"]) @@ -158,7 +160,11 @@ async def wrapper(*args, request_policy: RequestPolicy, **kwargs): req = AggregateRequest("redis").group_by("@parent").cursor(1) res = await r.ft().aggregate(req) - cursor = res.cursor + + if is_resp2_connection(r): + cursor = res.cursor + else: + cursor = Cursor(res[1]) # Ensure that aggregate node was cached. assert determined_nodes[0] == r._aggregate_nodes[0] @@ -169,7 +175,7 @@ async def wrapper(*args, request_policy: RequestPolicy, **kwargs): assert determined_nodes[0] == r._aggregate_nodes[0] # Error propagates to a user - with pytest.raises(ResponseError, match="Cursor not found, id: 0"): + with pytest.raises(ResponseError, match="Cursor not found, id:"): await r.ft().aggregate(cursor) assert determined_nodes[0] == primary_nodes[2] diff --git a/tests/test_command_policies.py b/tests/test_command_policies.py index cace819e29..90828c618c 100644 --- a/tests/test_command_policies.py +++ b/tests/test_command_policies.py @@ -9,9 +9,9 @@ from redis._parsers import CommandsParser from redis._parsers.commands import CommandPolicies, RequestPolicy, ResponsePolicy from redis.commands.policies import DynamicPolicyResolver, StaticPolicyResolver -from redis.commands.search.aggregation import AggregateRequest +from redis.commands.search.aggregation import AggregateRequest, Cursor from redis.commands.search.field import TextField, NumericField -from tests.conftest import skip_if_server_version_lt +from tests.conftest import skip_if_server_version_lt, is_resp2_connection @pytest.mark.onlycluster @@ -116,12 +116,13 @@ def wrapper(*args, request_policy: RequestPolicy, **kwargs): ) assert determined_nodes[0] == primary_nodes[0] - # Wait for index creation - sleep(1) - # Routed to another random primary node info = r.ft().info() - assert info["index_name"] == "idx" + if is_resp2_connection(r): + assert info["index_name"] == "idx" + else: + assert info[b"index_name"] == b"idx" + assert determined_nodes[0] == primary_nodes[1] expected_node = r.get_nodes_from_slot("ft.suglen", *["FT.SUGLEN", "foo"]) @@ -158,7 +159,11 @@ def wrapper(*args, request_policy: RequestPolicy, **kwargs): ) req = AggregateRequest("redis").group_by("@parent").cursor(1) - cursor = r.ft().aggregate(req).cursor + + if is_resp2_connection(r): + cursor = r.ft().aggregate(req).cursor + else: + cursor = Cursor(r.ft().aggregate(req)[1]) # Ensure that aggregate node was cached. assert determined_nodes[0] == r._aggregate_nodes[0] @@ -169,7 +174,7 @@ def wrapper(*args, request_policy: RequestPolicy, **kwargs): assert determined_nodes[0] == r._aggregate_nodes[0] # Error propagates to a user - with pytest.raises(ResponseError, match="Cursor not found, id: 0"): + with pytest.raises(ResponseError, match="Cursor not found, id:"): r.ft().aggregate(cursor) assert determined_nodes[0] == primary_nodes[2] From c18d6c991457edf296d3b2d4737df379046ba334 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 4 Nov 2025 12:32:20 +0200 Subject: [PATCH 12/13] Codestyle changes --- tests/test_asyncio/test_command_policies.py | 1 - tests/test_command_policies.py | 1 - 2 files changed, 2 deletions(-) diff --git a/tests/test_asyncio/test_command_policies.py b/tests/test_asyncio/test_command_policies.py index ad6bbfc75b..7a52f256c9 100644 --- a/tests/test_asyncio/test_command_policies.py +++ b/tests/test_asyncio/test_command_policies.py @@ -1,4 +1,3 @@ -import asyncio import random import pytest diff --git a/tests/test_command_policies.py b/tests/test_command_policies.py index 90828c618c..633134a1b6 100644 --- a/tests/test_command_policies.py +++ b/tests/test_command_policies.py @@ -1,5 +1,4 @@ import random -from time import sleep from unittest.mock import Mock, patch import pytest From fd75d139165a0e06f5138b42660732189f12d7eb Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 4 Nov 2025 16:15:05 +0200 Subject: [PATCH 13/13] Added additional option for cluster --- docker-compose.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index 46c70ba5a9..ef7fd813b3 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -4,7 +4,7 @@ x-client-libs-stack-image: &client-libs-stack-image image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_STACK_IMAGE_TAG:-8.2}" x-client-libs-image: &client-libs-image - image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_IMAGE_TAG:-8.2}" + image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_IMAGE_TAG:-8.4-RC1-pre.2}" services: @@ -58,7 +58,7 @@ services: - TLS_ENABLED=yes - PORT=16379 - TLS_PORT=27379 - command: ${REDIS_EXTRA_ARGS:---enable-debug-command yes --enable-module-command yes --tls-auth-clients optional --save ""} + command: ${REDIS_EXTRA_ARGS:---enable-debug-command yes --enable-module-command yes --tls-auth-clients optional --save "" --tls-cluster yes} ports: - "16379-16384:16379-16384" - "27379-27384:27379-27384"