@@ -500,28 +500,37 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
500500 ):
501501 raise error
502502
503- # COMMAND EXECUTION AND PROTOCOL PARSING
504- async def execute_command (self , * args , ** options ):
505- """Execute a command and return a parsed response"""
506- await self .initialize ()
507- pool = self .connection_pool
508- command_name = args [0 ]
509- conn = self .connection or await pool .get_connection (command_name , ** options )
510-
511- if self .single_connection_client :
512- await self ._single_conn_lock .acquire ()
503+ async def _try_send_command_parse_response (self , conn , * args , ** options ):
513504 try :
514505 return await conn .retry .call_with_retry (
515506 lambda : self ._send_command_parse_response (
516- conn , command_name , * args , ** options
507+ conn , args [ 0 ] , * args , ** options
517508 ),
518509 lambda error : self ._disconnect_raise (conn , error ),
519510 )
511+ except asyncio .CancelledError :
512+ await conn .disconnect (nowait = True )
513+ raise
520514 finally :
521515 if self .single_connection_client :
522516 self ._single_conn_lock .release ()
523517 if not self .connection :
524- await pool .release (conn )
518+ await self .connection_pool .release (conn )
519+
520+ # COMMAND EXECUTION AND PROTOCOL PARSING
521+ async def execute_command (self , * args , ** options ):
522+ """Execute a command and return a parsed response"""
523+ await self .initialize ()
524+ pool = self .connection_pool
525+ command_name = args [0 ]
526+ conn = self .connection or await pool .get_connection (command_name , ** options )
527+
528+ if self .single_connection_client :
529+ await self ._single_conn_lock .acquire ()
530+
531+ return await asyncio .shield (
532+ self ._try_send_command_parse_response (conn , * args , ** options )
533+ )
525534
526535 async def parse_response (
527536 self , connection : Connection , command_name : Union [str , bytes ], ** options
@@ -765,10 +774,18 @@ async def _disconnect_raise_connect(self, conn, error):
765774 is not a TimeoutError. Otherwise, try to reconnect
766775 """
767776 await conn .disconnect ()
777+
768778 if not (conn .retry_on_timeout and isinstance (error , TimeoutError )):
769779 raise error
770780 await conn .connect ()
771781
782+ async def _try_execute (self , conn , command , * arg , ** kwargs ):
783+ try :
784+ return await command (* arg , ** kwargs )
785+ except asyncio .CancelledError :
786+ await conn .disconnect ()
787+ raise
788+
772789 async def _execute (self , conn , command , * args , ** kwargs ):
773790 """
774791 Connect manually upon disconnection. If the Redis server is down,
@@ -777,9 +794,11 @@ async def _execute(self, conn, command, *args, **kwargs):
777794 called by the # connection to resubscribe us to any channels and
778795 patterns we were previously listening to
779796 """
780- return await conn .retry .call_with_retry (
781- lambda : command (* args , ** kwargs ),
782- lambda error : self ._disconnect_raise_connect (conn , error ),
797+ return await asyncio .shield (
798+ conn .retry .call_with_retry (
799+ lambda : self ._try_execute (conn , command , * args , ** kwargs ),
800+ lambda error : self ._disconnect_raise_connect (conn , error ),
801+ )
783802 )
784803
785804 async def parse_response (self , block : bool = True , timeout : float = 0 ):
@@ -1181,6 +1200,18 @@ async def _disconnect_reset_raise(self, conn, error):
11811200 await self .reset ()
11821201 raise
11831202
1203+ async def _try_send_command_parse_response (self , conn , * args , ** options ):
1204+ try :
1205+ return await conn .retry .call_with_retry (
1206+ lambda : self ._send_command_parse_response (
1207+ conn , args [0 ], * args , ** options
1208+ ),
1209+ lambda error : self ._disconnect_reset_raise (conn , error ),
1210+ )
1211+ except asyncio .CancelledError :
1212+ await conn .disconnect ()
1213+ raise
1214+
11841215 async def immediate_execute_command (self , * args , ** options ):
11851216 """
11861217 Execute a command immediately, but don't auto-retry on a
@@ -1196,13 +1227,13 @@ async def immediate_execute_command(self, *args, **options):
11961227 command_name , self .shard_hint
11971228 )
11981229 self .connection = conn
1199-
1200- return await conn . retry . call_with_retry (
1201- lambda : self ._send_command_parse_response (
1202- conn , command_name , * args , ** options
1203- ),
1204- lambda error : self . _disconnect_reset_raise ( conn , error ),
1205- )
1230+ try :
1231+ return await asyncio . shield (
1232+ self ._try_send_command_parse_response ( conn , * args , ** options )
1233+ )
1234+ except asyncio . CancelledError :
1235+ await conn . disconnect ()
1236+ raise
12061237
12071238 def pipeline_execute_command (self , * args , ** options ):
12081239 """
@@ -1369,6 +1400,19 @@ async def _disconnect_raise_reset(self, conn: Connection, error: Exception):
13691400 await self .reset ()
13701401 raise
13711402
1403+ async def _try_execute (self , conn , execute , stack , raise_on_error ):
1404+ try :
1405+ return await conn .retry .call_with_retry (
1406+ lambda : execute (conn , stack , raise_on_error ),
1407+ lambda error : self ._disconnect_raise_reset (conn , error ),
1408+ )
1409+ except asyncio .CancelledError :
1410+ # not supposed to be possible, yet here we are
1411+ await conn .disconnect (nowait = True )
1412+ raise
1413+ finally :
1414+ await self .reset ()
1415+
13721416 async def execute (self , raise_on_error : bool = True ):
13731417 """Execute all the commands in the current pipeline"""
13741418 stack = self .command_stack
@@ -1391,15 +1435,10 @@ async def execute(self, raise_on_error: bool = True):
13911435
13921436 try :
13931437 return await asyncio .shield (
1394- conn .retry .call_with_retry (
1395- lambda : execute (conn , stack , raise_on_error ),
1396- lambda error : self ._disconnect_raise_reset (conn , error ),
1397- )
1438+ self ._try_execute (conn , execute , stack , raise_on_error )
13981439 )
1399- except asyncio .CancelledError :
1400- # not supposed to be possible, yet here we are
1401- await conn .disconnect (nowait = True )
1402- raise
1440+ except RuntimeError :
1441+ await self .reset ()
14031442 finally :
14041443 await self .reset ()
14051444
0 commit comments