|
10 | 10 | # See the License for the specific language governing permissions and |
11 | 11 | # limitations under the License. |
12 | 12 | from sqlalchemy.sql import compiler |
| 13 | +try: |
| 14 | + from sqlalchemy.sql.expression import ( |
| 15 | + Alias, |
| 16 | + CTE, |
| 17 | + Subquery, |
| 18 | + ) |
| 19 | +except ImportError: |
| 20 | + # For SQLAlchemy versions < 1.4, the CTE and Subquery classes did not explicitly exist |
| 21 | + from sqlalchemy.sql.expression import Alias |
| 22 | + CTE = type(None) |
| 23 | + Subquery = type(None) |
13 | 24 |
|
14 | 25 | # https://trino.io/docs/current/language/reserved.html |
15 | 26 | RESERVED_WORDS = { |
@@ -102,6 +113,31 @@ def limit_clause(self, select, **kw): |
102 | 113 | text += "\nLIMIT " + self.process(select._limit_clause, **kw) |
103 | 114 | return text |
104 | 115 |
|
| 116 | + def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, |
| 117 | + fromhints=None, use_schema=True, **kwargs): |
| 118 | + sql = super(TrinoSQLCompiler, self).visit_table( |
| 119 | + table, asfrom, iscrud, ashint, fromhints, use_schema, **kwargs |
| 120 | + ) |
| 121 | + return self.add_catalog(sql, table) |
| 122 | + |
| 123 | + @staticmethod |
| 124 | + def add_catalog(sql, table): |
| 125 | + if table is None: |
| 126 | + return sql |
| 127 | + |
| 128 | + if isinstance(table, (Alias, CTE, Subquery)): |
| 129 | + return sql |
| 130 | + |
| 131 | + if ( |
| 132 | + 'trino' not in table.dialect_options |
| 133 | + or 'catalog' not in table.dialect_options['trino'] |
| 134 | + ): |
| 135 | + return sql |
| 136 | + |
| 137 | + catalog = table.dialect_options['trino']['catalog'] |
| 138 | + sql = f'"{catalog}".{sql}' |
| 139 | + return sql |
| 140 | + |
105 | 141 |
|
106 | 142 | class TrinoDDLCompiler(compiler.DDLCompiler): |
107 | 143 | pass |
@@ -173,3 +209,7 @@ def visit_TIME(self, type_, **kw): |
173 | 209 |
|
174 | 210 | class TrinoIdentifierPreparer(compiler.IdentifierPreparer): |
175 | 211 | reserved_words = RESERVED_WORDS |
| 212 | + |
| 213 | + def format_table(self, table, use_schema=True, name=None): |
| 214 | + result = super(TrinoIdentifierPreparer, self).format_table(table, use_schema, name) |
| 215 | + return TrinoSQLCompiler.add_catalog(result, table) |
0 commit comments