Skip to content

Commit ec9d4e2

Browse files
committed
Replaced wait
1 parent f0bf381 commit ec9d4e2

File tree

1 file changed

+66
-17
lines changed
  • modules/generic/testcontainers/generic

1 file changed

+66
-17
lines changed

modules/generic/testcontainers/generic/sql.py

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from testcontainers.core.container import DockerContainer
66
from testcontainers.core.exceptions import ContainerStartException
77
from testcontainers.core.utils import raise_for_deprecated_parameter
8-
from testcontainers.core.waiting_utils import wait_container_is_ready
8+
from testcontainers.core.waiting_utils import WaitStrategy, WaitStrategyTarget
99

1010
logger = logging.getLogger(__name__)
1111

@@ -18,19 +18,67 @@
1818
logger.debug("SQLAlchemy not available, skipping DBAPIError handling")
1919

2020

21+
class DatabaseConnectionWaitStrategy(WaitStrategy):
22+
"""
23+
Wait strategy for database connection readiness using SqlContainer._connect().
24+
25+
This strategy implements retry logic and calls SqlContainer._connect()
26+
repeatedly until it succeeds or times out.
27+
"""
28+
29+
def __init__(self, sql_container: "SqlContainer"):
30+
super().__init__()
31+
self.sql_container = sql_container
32+
33+
def wait_until_ready(self, container: WaitStrategyTarget) -> None:
34+
"""
35+
Test database connectivity with retry logic by calling SqlContainer._connect().
36+
37+
Raises:
38+
TimeoutError: If connection fails after timeout
39+
Exception: Any non-transient errors from _connect()
40+
"""
41+
import time
42+
43+
start_time = time.time()
44+
45+
transient_exceptions = (TimeoutError, ConnectionError, *ADDITIONAL_TRANSIENT_ERRORS)
46+
47+
while True:
48+
if time.time() - start_time > self._startup_timeout:
49+
raise TimeoutError(
50+
f"Database connection failed after {self._startup_timeout}s timeout. "
51+
f"Hint: Check if the database container is ready and accessible."
52+
)
53+
54+
try:
55+
self.sql_container._connect()
56+
return
57+
except transient_exceptions as e:
58+
logger.debug(f"Database connection attempt failed: {e}, retrying in {self._poll_interval}s...")
59+
except Exception as e:
60+
logger.error(f"Database connection test failed with non-transient error: {e}")
61+
raise
62+
63+
time.sleep(self._poll_interval)
64+
65+
2166
class SqlContainer(DockerContainer):
2267
"""
2368
Generic SQL database container providing common functionality.
2469
2570
This class can serve as a base for database-specific container implementations.
2671
It provides connection management, URL construction, and basic lifecycle methods.
72+
Database connection readiness is automatically handled by DatabaseConnectionWaitStrategy.
2773
"""
2874

29-
@wait_container_is_ready(*ADDITIONAL_TRANSIENT_ERRORS)
3075
def _connect(self) -> None:
3176
"""
3277
Test database connectivity using SQLAlchemy.
3378
79+
This method performs a single connection test without retry logic.
80+
Retry logic is handled by the DatabaseConnectionWaitStrategy.
81+
3482
Raises:
3583
ImportError: If SQLAlchemy is not installed
3684
Exception: If connection fails
@@ -42,29 +90,17 @@ def _connect(self) -> None:
4290
raise ImportError("SQLAlchemy is required for database containers") from e
4391

4492
connection_url = self.get_connection_url()
45-
4693
engine = sqlalchemy.create_engine(connection_url)
94+
4795
try:
4896
with engine.connect():
4997
logger.info("Database connection test successful")
5098
except Exception as e:
51-
logger.error(f"Database connection test failed: {e}")
99+
logger.debug(f"Database connection attempt failed: {e}")
52100
raise
53101
finally:
54102
engine.dispose()
55103

56-
def get_connection_url(self) -> str:
57-
"""
58-
Get the database connection URL.
59-
60-
Returns:
61-
str: Database connection URL
62-
63-
Raises:
64-
NotImplementedError: Must be implemented by subclasses
65-
"""
66-
raise NotImplementedError("Subclasses must implement get_connection_url()")
67-
68104
def _create_connection_url(
69105
self,
70106
dialect: str,
@@ -147,9 +183,10 @@ def start(self) -> "SqlContainer":
147183

148184
try:
149185
self._configure()
186+
# Set up database connection wait strategy before starting
187+
self.waiting_for(DatabaseConnectionWaitStrategy(self))
150188
super().start()
151189
self._transfer_seed()
152-
self._connect()
153190
logger.info("Database container started successfully")
154191
except Exception as e:
155192
logger.error(f"Failed to start database container: {e}")
@@ -174,3 +211,15 @@ def _transfer_seed(self) -> None:
174211
database-specific seeding functionality.
175212
"""
176213
logger.debug("No seed data to transfer")
214+
215+
def get_connection_url(self) -> str:
216+
"""
217+
Get the database connection URL.
218+
219+
Returns:
220+
str: Database connection URL
221+
222+
Raises:
223+
NotImplementedError: Must be implemented by subclasses
224+
"""
225+
raise NotImplementedError("Subclasses must implement get_connection_url()")

0 commit comments

Comments
 (0)