|
4 | 4 |
|
5 | 5 | from testcontainers.core.container import DockerContainer |
6 | 6 | from testcontainers.core.exceptions import ContainerStartException |
7 | | -from testcontainers.core.utils import raise_for_deprecated_parameter |
8 | | -from testcontainers.core.waiting_utils import WaitStrategy, WaitStrategyTarget |
9 | 7 |
|
10 | | -logger = logging.getLogger(__name__) |
11 | | - |
12 | | -ADDITIONAL_TRANSIENT_ERRORS = [] |
13 | | -try: |
14 | | - from sqlalchemy.exc import DBAPIError |
15 | | - |
16 | | - ADDITIONAL_TRANSIENT_ERRORS.append(DBAPIError) |
17 | | -except ImportError: |
18 | | - logger.debug("SQLAlchemy not available, skipping DBAPIError handling") |
19 | | -SQL_TRANSIENT_EXCEPTIONS = (TimeoutError, ConnectionError, *ADDITIONAL_TRANSIENT_ERRORS) |
20 | | - |
21 | | - |
22 | | -class ConnectWaitStrategy(WaitStrategy): |
23 | | - """ |
24 | | - Wait strategy that tests database connectivity until it succeeds or times out. |
25 | | -
|
26 | | - This strategy performs database connection testing using SQLAlchemy directly, |
27 | | - handling transient connection errors and providing appropriate retry logic |
28 | | - for database connectivity testing. |
29 | | - """ |
| 8 | +from .sql_utils import SqlConnectWaitStrategy |
30 | 9 |
|
31 | | - def __init__(self, transient_exceptions: Optional[tuple] = None): |
32 | | - super().__init__() |
33 | | - self.transient_exceptions = transient_exceptions or (TimeoutError, ConnectionError) |
34 | | - |
35 | | - def wait_until_ready(self, container: WaitStrategyTarget) -> None: |
36 | | - """ |
37 | | - Test database connectivity with retry logic until it succeeds or times out. |
38 | | -
|
39 | | - Args: |
40 | | - container: The SQL container that must have get_connection_url method |
41 | | -
|
42 | | - Raises: |
43 | | - TimeoutError: If connection fails after timeout |
44 | | - AttributeError: If container doesn't have get_connection_url method |
45 | | - ImportError: If SQLAlchemy is not installed |
46 | | - Exception: Any non-transient errors from connection attempts |
47 | | - """ |
48 | | - import time |
49 | | - |
50 | | - if not hasattr(container, "get_connection_url"): |
51 | | - raise AttributeError(f"Container {container} must have a get_connection_url method") |
52 | | - |
53 | | - try: |
54 | | - import sqlalchemy |
55 | | - except ImportError as e: |
56 | | - logger.error("SQLAlchemy is required for database connectivity testing") |
57 | | - raise ImportError("SQLAlchemy is required for database containers") from e |
58 | | - |
59 | | - start_time = time.time() |
60 | | - |
61 | | - while True: |
62 | | - if time.time() - start_time > self._startup_timeout: |
63 | | - raise TimeoutError( |
64 | | - f"Database connection failed after {self._startup_timeout}s timeout. " |
65 | | - f"Hint: Check if the container is ready and the database is accessible." |
66 | | - ) |
67 | | - |
68 | | - try: |
69 | | - connection_url = container.get_connection_url() |
70 | | - engine = sqlalchemy.create_engine(connection_url) |
71 | | - |
72 | | - try: |
73 | | - with engine.connect(): |
74 | | - logger.info("Database connection test successful") |
75 | | - return |
76 | | - except Exception as e: |
77 | | - logger.debug(f"Database connection attempt failed: {e}") |
78 | | - raise |
79 | | - finally: |
80 | | - engine.dispose() |
81 | | - |
82 | | - except self.transient_exceptions as e: |
83 | | - logger.debug(f"Connection attempt failed: {e}, retrying in {self._poll_interval}s...") |
84 | | - except Exception as e: |
85 | | - logger.error(f"Connection failed with non-transient error: {e}") |
86 | | - raise |
87 | | - |
88 | | - time.sleep(self._poll_interval) |
| 10 | +logger = logging.getLogger(__name__) |
89 | 11 |
|
90 | 12 |
|
91 | 13 | class SqlContainer(DockerContainer): |
@@ -128,8 +50,6 @@ def _create_connection_url( |
128 | 50 | ValueError: If unexpected arguments are provided or required parameters are missing |
129 | 51 | ContainerStartException: If container is not started |
130 | 52 | """ |
131 | | - if raise_for_deprecated_parameter(kwargs, "db_name", "dbname"): |
132 | | - raise ValueError(f"Unexpected arguments: {','.join(kwargs)}") |
133 | 53 |
|
134 | 54 | if self._container is None: |
135 | 55 | raise ContainerStartException("Container has not been started") |
@@ -179,7 +99,7 @@ def start(self) -> "SqlContainer": |
179 | 99 |
|
180 | 100 | try: |
181 | 101 | self._configure() |
182 | | - self.waiting_for(ConnectWaitStrategy(SQL_TRANSIENT_EXCEPTIONS)) |
| 102 | + self.waiting_for(SqlConnectWaitStrategy()) |
183 | 103 | super().start() |
184 | 104 | self._transfer_seed() |
185 | 105 | logger.info("Database container started successfully") |
|
0 commit comments