|
14 | 14 |
|
15 | 15 | from collections.abc import AsyncGenerator, Callable |
16 | 16 | from contextlib import asynccontextmanager |
| 17 | +from datetime import UTC, datetime |
17 | 18 | from typing import Any |
18 | 19 |
|
19 | 20 | import sentry_sdk |
@@ -54,6 +55,10 @@ def __init__(self, *, echo: bool = False): |
54 | 55 | self._session_factory: async_sessionmaker[AsyncSession] | None = None |
55 | 56 | self._echo = echo |
56 | 57 |
|
| 58 | + def get_database_url(self) -> str: |
| 59 | + """Get the current database URL from configuration.""" |
| 60 | + return get_database_url() |
| 61 | + |
57 | 62 | # ===================================================================== |
58 | 63 | # Connection & Session Management |
59 | 64 | # ===================================================================== |
@@ -84,12 +89,35 @@ async def connect(self, database_url: str | None = None, *, echo: bool | None = |
84 | 89 | echo_setting = echo if echo is not None else self._echo |
85 | 90 |
|
86 | 91 | logger.debug(f"Creating async SQLAlchemy engine (echo={echo_setting})") |
| 92 | + |
| 93 | + # Enhanced connection configuration based on SQLModel best practices |
| 94 | + connect_args = {} |
| 95 | + if "sqlite" in database_url: |
| 96 | + # SQLite-specific optimizations |
| 97 | + connect_args = { |
| 98 | + "check_same_thread": False, |
| 99 | + "timeout": 30, |
| 100 | + } |
| 101 | + elif "postgresql" in database_url: |
| 102 | + # PostgreSQL-specific optimizations |
| 103 | + connect_args = { |
| 104 | + "server_settings": { |
| 105 | + "timezone": "UTC", |
| 106 | + "application_name": "TuxBot", |
| 107 | + }, |
| 108 | + } |
| 109 | + |
87 | 110 | self._engine = create_async_engine( |
88 | 111 | database_url, |
89 | 112 | echo=echo_setting, |
90 | | - pool_pre_ping=True, |
| 113 | + future=True, # Enable SQLAlchemy 2.0 style |
| 114 | + # Connection pooling configuration |
| 115 | + pool_pre_ping=True, # Verify connections before use |
91 | 116 | pool_size=10, |
92 | 117 | max_overflow=20, |
| 118 | + pool_timeout=30, # Connection timeout |
| 119 | + pool_recycle=1800, # Recycle connections after 30 minutes |
| 120 | + connect_args=connect_args, |
93 | 121 | ) |
94 | 122 | self._session_factory = async_sessionmaker( |
95 | 123 | self._engine, |
@@ -121,6 +149,34 @@ async def disconnect(self) -> None: |
121 | 149 | self._session_factory = None |
122 | 150 | logger.info("Disconnected from database") |
123 | 151 |
|
| 152 | + async def health_check(self) -> dict[str, Any]: |
| 153 | + """Perform a database health check.""" |
| 154 | + if not self.is_connected(): |
| 155 | + return {"status": "disconnected", "error": "Database engine not connected"} |
| 156 | + |
| 157 | + try: |
| 158 | + async with self.session() as session: |
| 159 | + # Simple query to test connectivity |
| 160 | + from sqlalchemy import text # noqa: PLC0415 |
| 161 | + |
| 162 | + result = await session.execute(text("SELECT 1")) |
| 163 | + value = result.scalar() |
| 164 | + |
| 165 | + if value == 1: |
| 166 | + return { |
| 167 | + "status": "healthy", |
| 168 | + "pool_size": getattr(self._engine.pool, "size", "unknown") if self._engine else "unknown", |
| 169 | + "checked_connections": getattr(self._engine.pool, "checkedin", "unknown") |
| 170 | + if self._engine |
| 171 | + else "unknown", |
| 172 | + "timestamp": datetime.now(UTC).isoformat(), |
| 173 | + } |
| 174 | + return {"status": "unhealthy", "error": "Unexpected query result"} |
| 175 | + |
| 176 | + except Exception as exc: |
| 177 | + logger.error(f"Database health check failed: {exc}") |
| 178 | + return {"status": "unhealthy", "error": str(exc)} |
| 179 | + |
124 | 180 | @asynccontextmanager |
125 | 181 | async def session(self) -> AsyncGenerator[AsyncSession]: |
126 | 182 | """Return an async SQLAlchemy session context-manager.""" |
|
0 commit comments