2121from data_diff .queries .extras import ApplyFuncAndNormalizeAsString , Checksum , NormalizeAsString
2222from data_diff .utils import ArithString , is_uuid , join_iter , safezip
2323from data_diff .queries .api import Expr , table , Select , SKIP , Explain , Code , this
24- from data_diff .queries .ast_classes import Alias , BinOp , CaseWhen , Cast , Column , Commit , Concat , ConstantTable , Count , \
25- CreateTable , Cte , \
26- CurrentTimestamp , DropTable , Func , \
27- GroupBy , \
28- ITable , In , InsertToTable , IsDistinctFrom , \
29- Join , \
30- Param , \
31- Random , \
32- Root , TableAlias , TableOp , TablePath , \
33- TimeTravel , TruncateTable , UnaryOp , WhenThen , _ResolveColumn
24+ from data_diff .queries .ast_classes import (
25+ Alias ,
26+ BinOp ,
27+ CaseWhen ,
28+ Cast ,
29+ Column ,
30+ Commit ,
31+ Concat ,
32+ ConstantTable ,
33+ Count ,
34+ CreateTable ,
35+ Cte ,
36+ CurrentTimestamp ,
37+ DropTable ,
38+ Func ,
39+ GroupBy ,
40+ ITable ,
41+ In ,
42+ InsertToTable ,
43+ IsDistinctFrom ,
44+ Join ,
45+ Param ,
46+ Random ,
47+ Root ,
48+ TableAlias ,
49+ TableOp ,
50+ TablePath ,
51+ TimeTravel ,
52+ TruncateTable ,
53+ UnaryOp ,
54+ WhenThen ,
55+ _ResolveColumn ,
56+ )
3457from data_diff .abcs .database_types import (
3558 Array ,
3659 Struct ,
@@ -67,17 +90,11 @@ class CompileError(Exception):
6790 pass
6891
6992
70- # TODO: LATER: Resolve the circular imports of databases-compiler-dialects:
71- # A database uses a compiler to render the SQL query.
72- # The compiler delegates to a dialect.
73- # The dialect renders the SQL.
74- # AS IS: The dialect requires the db to normalize table paths — leading to the back-dependency.
75- # TO BE: All the tables paths must be pre-normalized before SQL rendering.
76- # Also: c.database.is_autocommit in render_commit().
77- # After this, the Compiler can cease referring Database/Dialect at all,
78- # and be used only as a CompilingContext (a counter/data-bearing class).
79- # As a result, it becomes low-level util, and the circular dependency auto-resolves.
80- # Meanwhile, the easy fix is to simply move the Compiler here.
93+ # TODO: remove once switched to attrs, where ForwardRef[]/strings are resolved.
94+ class _RuntypeHackToFixCicularRefrencedDatabase :
95+ dialect : "BaseDialect"
96+
97+
8198@dataclass
8299class Compiler (AbstractCompiler ):
83100 """
@@ -90,7 +107,7 @@ class Compiler(AbstractCompiler):
90107 # Database is needed to normalize tables. Dialect is needed for recursive compilations.
91108 # In theory, it is many-to-many relations: e.g. a generic ODBC driver with multiple dialects.
92109 # In practice, we currently bind the dialects to the specific database classes.
93- database : "Database"
110+ database : _RuntypeHackToFixCicularRefrencedDatabase
94111
95112 in_select : bool = False # Compilation runtime flag
96113 in_join : bool = False # Compilation runtime flag
@@ -102,7 +119,7 @@ class Compiler(AbstractCompiler):
102119 _counter : List = field (default_factory = lambda : [0 ])
103120
104121 @property
105- def dialect (self ) -> "Dialect " :
122+ def dialect (self ) -> "BaseDialect " :
106123 return self .database .dialect
107124
108125 # TODO: DEPRECATED: Remove once the dialect is used directly in all places.
@@ -223,7 +240,6 @@ class BaseDialect(abc.ABC):
223240 SUPPORTS_PRIMARY_KEY = False
224241 SUPPORTS_INDEXES = False
225242 TYPE_CLASSES : Dict [str , type ] = {}
226- MIXINS = frozenset ()
227243
228244 PLACEHOLDER_TABLE = None # Used for Oracle
229245
@@ -414,7 +430,9 @@ def render_checksum(self, c: Compiler, elem: Checksum) -> str:
414430
415431 def render_concat (self , c : Compiler , elem : Concat ) -> str :
416432 # We coalesce because on some DBs (e.g. MySQL) concat('a', NULL) is NULL
417- items = [f"coalesce({ self .compile (c , Code (self .to_string (self .compile (c , expr ))))} , '<null>')" for expr in elem .exprs ]
433+ items = [
434+ f"coalesce({ self .compile (c , Code (self .to_string (self .compile (c , expr ))))} , '<null>')" for expr in elem .exprs
435+ ]
418436 assert items
419437 if len (items ) == 1 :
420438 return items [0 ]
@@ -559,17 +577,15 @@ def render_groupby(self, c: Compiler, elem: GroupBy) -> str:
559577 columns = columns ,
560578 group_by_exprs = [Code (k ) for k in keys ],
561579 having_exprs = elem .having_exprs ,
562- )
580+ ),
563581 )
564582
565583 keys_str = ", " .join (keys )
566584 columns_str = ", " .join (self .compile (c , x ) for x in columns )
567585 having_str = (
568586 " HAVING " + " AND " .join (map (compile_fn , elem .having_exprs )) if elem .having_exprs is not None else ""
569587 )
570- select = (
571- f"SELECT { columns_str } FROM { self .compile (c .replace (in_select = True ), elem .table )} GROUP BY { keys_str } { having_str } "
572- )
588+ select = f"SELECT { columns_str } FROM { self .compile (c .replace (in_select = True ), elem .table )} GROUP BY { keys_str } { having_str } "
573589
574590 if c .in_select :
575591 select = f"({ select } ) { c .new_unique_name ()} "
@@ -601,7 +617,7 @@ def render_timetravel(self, c: Compiler, elem: TimeTravel) -> str:
601617 # TODO: why is it c.? why not self? time-trvelling is the dialect's thing, isnt't it?
602618 c .time_travel (
603619 elem .table , before = elem .before , timestamp = elem .timestamp , offset = elem .offset , statement = elem .statement
604- )
620+ ),
605621 )
606622
607623 def render_createtable (self , c : Compiler , elem : CreateTable ) -> str :
@@ -768,18 +784,6 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
768784 # See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format
769785 return math .floor (math .log (2 ** p , 10 ))
770786
771- @classmethod
772- def load_mixins (cls , * abstract_mixins ) -> Self :
773- "Load a list of mixins that implement the given abstract mixins"
774- mixins = {m for m in cls .MIXINS if issubclass (m , abstract_mixins )}
775-
776- class _DialectWithMixins (cls , * mixins , * abstract_mixins ):
777- pass
778-
779- _DialectWithMixins .__name__ = cls .__name__
780- return _DialectWithMixins ()
781-
782-
783787 @property
784788 @abstractmethod
785789 def name (self ) -> str :
@@ -822,7 +826,7 @@ def __getitem__(self, i):
822826 return self .rows [i ]
823827
824828
825- class Database (abc .ABC ):
829+ class Database (abc .ABC , _RuntypeHackToFixCicularRefrencedDatabase ):
826830 """Base abstract class for databases.
827831
828832 Used for providing connection code and implementation specific SQL utilities.
0 commit comments