diff --git a/sqlalchemy_firebird/base.py b/sqlalchemy_firebird/base.py index f74399e..a9ba973 100644 --- a/sqlalchemy_firebird/base.py +++ b/sqlalchemy_firebird/base.py @@ -3,7 +3,7 @@ from packaging import version -from typing import List +from typing import Any, List from typing import Optional from sqlalchemy import __version__ as SQLALCHEMY_VERSION @@ -377,19 +377,18 @@ def visit_boolean(self, type_, **kw): def visit_datetime(self, type_, **kw): return self.visit_TIMESTAMP(type_, **kw) - def _render_string_type(self, type_, name, length_override=None): + def _render_firebird_string_type( + self, + name: str, + length: Optional[int], + collation: Optional[str], + charset: Optional[str], + ) -> str: firebird_3_or_lower = ( self.dialect.server_version_info and self.dialect.server_version_info < (4,) ) - length = coalesce( - length_override, - getattr(type_, "length", None), - ) - charset = getattr(type_, "charset", None) - collation = getattr(type_, "collation", None) - if name in ["BINARY", "VARBINARY", "NCHAR", "NVARCHAR"]: charset = None collation = None @@ -432,11 +431,33 @@ def _render_string_type(self, type_, name, length_override=None): return text - def visit_BINARY(self, type_, **kw): - return self._render_string_type(type_, "BINARY") + def visit_CHAR(self, type_: fb_types.FBCHAR, **kw: Any) -> str: + return self._render_firebird_string_type( + "CHAR", + type_.length, + type_.collation, + getattr(type_, "charset", None), + ) + + def visit_NCHAR(self, type_: fb_types.FBNCHAR, **kw: Any) -> str: + return self._render_firebird_string_type("NCHAR", type_.length, type_.collation) - def visit_VARBINARY(self, type_, **kw): - return self._render_string_type(type_, "VARBINARY") + def visit_VARCHAR(self, type_: fb_types.FBVARCHAR, **kw: Any) -> str: + return self._render_firebird_string_type( + "VARCHAR", + type_.length, + type_.collation, + getattr(type_, "charset", None), + ) + + def visit_NVARCHAR(self, type_: fb_types.FBNCHAR, **kw: Any) -> str: + return self._render_firebird_string_type("NVARCHAR", type_.length, type_.collation) + + def visit_BINARY(self, type_: fb_types.FBBINARY, **kw) -> str: + return self._render_firebird_string_type("BINARY", type_.length) + + def visit_VARBINARY(self, type_: fb_types.FBVARBINARY, **kw) -> str: + return self._render_firebird_string_type("VARBINARY", type_.length) def visit_TEXT(self, type_, **kw): return self.visit_BLOB(type_, override_subtype=1, **kw) @@ -490,9 +511,11 @@ def visit_TIMESTAMP(self, type_, **kw): return super().visit_TIMESTAMP(type_, **kw) return "TIMESTAMP%s %s" % ( - "(%d)" % type_.precision - if getattr(type_, "precision", None) is not None - else "", + ( + "(%d)" % type_.precision + if getattr(type_, "precision", None) is not None + else "" + ), (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE", ) @@ -501,9 +524,11 @@ def visit_TIME(self, type_, **kw): return super().visit_TIME(type_, **kw) return "TIME%s %s" % ( - "(%d)" % type_.precision - if getattr(type_, "precision", None) is not None - else "", + ( + "(%d)" % type_.precision + if getattr(type_, "precision", None) is not None + else "" + ), (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE", )