66
77from loguru import logger
88from pydantic import BaseModel
9- from sqlalchemy import func
9+ from sqlalchemy import Table , and_ , func , text
1010from sqlalchemy .ext .asyncio import AsyncSession
1111from sqlalchemy .orm import selectinload
1212from sqlmodel import SQLModel , delete , select , update
@@ -58,8 +58,9 @@ async def find_one(self, filters: Any | None = None, order_by: Any | None = None
5858 """Find one record."""
5959 async with self .db .session () as session :
6060 stmt = select (self .model )
61- if filters is not None :
62- stmt = stmt .where (filters )
61+ filter_expr = self ._build_filters (filters )
62+ if filter_expr is not None :
63+ stmt = stmt .where (filter_expr )
6364 if order_by is not None :
6465 stmt = stmt .order_by (order_by )
6566 result = await session .execute (stmt )
@@ -75,8 +76,9 @@ async def find_all(
7576 """Find all records with performance optimizations."""
7677 async with self .db .session () as session :
7778 stmt = select (self .model )
78- if filters is not None :
79- stmt = stmt .where (filters )
79+ filter_expr = self ._build_filters (filters )
80+ if filter_expr is not None :
81+ stmt = stmt .where (filter_expr )
8082 if order_by is not None :
8183 stmt = stmt .order_by (order_by )
8284 if limit is not None :
@@ -474,6 +476,18 @@ async def with_transaction[R](self, operation: Callable[[AsyncSession], Awaitabl
474476 # Utility Methods
475477 # ------------------------------------------------------------------
476478
479+ def _build_filters (self , filters : Any ) -> Any :
480+ """Convert dictionary filters to SQLAlchemy filter expressions."""
481+ if filters is None :
482+ return None
483+
484+ if isinstance (filters , dict ):
485+ filter_expressions : list [Any ] = [getattr (self .model , key ) == value for key , value in filters .items ()] # type: ignore[reportUnknownArgumentType]
486+ return and_ (* filter_expressions ) if filter_expressions else None # type: ignore[arg-type]
487+
488+ # If it's already a proper filter expression, return as-is
489+ return filters
490+
477491 async def get_or_create (self , defaults : dict [str , Any ] | None = None , ** filters : Any ) -> tuple [ModelT , bool ]:
478492 """Get a record by filters, or create it if it doesn't exist.
479493
@@ -566,6 +580,242 @@ async def bulk_delete(self, record_ids: list[Any]) -> int:
566580 await session .commit ()
567581 return len (record_ids )
568582
583+ # ------------------------------------------------------------------
584+ # PostgreSQL-Specific Features - Based on py-pglite Examples
585+ # ------------------------------------------------------------------
586+
587+ async def find_with_json_query (
588+ self ,
589+ json_field : str ,
590+ json_path : str ,
591+ value : Any ,
592+ order_by : Any | None = None ,
593+ ) -> list [ModelT ]:
594+ """
595+ Query records using PostgreSQL JSON operators.
596+
597+ Args:
598+ json_field: Name of the JSON field to query
599+ json_path: JSON path expression (e.g., "$.metadata.key")
600+ value: Value to match
601+ order_by: Optional ordering clause
602+
603+ Example:
604+ guilds = await controller.find_with_json_query(
605+ "metadata", "$.settings.auto_mod", True
606+ )
607+ """
608+ async with self .db .session () as session :
609+ # Use PostgreSQL JSON path operators
610+ stmt = select (self .model ).where (
611+ text (f"{ json_field } ::jsonb @> :value::jsonb" ),
612+ )
613+
614+ if order_by is not None :
615+ stmt = stmt .order_by (order_by )
616+
617+ result = await session .execute (stmt , {"value" : f'{{"{ json_path .replace ("$." , "" )} ": { value } }}' })
618+ return list (result .scalars ().all ())
619+
620+ async def find_with_array_contains (
621+ self ,
622+ array_field : str ,
623+ value : str | list [str ],
624+ order_by : Any | None = None ,
625+ ) -> list [ModelT ]:
626+ """
627+ Query records where array field contains specific value(s).
628+
629+ Args:
630+ array_field: Name of the array field
631+ value: Single value or list of values to check for
632+ order_by: Optional ordering clause
633+
634+ Example:
635+ guilds = await controller.find_with_array_contains("tags", "gaming")
636+ """
637+ async with self .db .session () as session :
638+ if isinstance (value , str ):
639+ # Single value containment check
640+ stmt = select (self .model ).where (
641+ text (f":value = ANY({ array_field } )" ),
642+ )
643+ params = {"value" : value }
644+ else :
645+ # Multiple values overlap check
646+ stmt = select (self .model ).where (
647+ text (f"{ array_field } && :values" ),
648+ )
649+ params = {"values" : value }
650+
651+ if order_by is not None :
652+ stmt = stmt .order_by (order_by )
653+
654+ result = await session .execute (stmt , params )
655+ return list (result .scalars ().all ())
656+
657+ async def find_with_full_text_search (
658+ self ,
659+ text_field : str ,
660+ search_query : str ,
661+ rank_order : bool = True ,
662+ ) -> list [tuple [ModelT , float ]]:
663+ """
664+ Perform full-text search using PostgreSQL's built-in capabilities.
665+
666+ Args:
667+ text_field: Field to search in
668+ search_query: Search query
669+ rank_order: Whether to order by relevance rank
670+
671+ Returns:
672+ List of tuples (model, rank) if rank_order=True, else just models
673+ """
674+ async with self .db .session () as session :
675+ if rank_order :
676+ stmt = (
677+ select (
678+ self .model ,
679+ func .ts_rank (
680+ func .to_tsvector ("english" , getattr (self .model , text_field )),
681+ func .plainto_tsquery ("english" , search_query ),
682+ ).label ("rank" ),
683+ )
684+ .where (
685+ func .to_tsvector ("english" , getattr (self .model , text_field )).match (
686+ func .plainto_tsquery ("english" , search_query ),
687+ ),
688+ )
689+ .order_by (text ("rank DESC" ))
690+ )
691+
692+ result = await session .execute (stmt )
693+ return [(row [0 ], float (row [1 ])) for row in result .fetchall ()]
694+ stmt = select (self .model ).where (
695+ func .to_tsvector ("english" , getattr (self .model , text_field )).match (
696+ func .plainto_tsquery ("english" , search_query ),
697+ ),
698+ )
699+ result = await session .execute (stmt )
700+ return [(model , 0.0 ) for model in result .scalars ().all ()]
701+
702+ async def bulk_upsert_with_conflict_resolution (
703+ self ,
704+ records : list [dict [str , Any ]],
705+ conflict_columns : list [str ],
706+ update_columns : list [str ] | None = None ,
707+ ) -> int :
708+ """
709+ Bulk upsert using PostgreSQL's ON CONFLICT capabilities.
710+
711+ Args:
712+ records: List of record dictionaries
713+ conflict_columns: Columns that define uniqueness
714+ update_columns: Columns to update on conflict (if None, updates all)
715+
716+ Returns:
717+ Number of records processed
718+ """
719+ if not records :
720+ return 0
721+
722+ async with self .db .session () as session :
723+ # Use PostgreSQL's INSERT ... ON CONFLICT for high-performance upserts
724+ table : Table = self .model .__table__ # pyright: ignore[reportAttributeAccessIssue,reportUnknownMemberType,reportUnknownVariableType]
725+
726+ # Build the ON CONFLICT clause
727+ conflict_clause = ", " .join (conflict_columns )
728+
729+ if update_columns is None :
730+ # Update all columns except the conflict columns
731+ update_columns = [col .name for col in table .columns if col .name not in conflict_columns ] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
732+
733+ update_clause = ", " .join ([f"{ col } = EXCLUDED.{ col } " for col in update_columns ])
734+
735+ # Build the SQL statement
736+ columns = ", " .join (records [0 ].keys ())
737+ placeholders = ", " .join ([f":{ key } " for key in records [0 ]])
738+
739+ table_name_attr = getattr (table , "name" , "unknown" ) # pyright: ignore[reportUnknownArgumentType]
740+ sql = f"""
741+ INSERT INTO { table_name_attr } ({ columns } )
742+ VALUES ({ placeholders } )
743+ ON CONFLICT ({ conflict_clause } )
744+ DO UPDATE SET { update_clause }
745+ """
746+
747+ # Execute for all records
748+ await session .execute (text (sql ), records )
749+ await session .commit ()
750+
751+ return len (records )
752+
753+ async def get_table_statistics (self ) -> dict [str , Any ]:
754+ """
755+ Get PostgreSQL table statistics for this model.
756+
757+ Based on py-pglite monitoring patterns.
758+ """
759+ async with self .db .session () as session :
760+ table_name : str = self .model .__tablename__ # pyright: ignore[reportAttributeAccessIssue,reportUnknownMemberType,reportUnknownVariableType]
761+
762+ result = await session .execute (
763+ text ("""
764+ SELECT
765+ schemaname,
766+ tablename,
767+ n_tup_ins as total_inserts,
768+ n_tup_upd as total_updates,
769+ n_tup_del as total_deletes,
770+ n_live_tup as live_tuples,
771+ n_dead_tup as dead_tuples,
772+ seq_scan as sequential_scans,
773+ seq_tup_read as sequential_tuples_read,
774+ idx_scan as index_scans,
775+ idx_tup_fetch as index_tuples_fetched,
776+ n_tup_hot_upd as hot_updates,
777+ n_tup_newpage_upd as newpage_updates
778+ FROM pg_stat_user_tables
779+ WHERE tablename = :table_name
780+ """ ),
781+ {"table_name" : table_name },
782+ )
783+
784+ stats = result .fetchone ()
785+ return dict (stats ._mapping ) if stats else {} # pyright: ignore[reportPrivateUsage]
786+
787+ async def explain_query_performance (
788+ self ,
789+ filters : Any | None = None ,
790+ order_by : Any | None = None ,
791+ ) -> dict [str , Any ]:
792+ """
793+ Analyze query performance using EXPLAIN ANALYZE.
794+
795+ Development utility based on py-pglite optimization patterns.
796+ """
797+ async with self .db .session () as session :
798+ stmt = select (self .model )
799+ if filters is not None :
800+ stmt = stmt .where (filters )
801+ if order_by is not None :
802+ stmt = stmt .order_by (order_by )
803+
804+ # Get the compiled SQL
805+ compiled = stmt .compile (compile_kwargs = {"literal_binds" : True })
806+ sql_query = str (compiled )
807+
808+ # Analyze with EXPLAIN
809+ explain_stmt = text (f"EXPLAIN (ANALYZE, BUFFERS, FORMAT JSON) { sql_query } " )
810+ result = await session .execute (explain_stmt )
811+ plan_data = result .scalar ()
812+
813+ return {
814+ "query" : sql_query ,
815+ "plan" : plan_data [0 ] if plan_data else {},
816+ "model" : self .model .__name__ ,
817+ }
818+
569819 @staticmethod
570820 def safe_get_attr (obj : Any , attr : str , default : Any = None ) -> Any :
571821 """Return getattr(obj, attr, default) - keeps old helper available."""
0 commit comments