Skip to content

Commit a66cd3b

Browse files
committed
feat(database): enhance DatabaseService and BaseController with advanced features
- Updated DatabaseService to include retry logic for database operations, improved connection pooling settings, and added methods for database metrics and query performance analysis. - Refactored BaseController to support dynamic filter building and introduced PostgreSQL-specific query capabilities, including JSON querying, array containment checks, and full-text search. - Added bulk upsert functionality to streamline record insertion and conflict resolution. - Enhanced model definitions to utilize PostgreSQL features like JSONB and arrays for flexible data storage.
1 parent 8b39df0 commit a66cd3b

File tree

5 files changed

+715
-51
lines changed

5 files changed

+715
-51
lines changed

src/tux/database/controllers/base.py

Lines changed: 255 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from loguru import logger
88
from pydantic import BaseModel
9-
from sqlalchemy import func
9+
from sqlalchemy import Table, and_, func, text
1010
from sqlalchemy.ext.asyncio import AsyncSession
1111
from sqlalchemy.orm import selectinload
1212
from sqlmodel import SQLModel, delete, select, update
@@ -58,8 +58,9 @@ async def find_one(self, filters: Any | None = None, order_by: Any | None = None
5858
"""Find one record."""
5959
async with self.db.session() as session:
6060
stmt = select(self.model)
61-
if filters is not None:
62-
stmt = stmt.where(filters)
61+
filter_expr = self._build_filters(filters)
62+
if filter_expr is not None:
63+
stmt = stmt.where(filter_expr)
6364
if order_by is not None:
6465
stmt = stmt.order_by(order_by)
6566
result = await session.execute(stmt)
@@ -75,8 +76,9 @@ async def find_all(
7576
"""Find all records with performance optimizations."""
7677
async with self.db.session() as session:
7778
stmt = select(self.model)
78-
if filters is not None:
79-
stmt = stmt.where(filters)
79+
filter_expr = self._build_filters(filters)
80+
if filter_expr is not None:
81+
stmt = stmt.where(filter_expr)
8082
if order_by is not None:
8183
stmt = stmt.order_by(order_by)
8284
if limit is not None:
@@ -474,6 +476,18 @@ async def with_transaction[R](self, operation: Callable[[AsyncSession], Awaitabl
474476
# Utility Methods
475477
# ------------------------------------------------------------------
476478

479+
def _build_filters(self, filters: Any) -> Any:
480+
"""Convert dictionary filters to SQLAlchemy filter expressions."""
481+
if filters is None:
482+
return None
483+
484+
if isinstance(filters, dict):
485+
filter_expressions: list[Any] = [getattr(self.model, key) == value for key, value in filters.items()] # type: ignore[reportUnknownArgumentType]
486+
return and_(*filter_expressions) if filter_expressions else None # type: ignore[arg-type]
487+
488+
# If it's already a proper filter expression, return as-is
489+
return filters
490+
477491
async def get_or_create(self, defaults: dict[str, Any] | None = None, **filters: Any) -> tuple[ModelT, bool]:
478492
"""Get a record by filters, or create it if it doesn't exist.
479493
@@ -566,6 +580,242 @@ async def bulk_delete(self, record_ids: list[Any]) -> int:
566580
await session.commit()
567581
return len(record_ids)
568582

583+
# ------------------------------------------------------------------
584+
# PostgreSQL-Specific Features - Based on py-pglite Examples
585+
# ------------------------------------------------------------------
586+
587+
async def find_with_json_query(
588+
self,
589+
json_field: str,
590+
json_path: str,
591+
value: Any,
592+
order_by: Any | None = None,
593+
) -> list[ModelT]:
594+
"""
595+
Query records using PostgreSQL JSON operators.
596+
597+
Args:
598+
json_field: Name of the JSON field to query
599+
json_path: JSON path expression (e.g., "$.metadata.key")
600+
value: Value to match
601+
order_by: Optional ordering clause
602+
603+
Example:
604+
guilds = await controller.find_with_json_query(
605+
"metadata", "$.settings.auto_mod", True
606+
)
607+
"""
608+
async with self.db.session() as session:
609+
# Use PostgreSQL JSON path operators
610+
stmt = select(self.model).where(
611+
text(f"{json_field}::jsonb @> :value::jsonb"),
612+
)
613+
614+
if order_by is not None:
615+
stmt = stmt.order_by(order_by)
616+
617+
result = await session.execute(stmt, {"value": f'{{"{json_path.replace("$.", "")}": {value}}}'})
618+
return list(result.scalars().all())
619+
620+
async def find_with_array_contains(
621+
self,
622+
array_field: str,
623+
value: str | list[str],
624+
order_by: Any | None = None,
625+
) -> list[ModelT]:
626+
"""
627+
Query records where array field contains specific value(s).
628+
629+
Args:
630+
array_field: Name of the array field
631+
value: Single value or list of values to check for
632+
order_by: Optional ordering clause
633+
634+
Example:
635+
guilds = await controller.find_with_array_contains("tags", "gaming")
636+
"""
637+
async with self.db.session() as session:
638+
if isinstance(value, str):
639+
# Single value containment check
640+
stmt = select(self.model).where(
641+
text(f":value = ANY({array_field})"),
642+
)
643+
params = {"value": value}
644+
else:
645+
# Multiple values overlap check
646+
stmt = select(self.model).where(
647+
text(f"{array_field} && :values"),
648+
)
649+
params = {"values": value}
650+
651+
if order_by is not None:
652+
stmt = stmt.order_by(order_by)
653+
654+
result = await session.execute(stmt, params)
655+
return list(result.scalars().all())
656+
657+
async def find_with_full_text_search(
658+
self,
659+
text_field: str,
660+
search_query: str,
661+
rank_order: bool = True,
662+
) -> list[tuple[ModelT, float]]:
663+
"""
664+
Perform full-text search using PostgreSQL's built-in capabilities.
665+
666+
Args:
667+
text_field: Field to search in
668+
search_query: Search query
669+
rank_order: Whether to order by relevance rank
670+
671+
Returns:
672+
List of tuples (model, rank) if rank_order=True, else just models
673+
"""
674+
async with self.db.session() as session:
675+
if rank_order:
676+
stmt = (
677+
select(
678+
self.model,
679+
func.ts_rank(
680+
func.to_tsvector("english", getattr(self.model, text_field)),
681+
func.plainto_tsquery("english", search_query),
682+
).label("rank"),
683+
)
684+
.where(
685+
func.to_tsvector("english", getattr(self.model, text_field)).match(
686+
func.plainto_tsquery("english", search_query),
687+
),
688+
)
689+
.order_by(text("rank DESC"))
690+
)
691+
692+
result = await session.execute(stmt)
693+
return [(row[0], float(row[1])) for row in result.fetchall()]
694+
stmt = select(self.model).where(
695+
func.to_tsvector("english", getattr(self.model, text_field)).match(
696+
func.plainto_tsquery("english", search_query),
697+
),
698+
)
699+
result = await session.execute(stmt)
700+
return [(model, 0.0) for model in result.scalars().all()]
701+
702+
async def bulk_upsert_with_conflict_resolution(
703+
self,
704+
records: list[dict[str, Any]],
705+
conflict_columns: list[str],
706+
update_columns: list[str] | None = None,
707+
) -> int:
708+
"""
709+
Bulk upsert using PostgreSQL's ON CONFLICT capabilities.
710+
711+
Args:
712+
records: List of record dictionaries
713+
conflict_columns: Columns that define uniqueness
714+
update_columns: Columns to update on conflict (if None, updates all)
715+
716+
Returns:
717+
Number of records processed
718+
"""
719+
if not records:
720+
return 0
721+
722+
async with self.db.session() as session:
723+
# Use PostgreSQL's INSERT ... ON CONFLICT for high-performance upserts
724+
table: Table = self.model.__table__ # pyright: ignore[reportAttributeAccessIssue,reportUnknownMemberType,reportUnknownVariableType]
725+
726+
# Build the ON CONFLICT clause
727+
conflict_clause = ", ".join(conflict_columns)
728+
729+
if update_columns is None:
730+
# Update all columns except the conflict columns
731+
update_columns = [col.name for col in table.columns if col.name not in conflict_columns] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
732+
733+
update_clause = ", ".join([f"{col} = EXCLUDED.{col}" for col in update_columns])
734+
735+
# Build the SQL statement
736+
columns = ", ".join(records[0].keys())
737+
placeholders = ", ".join([f":{key}" for key in records[0]])
738+
739+
table_name_attr = getattr(table, "name", "unknown") # pyright: ignore[reportUnknownArgumentType]
740+
sql = f"""
741+
INSERT INTO {table_name_attr} ({columns})
742+
VALUES ({placeholders})
743+
ON CONFLICT ({conflict_clause})
744+
DO UPDATE SET {update_clause}
745+
"""
746+
747+
# Execute for all records
748+
await session.execute(text(sql), records)
749+
await session.commit()
750+
751+
return len(records)
752+
753+
async def get_table_statistics(self) -> dict[str, Any]:
754+
"""
755+
Get PostgreSQL table statistics for this model.
756+
757+
Based on py-pglite monitoring patterns.
758+
"""
759+
async with self.db.session() as session:
760+
table_name: str = self.model.__tablename__ # pyright: ignore[reportAttributeAccessIssue,reportUnknownMemberType,reportUnknownVariableType]
761+
762+
result = await session.execute(
763+
text("""
764+
SELECT
765+
schemaname,
766+
tablename,
767+
n_tup_ins as total_inserts,
768+
n_tup_upd as total_updates,
769+
n_tup_del as total_deletes,
770+
n_live_tup as live_tuples,
771+
n_dead_tup as dead_tuples,
772+
seq_scan as sequential_scans,
773+
seq_tup_read as sequential_tuples_read,
774+
idx_scan as index_scans,
775+
idx_tup_fetch as index_tuples_fetched,
776+
n_tup_hot_upd as hot_updates,
777+
n_tup_newpage_upd as newpage_updates
778+
FROM pg_stat_user_tables
779+
WHERE tablename = :table_name
780+
"""),
781+
{"table_name": table_name},
782+
)
783+
784+
stats = result.fetchone()
785+
return dict(stats._mapping) if stats else {} # pyright: ignore[reportPrivateUsage]
786+
787+
async def explain_query_performance(
788+
self,
789+
filters: Any | None = None,
790+
order_by: Any | None = None,
791+
) -> dict[str, Any]:
792+
"""
793+
Analyze query performance using EXPLAIN ANALYZE.
794+
795+
Development utility based on py-pglite optimization patterns.
796+
"""
797+
async with self.db.session() as session:
798+
stmt = select(self.model)
799+
if filters is not None:
800+
stmt = stmt.where(filters)
801+
if order_by is not None:
802+
stmt = stmt.order_by(order_by)
803+
804+
# Get the compiled SQL
805+
compiled = stmt.compile(compile_kwargs={"literal_binds": True})
806+
sql_query = str(compiled)
807+
808+
# Analyze with EXPLAIN
809+
explain_stmt = text(f"EXPLAIN (ANALYZE, BUFFERS, FORMAT JSON) {sql_query}")
810+
result = await session.execute(explain_stmt)
811+
plan_data = result.scalar()
812+
813+
return {
814+
"query": sql_query,
815+
"plan": plan_data[0] if plan_data else {},
816+
"model": self.model.__name__,
817+
}
818+
569819
@staticmethod
570820
def safe_get_attr(obj: Any, attr: str, default: Any = None) -> Any:
571821
"""Return getattr(obj, attr, default) - keeps old helper available."""

src/tux/database/migrations/runner.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from alembic.config import Config
88
from loguru import logger
99

10-
from tux.database.service import DatabaseService
1110
from tux.shared.config.env import get_database_url, is_dev_mode
1211

1312

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""
2+
Revision ID: 0e3ef5ec0555
3+
Revises:
4+
Create Date: 2025-08-25 04:08:52.331369+00:00
5+
"""
6+
from __future__ import annotations
7+
8+
from typing import Union
9+
from collections.abc import Sequence
10+
11+
from alembic import op
12+
import sqlalchemy as sa
13+
14+
# revision identifiers, used by Alembic.
15+
revision: str = '0e3ef5ec0555'
16+
down_revision: str | None = None
17+
branch_labels: str | Sequence[str] | None = None
18+
depends_on: str | Sequence[str] | None = None
19+
20+
21+
def upgrade() -> None:
22+
pass
23+
24+
25+
def downgrade() -> None:
26+
pass

0 commit comments

Comments
 (0)