|
20 | 20 | # software solely pursuant to the terms of the relevant commercial agreement. |
21 | 21 |
|
22 | 22 | from datetime import datetime |
23 | | -from unittest import TestCase |
| 23 | +from unittest import TestCase, skipIf |
24 | 24 | from unittest.mock import MagicMock, patch |
25 | 25 |
|
26 | 26 | import sqlalchemy as sa |
27 | 27 |
|
28 | 28 | from crate.client.cursor import Cursor |
| 29 | +from crate.client.sqlalchemy import SA_VERSION |
| 30 | +from crate.client.sqlalchemy.sa_version import SA_1_4, SA_2_0 |
29 | 31 | from crate.client.sqlalchemy.types import Object |
30 | 32 | from sqlalchemy import inspect |
31 | 33 | from sqlalchemy.orm import Session |
32 | 34 | try: |
33 | 35 | from sqlalchemy.orm import declarative_base |
34 | 36 | except ImportError: |
35 | 37 | from sqlalchemy.ext.declarative import declarative_base |
36 | | -from sqlalchemy.testing import eq_, in_ |
| 38 | +from sqlalchemy.testing import eq_, in_, is_true |
37 | 39 |
|
38 | 40 | FakeCursor = MagicMock(name='FakeCursor', spec=Cursor) |
39 | 41 |
|
@@ -70,6 +72,13 @@ class Character(self.base): |
70 | 72 |
|
71 | 73 | self.session = Session(bind=self.engine) |
72 | 74 |
|
| 75 | + def init_mock(self, return_value=None): |
| 76 | + self.fake_cursor.rowcount = 1 |
| 77 | + self.fake_cursor.description = ( |
| 78 | + ('foo', None, None, None, None, None, None), |
| 79 | + ) |
| 80 | + self.fake_cursor.fetchall = MagicMock(return_value=return_value) |
| 81 | + |
73 | 82 | def test_primary_keys_2_3_0(self): |
74 | 83 | insp = inspect(self.session.bind) |
75 | 84 | self.engine.dialect.server_version_info = (2, 3, 0) |
@@ -126,3 +135,22 @@ def test_get_view_names(self): |
126 | 135 | ['v1', 'v2']) |
127 | 136 | eq_(self.executed_statement, "SELECT table_name FROM information_schema.views " |
128 | 137 | "ORDER BY table_name ASC, table_schema ASC") |
| 138 | + |
| 139 | + @skipIf(SA_VERSION < SA_1_4, "Inspector.has_table only available on SQLAlchemy>=1.4") |
| 140 | + def test_has_table(self): |
| 141 | + self.init_mock(return_value=[["foo"], ["bar"]]) |
| 142 | + insp = inspect(self.session.bind) |
| 143 | + is_true(insp.has_table("bar")) |
| 144 | + eq_(self.executed_statement, |
| 145 | + "SELECT table_name FROM information_schema.tables " |
| 146 | + "WHERE table_schema = ? AND table_type = 'BASE TABLE' " |
| 147 | + "ORDER BY table_name ASC, table_schema ASC") |
| 148 | + |
| 149 | + @skipIf(SA_VERSION < SA_2_0, "Inspector.has_schema only available on SQLAlchemy>=2.0") |
| 150 | + def test_has_schema(self): |
| 151 | + self.init_mock( |
| 152 | + return_value=[["blob"], ["doc"], ["information_schema"], ["pg_catalog"], ["sys"]]) |
| 153 | + insp = inspect(self.session.bind) |
| 154 | + is_true(insp.has_schema("doc")) |
| 155 | + eq_(self.executed_statement, |
| 156 | + "select schema_name from information_schema.schemata order by schema_name asc") |
0 commit comments