2121from asyncpg import connection as pg_connection
2222from asyncpg import pool as pg_pool
2323
24+ from . import fuzzer
25+
2426
2527@contextlib .contextmanager
2628def silence_asyncio_long_exec_warning ():
@@ -36,7 +38,16 @@ def flt(log_record):
3638 logger .removeFilter (flt )
3739
3840
41+ def with_timeout (timeout ):
42+ def wrap (func ):
43+ func .__timeout__ = timeout
44+ return func
45+
46+ return wrap
47+
48+
3949class TestCaseMeta (type (unittest .TestCase )):
50+ TEST_TIMEOUT = None
4051
4152 @staticmethod
4253 def _iter_methods (bases , ns ):
@@ -64,7 +75,18 @@ def __new__(mcls, name, bases, ns):
6475 for methname , meth in mcls ._iter_methods (bases , ns ):
6576 @functools .wraps (meth )
6677 def wrapper (self , * args , __meth__ = meth , ** kwargs ):
67- self .loop .run_until_complete (__meth__ (self , * args , ** kwargs ))
78+ coro = __meth__ (self , * args , ** kwargs )
79+ timeout = getattr (__meth__ , '__timeout__' , mcls .TEST_TIMEOUT )
80+ if timeout :
81+ coro = asyncio .wait_for (coro , timeout , loop = self .loop )
82+ try :
83+ self .loop .run_until_complete (coro )
84+ except asyncio .TimeoutError :
85+ raise self .failureException (
86+ 'test timed out after {} seconds' .format (
87+ timeout )) from None
88+ else :
89+ self .loop .run_until_complete (coro )
6890 ns [methname ] = wrapper
6991
7092 return super ().__new__ (mcls , name , bases , ns )
@@ -169,7 +191,8 @@ def _start_default_cluster(server_settings={}, initdb_options=None):
169191
170192
171193def _shutdown_cluster (cluster ):
172- cluster .stop ()
194+ if cluster .get_status () == 'running' :
195+ cluster .stop ()
173196 cluster .destroy ()
174197
175198
@@ -220,9 +243,11 @@ def get_connection_spec(cls, kwargs={}):
220243 conn_spec ['user' ] = 'postgres'
221244 return conn_spec
222245
223- def create_pool (self , pool_class = pg_pool .Pool , ** kwargs ):
246+ def create_pool (self , pool_class = pg_pool .Pool ,
247+ connection_class = pg_connection .Connection , ** kwargs ):
224248 conn_spec = self .get_connection_spec (kwargs )
225- return create_pool (loop = self .loop , pool_class = pool_class , ** conn_spec )
249+ return create_pool (loop = self .loop , pool_class = pool_class ,
250+ connection_class = connection_class , ** conn_spec )
226251
227252 @classmethod
228253 def connect (cls , ** kwargs ):
@@ -238,6 +263,49 @@ def start_cluster(cls, ClusterCls, *,
238263 server_settings , _get_initdb_options (initdb_options ))
239264
240265
266+ class ProxiedClusterTestCase (ClusterTestCase ):
267+ @classmethod
268+ def get_server_settings (cls ):
269+ settings = dict (super ().get_server_settings ())
270+ settings ['listen_addresses' ] = '127.0.0.1'
271+ return settings
272+
273+ @classmethod
274+ def get_proxy_settings (cls ):
275+ return {'fuzzing-mode' : None }
276+
277+ @classmethod
278+ def setUpClass (cls ):
279+ super ().setUpClass ()
280+ conn_spec = cls .cluster .get_connection_spec ()
281+ host = conn_spec .get ('host' )
282+ if not host :
283+ host = '127.0.0.1'
284+ elif host .startswith ('/' ):
285+ host = '127.0.0.1'
286+ cls .proxy = fuzzer .TCPFuzzingProxy (
287+ backend_host = host ,
288+ backend_port = conn_spec ['port' ],
289+ )
290+ cls .proxy .start ()
291+
292+ @classmethod
293+ def tearDownClass (cls ):
294+ cls .proxy .stop ()
295+ super ().tearDownClass ()
296+
297+ @classmethod
298+ def get_connection_spec (cls , kwargs ):
299+ conn_spec = super ().get_connection_spec (kwargs )
300+ conn_spec ['host' ] = cls .proxy .listening_addr
301+ conn_spec ['port' ] = cls .proxy .listening_port
302+ return conn_spec
303+
304+ def tearDown (self ):
305+ self .proxy .reset ()
306+ super ().tearDown ()
307+
308+
241309def with_connection_options (** options ):
242310 if not options :
243311 raise ValueError ('no connection options were specified' )
0 commit comments