|
12 | 12 | import sqlalchemy.types as sqltypes |
13 | 13 |
|
14 | 14 | from .stmt_compiler import CockroachCompiler, CockroachIdentifierPreparer |
| 15 | +from .ddl_compiler import CockroachDDLCompiler |
15 | 16 |
|
16 | 17 | # Map type names (as returned by information_schema) to sqlalchemy type |
17 | 18 | # objects. |
@@ -97,6 +98,7 @@ class CockroachDBDialect(PGDialect_psycopg2): |
97 | 98 | supports_sequences = False |
98 | 99 | statement_compiler = CockroachCompiler |
99 | 100 | preparer = CockroachIdentifierPreparer |
| 101 | + ddl_compiler = CockroachDDLCompiler |
100 | 102 |
|
101 | 103 | def __init__(self, *args, **kwargs): |
102 | 104 | if kwargs.get("use_native_hstore", False): |
@@ -160,11 +162,22 @@ def get_columns(self, conn, table_name, schema=None, **kw): |
160 | 162 | # Oh well. Hoping 1.1 won't be around for long. |
161 | 163 | rows = conn.execute('SHOW COLUMNS FROM "%s"."%s"' % |
162 | 164 | (schema or self.default_schema_name, table_name)) |
| 165 | + elif not self._is_v191plus: |
| 166 | + # v2.x does not have is_generated or generation_expression |
| 167 | + rows = conn.execute( |
| 168 | + 'SELECT column_name, data_type, is_nullable::bool, column_default, ' |
| 169 | + 'numeric_precision, numeric_scale, character_maximum_length, ' |
| 170 | + 'NULL AS is_generated, NULL AS generation_expression ' |
| 171 | + 'FROM information_schema.columns ' |
| 172 | + 'WHERE table_schema = %s AND table_name = %s AND NOT is_hidden::bool', |
| 173 | + (schema or self.default_schema_name, table_name), |
| 174 | + ) |
163 | 175 | else: |
164 | | - # v2.0 or later. Information schema is usable. |
| 176 | + # v19.1 or later. Information schema columns are all usable. |
165 | 177 | rows = conn.execute( |
166 | 178 | 'SELECT column_name, data_type, is_nullable::bool, column_default, ' |
167 | | - 'numeric_precision, numeric_scale, character_maximum_length ' |
| 179 | + 'numeric_precision, numeric_scale, character_maximum_length, ' |
| 180 | + 'is_generated::bool, generation_expression ' |
168 | 181 | 'FROM information_schema.columns ' |
169 | 182 | 'WHERE table_schema = %s AND table_name = %s AND NOT is_hidden::bool', |
170 | 183 | (schema or self.default_schema_name, table_name), |
@@ -198,12 +211,47 @@ def get_columns(self, conn, table_name, schema=None, **kw): |
198 | 211 | typ = type_class(length=row.character_maximum_length) |
199 | 212 | else: |
200 | 213 | typ = type_class |
201 | | - res.append(dict( |
| 214 | + if row.is_generated: |
| 215 | + # Currently, all computed columns are persisted. |
| 216 | + computed = dict(sqltext=row.generation_expression, persisted=True) |
| 217 | + default = None |
| 218 | + else: |
| 219 | + computed = None |
| 220 | + # Check if a sequence is being used and adjust the default value. |
| 221 | + autoincrement = False |
| 222 | + if default is not None: |
| 223 | + nextval_match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) |
| 224 | + unique_rowid_match = re.search(r"""unique_rowid\(""", default) |
| 225 | + if nextval_match is not None or unique_rowid_match is not None: |
| 226 | + print('affinity', type_class) |
| 227 | + if issubclass(type_class, sqltypes.Integer): |
| 228 | + autoincrement = True |
| 229 | + # the default is related to a Sequence |
| 230 | + sch = schema |
| 231 | + if nextval_match is not None \ |
| 232 | + and "." not in nextval_match.group(2) \ |
| 233 | + and sch is not None: |
| 234 | + # unconditionally quote the schema name. this could |
| 235 | + # later be enhanced to obey quoting rules / |
| 236 | + # "quote schema" |
| 237 | + default = ( |
| 238 | + nextval_match.group(1) |
| 239 | + + ('"%s"' % sch) |
| 240 | + + "." |
| 241 | + + nextval_match.group(2) |
| 242 | + + nextval_match.group(3) |
| 243 | + ) |
| 244 | + |
| 245 | + column_info = dict( |
202 | 246 | name=name, |
203 | 247 | type=typ, |
204 | 248 | nullable=nullable, |
205 | 249 | default=default, |
206 | | - )) |
| 250 | + autoincrement=autoincrement, |
| 251 | + ) |
| 252 | + if computed is not None: |
| 253 | + column_info["computed"] = computed |
| 254 | + res.append(column_info) |
207 | 255 | return res |
208 | 256 |
|
209 | 257 | def get_indexes(self, conn, table_name, schema=None, **kw): |
|
0 commit comments