1111import os
1212import socket
1313import struct
14+ import time
1415import urllib .parse
1516
1617from . import cursor
@@ -60,6 +61,22 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
6061 self ._stmt_cache = collections .OrderedDict ()
6162 self ._stmts_to_close = set ()
6263
64+ if command_timeout is not None :
65+ try :
66+ if isinstance (command_timeout , bool ):
67+ raise ValueError
68+
69+ command_timeout = float (command_timeout )
70+
71+ if command_timeout < 0 :
72+ raise ValueError
73+
74+ except ValueError :
75+ raise ValueError (
76+ 'invalid command_timeout value: '
77+ 'expected non-negative float (got {!r})' .format (
78+ command_timeout )) from None
79+
6380 self ._command_timeout = command_timeout
6481
6582 self ._listeners = {}
@@ -187,7 +204,7 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
187204 if not args :
188205 return await self ._protocol .query (query , timeout )
189206
190- _ , status , _ = await self ._do_execute (query , args , 0 , timeout , True )
207+ _ , status , _ = await self ._execute (query , args , 0 , timeout , True )
191208 return status .decode ()
192209
193210 async def executemany (self , command : str , args , timeout : float = None ):
@@ -208,8 +225,7 @@ async def executemany(self, command: str, args, timeout: float=None):
208225
209226 .. versionadded:: 0.7.0
210227 """
211- stmt = await self ._get_statement (command , timeout )
212- return await self ._protocol .bind_execute_many (stmt , args , '' , timeout )
228+ return await self ._executemany (command , args , timeout )
213229
214230 async def _get_statement (self , query , timeout ):
215231 cache = self ._stmt_cache_max_size > 0
@@ -281,7 +297,7 @@ async def fetch(self, query, *args, timeout=None) -> list:
281297
282298 :return list: A list of :class:`Record` instances.
283299 """
284- return await self ._do_execute (query , args , 0 , timeout )
300+ return await self ._execute (query , args , 0 , timeout )
285301
286302 async def fetchval (self , query , * args , column = 0 , timeout = None ):
287303 """Run a query and return a value in the first row.
@@ -297,7 +313,7 @@ async def fetchval(self, query, *args, column=0, timeout=None):
297313
298314 :return: The value of the specified column of the first record.
299315 """
300- data = await self ._do_execute (query , args , 1 , timeout )
316+ data = await self ._execute (query , args , 1 , timeout )
301317 if not data :
302318 return None
303319 return data [0 ][column ]
@@ -311,7 +327,7 @@ async def fetchrow(self, query, *args, timeout=None):
311327
312328 :return: The first row as a :class:`Record` instance.
313329 """
314- data = await self ._do_execute (query , args , 1 , timeout )
330+ data = await self ._execute (query , args , 1 , timeout )
315331 if not data :
316332 return None
317333 return data [0 ]
@@ -430,7 +446,9 @@ async def _cleanup_stmts(self):
430446 to_close = self ._stmts_to_close
431447 self ._stmts_to_close = set ()
432448 for stmt in to_close :
433- await self ._protocol .close_statement (stmt , False )
449+ # It is imperative that statements are cleaned properly,
450+ # so we ignore the timeout.
451+ await self ._protocol .close_statement (stmt , protocol .NO_TIMEOUT )
434452
435453 def _request_portal_name (self ):
436454 return self ._get_unique_id ()
@@ -554,13 +572,37 @@ def _drop_global_statement_cache(self):
554572 else :
555573 self ._drop_local_statement_cache ()
556574
557- async def _do_execute (self , query , args , limit , timeout ,
558- return_status = False ):
559- stmt = await self ._get_statement (query , timeout )
575+ def _execute (self , query , args , limit , timeout , return_status = False ):
576+ executor = lambda stmt , timeout : self ._protocol .bind_execute (
577+ stmt , args , '' , limit , return_status , timeout )
578+ timeout = self ._protocol ._get_timeout (timeout )
579+ return self ._do_execute (query , executor , timeout )
580+
581+ def _executemany (self , query , args , timeout ):
582+ executor = lambda stmt , timeout : self ._protocol .bind_execute_many (
583+ stmt , args , '' , timeout )
584+ timeout = self ._protocol ._get_timeout (timeout )
585+ return self ._do_execute (query , executor , timeout )
586+
587+ async def _do_execute (self , query , executor , timeout , retry = True ):
588+ if timeout is None :
589+ stmt = await self ._get_statement (query , None )
590+ else :
591+ before = time .monotonic ()
592+ stmt = await self ._get_statement (query , timeout )
593+ after = time .monotonic ()
594+ timeout -= after - before
595+ before = after
560596
561597 try :
562- result = await self ._protocol .bind_execute (
563- stmt , args , '' , limit , return_status , timeout )
598+ if timeout is None :
599+ result = await executor (stmt , None )
600+ else :
601+ try :
602+ result = await executor (stmt , timeout )
603+ finally :
604+ after = time .monotonic ()
605+ timeout -= after - before
564606
565607 except exceptions .InvalidCachedStatementError as e :
566608 # PostgreSQL will raise an exception when it detects
@@ -586,13 +628,11 @@ async def _do_execute(self, query, args, limit, timeout,
586628 # for discussion.
587629 #
588630 self ._drop_global_statement_cache ()
589-
590- if self ._protocol .is_in_transaction ():
631+ if self ._protocol .is_in_transaction () or not retry :
591632 raise
592633 else :
593- stmt = await self ._get_statement (query , timeout )
594- result = await self ._protocol .bind_execute (
595- stmt , args , '' , limit , return_status , timeout )
634+ result = await self ._do_execute (
635+ query , executor , timeout , retry = False )
596636
597637 return result
598638
0 commit comments