diff --git a/sqlalchemy_firebird/base.py b/sqlalchemy_firebird/base.py index f74399e..09a7d9b 100644 --- a/sqlalchemy_firebird/base.py +++ b/sqlalchemy_firebird/base.py @@ -3,7 +3,7 @@ from packaging import version -from typing import List +from typing import Any, List, TypedDict from typing import Optional from sqlalchemy import __version__ as SQLALCHEMY_VERSION @@ -377,19 +377,18 @@ def visit_boolean(self, type_, **kw): def visit_datetime(self, type_, **kw): return self.visit_TIMESTAMP(type_, **kw) - def _render_string_type(self, type_, name, length_override=None): + def _render_firebird_string_type( + self, + name: str, + length: Optional[int]=None, + collation: Optional[str]=None, + charset: Optional[str]=None, + ) -> str: firebird_3_or_lower = ( self.dialect.server_version_info and self.dialect.server_version_info < (4,) ) - length = coalesce( - length_override, - getattr(type_, "length", None), - ) - charset = getattr(type_, "charset", None) - collation = getattr(type_, "collation", None) - if name in ["BINARY", "VARBINARY", "NCHAR", "NVARCHAR"]: charset = None collation = None @@ -432,11 +431,33 @@ def _render_string_type(self, type_, name, length_override=None): return text - def visit_BINARY(self, type_, **kw): - return self._render_string_type(type_, "BINARY") + def visit_CHAR(self, type_: fb_types.FBCHAR, **kw: Any) -> str: + return self._render_firebird_string_type( + "CHAR", + type_.length, + type_.collation, + getattr(type_, "charset", None), + ) + + def visit_NCHAR(self, type_: fb_types.FBNCHAR, **kw: Any) -> str: + return self._render_firebird_string_type("NCHAR", type_.length, type_.collation) - def visit_VARBINARY(self, type_, **kw): - return self._render_string_type(type_, "VARBINARY") + def visit_VARCHAR(self, type_: fb_types.FBVARCHAR, **kw: Any) -> str: + return self._render_firebird_string_type( + "VARCHAR", + type_.length, + type_.collation, + getattr(type_, "charset", None), + ) + + def visit_NVARCHAR(self, type_: fb_types.FBNCHAR, **kw: Any) -> str: + return self._render_firebird_string_type("NVARCHAR", type_.length, type_.collation) + + def visit_BINARY(self, type_: fb_types.FBBINARY, **kw) -> str: + return self._render_firebird_string_type("BINARY", type_.length) + + def visit_VARBINARY(self, type_: fb_types.FBVARBINARY, **kw) -> str: + return self._render_firebird_string_type("VARBINARY", type_.length) def visit_TEXT(self, type_, **kw): return self.visit_BLOB(type_, override_subtype=1, **kw) @@ -490,9 +511,11 @@ def visit_TIMESTAMP(self, type_, **kw): return super().visit_TIMESTAMP(type_, **kw) return "TIMESTAMP%s %s" % ( - "(%d)" % type_.precision - if getattr(type_, "precision", None) is not None - else "", + ( + "(%d)" % type_.precision + if getattr(type_, "precision", None) is not None + else "" + ), (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE", ) @@ -501,9 +524,11 @@ def visit_TIME(self, type_, **kw): return super().visit_TIME(type_, **kw) return "TIME%s %s" % ( - "(%d)" % type_.precision - if getattr(type_, "precision", None) is not None - else "", + ( + "(%d)" % type_.precision + if getattr(type_, "precision", None) is not None + else "" + ), (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE", ) @@ -528,7 +553,7 @@ def fire_sequence(self, seq, type_): ) -class ReflectedDomain(util.typing.TypedDict): +class ReflectedDomain(TypedDict): """Represents a reflected domain.""" name: str @@ -601,7 +626,7 @@ class FBDialect(default.DefaultDialect): colspecs = { sa_types.String: fb_types._FBString, - sa_types.Numeric: fb_types._FBNumeric, + sa_types.Numeric: fb_types.FBNUMERIC, sa_types.Float: fb_types.FBFLOAT, sa_types.Double: fb_types.FBDOUBLE_PRECISION, sa_types.Date: fb_types.FBDATE, @@ -611,14 +636,14 @@ class FBDialect(default.DefaultDialect): sa_types.BigInteger: fb_types.FBBIGINT, sa_types.Integer: fb_types.FBINTEGER, sa_types.SmallInteger: fb_types.FBSMALLINT, - sa_types.BINARY: fb_types._FBLargeBinary, - sa_types.VARBINARY: fb_types._FBLargeBinary, - sa_types.LargeBinary: fb_types._FBLargeBinary, + sa_types.BINARY: fb_types.FBBINARY, + sa_types.VARBINARY: fb_types.FBVARBINARY, + sa_types.LargeBinary: fb_types.FBBLOB, } # SELECT TRIM(rdb$type_name) FROM rdb$types WHERE rdb$field_name = 'RDB$FIELD_TYPE' ORDER BY 1 ischema_names = { - "BLOB": fb_types._FBLargeBinary, + "BLOB": fb_types.FBBLOB, # "BLOB_ID": unused "BOOLEAN": fb_types.FBBOOLEAN, "CSTRING": fb_types.FBVARCHAR, @@ -853,7 +878,7 @@ def get_columns( # noqa: C901 charset=row.character_set_name, collation=row.collation_name, ) - elif issubclass(colclass, fb_types._FBNumeric): + elif colclass in (fb_types.FBFLOAT, fb_types.FBDOUBLE_PRECISION, fb_types.FBDECFLOAT): # FLOAT, DOUBLE PRECISION or DECFLOAT coltype = colclass(row.field_precision) elif issubclass(colclass, fb_types._FBInteger): @@ -874,13 +899,9 @@ def get_columns( # noqa: C901 elif issubclass(colclass, sa_types.DateTime): has_timezone = "WITH TIME ZONE" in row.field_type coltype = colclass(timezone=has_timezone) - elif issubclass(colclass, fb_types._FBLargeBinary): + elif issubclass(colclass, fb_types.FBBLOB): if row.field_sub_type == 1: - coltype = fb_types.FBTEXT( - row.segment_length, - row.character_set_name, - row.collation_name, - ) + coltype = fb_types.FBTEXT(row.segment_length, row.character_set_name, row.collation_name) else: coltype = fb_types.FBBLOB(row.segment_length) else: diff --git a/sqlalchemy_firebird/firebird.py b/sqlalchemy_firebird/firebird.py index d767f6f..4d7fe3f 100644 --- a/sqlalchemy_firebird/firebird.py +++ b/sqlalchemy_firebird/firebird.py @@ -19,6 +19,7 @@ import firebird.driver from firebird.driver import driver_config from firebird.driver import get_timezone +from firebird.driver.types import DatabaseError; class FBDialect_firebird(FBDialect): @@ -63,7 +64,13 @@ def get_deferrable(self, connection): return connection.deferrable def do_terminate(self, dbapi_connection) -> None: - dbapi_connection.terminate() + try: + dbapi_connection.close() + except DatabaseError as err: + # Ignore errors during connection termination, as the connection may already be in an invalid state. + # https://github.com/pauldex/sqlalchemy-firebird/issues/72 + if not self.is_disconnect(err, None, None): + raise def create_connect_args(self, url): opts = url.translate_connect_args(username="user") diff --git a/sqlalchemy_firebird/infrastructure.py b/sqlalchemy_firebird/infrastructure.py index 3ab0e66..db2dc85 100644 --- a/sqlalchemy_firebird/infrastructure.py +++ b/sqlalchemy_firebird/infrastructure.py @@ -21,15 +21,15 @@ # if os_name == "nt": - FB50_URL = "https://github.com/FirebirdSQL/firebird/releases/download/v5.0.0-RC2/Firebird-5.0.0.1304-RC2-windows-x64.zip" - FB40_URL = "https://github.com/FirebirdSQL/firebird/releases/download/v4.0.4/Firebird-4.0.4.3010-0-x64.zip" - FB30_URL = "https://github.com/FirebirdSQL/firebird/releases/download/v3.0.11/Firebird-3.0.11.33703-0_x64.zip" + FB50_URL = "https://github.com/FirebirdSQL/firebird/releases/download/v5.0.2/Firebird-5.0.2.1613-0-windows-x64.zip" + FB40_URL = "https://github.com/FirebirdSQL/firebird/releases/download/v4.0.5/Firebird-4.0.5.3140-0-x64.zip" + FB30_URL = "https://github.com/FirebirdSQL/firebird/releases/download/v3.0.12/Firebird-3.0.12.33787-0-x64.zip" FB25_URL = "https://github.com/FirebirdSQL/firebird/releases/download/R2_5_9/Firebird-2.5.9.27139-0_x64_embed.zip" FB25_EXTRA_URL = "https://github.com/FirebirdSQL/firebird/releases/download/R2_5_9/Firebird-2.5.9.27139-0_x64.zip" else: - FB50_URL = "https://github.com/FirebirdSQL/firebird/releases/download/v5.0.0-RC2/Firebird-5.0.0.1304-RC2-linux-x64.tar.gz" - FB40_URL = "https://github.com/FirebirdSQL/firebird/releases/download/v4.0.4/Firebird-4.0.4.3010-0.amd64.tar.gz" - FB30_URL = "https://github.com/FirebirdSQL/firebird/releases/download/v3.0.11/Firebird-3.0.11.33703-0.amd64.tar.gz" + FB50_URL = "https://github.com/FirebirdSQL/firebird/releases/download/v5.0.2/Firebird-5.0.2.1613-0-linux-arm64.tar.gz" + FB40_URL = "https://github.com/FirebirdSQL/firebird/releases/download/v4.0.5/Firebird-4.0.5.3140-0.amd64.tar.gz" + FB30_URL = "https://github.com/FirebirdSQL/firebird/releases/download/v3.0.12/Firebird-3.0.12.33787-0.amd64.tar.gz" FB25_URL = "https://github.com/FirebirdSQL/firebird/releases/download/R2_5_9/FirebirdCS-2.5.9.27139-0.amd64.tar.gz" TEMP_PATH = gettempdir() diff --git a/sqlalchemy_firebird/types.py b/sqlalchemy_firebird/types.py index 195b41d..f6ba7c9 100644 --- a/sqlalchemy_firebird/types.py +++ b/sqlalchemy_firebird/types.py @@ -15,103 +15,157 @@ class _FBString(sqltypes.String): render_bind_cast = True - def __init__(self, length=None, charset=None, collation=None): - super().__init__(length, collation) + def __init__(self, charset: Optional[str] = None, **kw: Any): self.charset = charset + # Only pass parameters that the parent String class accepts + string_kwargs = {} + if 'length' in kw: + string_kwargs['length'] = kw['length'] + if 'collation' in kw: + string_kwargs['collation'] = kw['collation'] + super().__init__(**string_kwargs) -class FBCHAR(_FBString): +class FBCHAR(_FBString, sqltypes.CHAR): __visit_name__ = "CHAR" - def __init__(self, length=None, charset=None, collation=None): - super().__init__(length, charset, collation) + def __init__(self, length: Optional[int] = None, **kwargs: Any): + super().__init__(length=length, **kwargs) class FBBINARY(FBCHAR): __visit_name__ = "BINARY" # Synonym for CHAR(n) CHARACTER SET OCTETS - def __init__(self, length=None, charset=None, collation=None): - super().__init__(length, BINARY_CHARSET) + def __init__(self, length: Optional[int] = None, **kwargs: Any): + kwargs["charset"] = BINARY_CHARSET + super().__init__(length=length, **kwargs) -class FBNCHAR(FBCHAR): +class FBNCHAR(FBCHAR, sqltypes.NCHAR): __visit_name__ = "NCHAR" # Synonym for CHAR(n) CHARACTER SET ISO8859_1 - def __init__(self, length=None, charset=None, collation=None): - super().__init__(length, NATIONAL_CHARSET) + def __init__(self, length: Optional[int] = None, **kwargs: Any): + kwargs["charset"] = NATIONAL_CHARSET + super().__init__(length=length, **kwargs) -class FBVARCHAR(_FBString): +class FBVARCHAR(_FBString, sqltypes.VARCHAR): __visit_name__ = "VARCHAR" - def __init__(self, length=None, charset=None, collation=None): - super().__init__(length, charset, collation) + def __init__(self, length: Optional[int] = None, **kwargs: Any): + super().__init__(length=length, **kwargs) class FBVARBINARY(FBVARCHAR): __visit_name__ = "VARBINARY" # Synonym for VARCHAR(n) CHARACTER SET OCTETS - def __init__(self, length=None, charset=None, collation=None): - super().__init__(length, BINARY_CHARSET) + def __init__(self, length: Optional[int] = None, **kwargs: Any): + kwargs["charset"] = BINARY_CHARSET + super().__init__(length=length, **kwargs) -class FBNVARCHAR(FBVARCHAR): +class FBNVARCHAR(FBVARCHAR, sqltypes.NVARCHAR): __visit_name__ = "NVARCHAR" # Synonym for VARCHAR(n) CHARACTER SET ISO8859_1 - def __init__(self, length=None, charset=None, collation=None): - super().__init__(length, NATIONAL_CHARSET) + def __init__(self, length: Optional[int] = None, **kwargs: Any): + kwargs["charset"] = NATIONAL_CHARSET + super().__init__(length=length, **kwargs) -class _FBNumeric(sqltypes.Numeric): +class FBFLOAT(sqltypes.FLOAT): + __visit_name__ = "FLOAT" render_bind_cast = True + def __init__(self, precision=None, **kwargs): + # FLOAT doesn't accept 'scale' parameter, filter it out + float_kwargs = {k: v for k, v in kwargs.items() if k != 'scale'} + # Set precision if provided + if precision is not None: + float_kwargs['precision'] = precision + # Provide defaults for required parameters + float_kwargs.setdefault("precision", None) + float_kwargs.setdefault("decimal_return_scale", None) + float_kwargs.setdefault("asdecimal", False) + super().__init__(**float_kwargs) + def bind_processor(self, dialect): return None # Dialect supports_native_decimal = True (no processor needed) -class FBFLOAT(_FBNumeric, sqltypes.FLOAT): - __visit_name__ = "FLOAT" +class FBDOUBLE_PRECISION(sqltypes.DOUBLE_PRECISION): + __visit_name__ = "DOUBLE_PRECISION" + render_bind_cast = True + def __init__(self, precision=None, **kwargs): + # DOUBLE_PRECISION doesn't accept 'scale' parameter, filter it out + float_kwargs = {k: v for k, v in kwargs.items() if k != 'scale'} + # Set precision if provided + if precision is not None: + float_kwargs['precision'] = precision + # Provide defaults for required parameters + float_kwargs.setdefault("precision", None) + float_kwargs.setdefault("decimal_return_scale", None) + float_kwargs.setdefault("asdecimal", False) + super().__init__(**float_kwargs) -class FBDOUBLE_PRECISION(_FBNumeric, sqltypes.DOUBLE_PRECISION): - __visit_name__ = "DOUBLE_PRECISION" + def bind_processor(self, dialect): + return None # Dialect supports_native_decimal = True (no processor needed) -class FBDECFLOAT(_FBNumeric): +class FBDECFLOAT(sqltypes.Numeric): __visit_name__ = "DECFLOAT" + render_bind_cast = True + + def __init__(self, precision=None, **kwargs): + # DECFLOAT (Numeric) accepts all parameters + if precision is not None: + kwargs['precision'] = precision + kwargs.setdefault("precision", None) + kwargs.setdefault("scale", None) + kwargs.setdefault("decimal_return_scale", None) + kwargs.setdefault("asdecimal", False) + super().__init__(**kwargs) + + def bind_processor(self, dialect): + return None # Dialect supports_native_decimal = True (no processor needed) class FBREAL(FBFLOAT): __visit_name__ = "REAL" - # Synonym for FLOAT - def __init__(self, precision=None, scale=None): - super().__init__(None, None) - -class _FBFixedPoint(_FBNumeric): - def __init__( - self, - precision=None, - scale=None, - decimal_return_scale=None, - asdecimal=None, - ): - super().__init__( - precision, scale, decimal_return_scale, asdecimal=True - ) +class FBDECIMAL(sqltypes.DECIMAL): + __visit_name__ = "DECIMAL" + render_bind_cast = True + def __init__(self, **kwargs: Any): + kwargs["asdecimal"] = True + kwargs.setdefault("precision", None) + kwargs.setdefault("scale", None) + kwargs.setdefault("decimal_return_scale", None) + super().__init__(**kwargs) -class FBDECIMAL(_FBFixedPoint): - __visit_name__ = "DECIMAL" + def bind_processor(self, dialect): + return None # Dialect supports_native_decimal = True (no processor needed) -class FBNUMERIC(_FBFixedPoint): +class FBNUMERIC(sqltypes.NUMERIC): __visit_name__ = "NUMERIC" + render_bind_cast = True + + def __init__(self, **kwargs: Any): + kwargs.setdefault("asdecimal", True) + kwargs.setdefault("precision", None) + kwargs.setdefault("scale", None) + kwargs.setdefault("decimal_return_scale", None) + super().__init__(**kwargs) + + def bind_processor(self, dialect): + return None # Dialect supports_native_decimal = True (no processor needed) class FBDATE(sqltypes.DATE): @@ -150,51 +204,40 @@ class FBBOOLEAN(sqltypes.BOOLEAN): render_bind_cast = True -class _FBLargeBinary(sqltypes.LargeBinary): +class FBBLOB(sqltypes.BLOB): + __visit_name__ = "BLOB" # BLOB SUB_TYPE 0 (BINARY) render_bind_cast = True - def __init__( - self, subtype=None, segment_size=None, charset=None, collation=None - ): + def __init__(self, segment_size=None): super().__init__() - self.subtype = subtype + self.subtype = 0 self.segment_size = segment_size - self.charset = charset - self.collation = collation def bind_processor(self, dialect): def process(value): - return None if value is None else bytes(value) + return bytes(value) return process -class FBBLOB(_FBLargeBinary, sqltypes.BLOB): - __visit_name__ = "BLOB" - - def __init__( - self, - segment_size=None, - ): - super().__init__(0, segment_size) - - -class FBTEXT(_FBLargeBinary, sqltypes.TEXT): - __visit_name__ = "BLOB" +class FBTEXT(sqltypes.TEXT): + __visit_name__ = "BLOB" # BLOB SUB_TYPE 1 (TEXT) + render_bind_cast = True - def __init__( - self, - segment_size=None, - charset=None, - collation=None, - ): - super().__init__(1, segment_size, charset, collation) + def __init__(self, segment_size=None, charset=None, collation=None): + super().__init__() + self.subtype = 1 + self.segment_size = segment_size + self.charset = charset + self.collation = collation -class _FBNumericInterval(_FBNumeric): +class _FBNumericInterval(FBNUMERIC): # NUMERIC(18,9) -- Used for _FBInterval storage - def __init__(self): - super().__init__(precision=18, scale=9) + def __init__(self, **kwargs: Any): + kwargs["precision"] = 18 + kwargs["scale"] = 9 + super().__init__(**kwargs) class _FBInterval(sqltypes.Interval): diff --git a/test/test_issues.py b/test/test_issues.py new file mode 100644 index 0000000..5aa8976 --- /dev/null +++ b/test/test_issues.py @@ -0,0 +1,41 @@ +from sqlalchemy import Column, select +from sqlalchemy import Table +from sqlalchemy import testing +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import fixtures + +import sqlalchemy_firebird.types as fb_types + +TEST_UNICODE = "próf-áêïôù-🗄️.fdb" +TEST_BINARY = TEST_UNICODE.encode('utf-8') + +class IssuesTest(fixtures.TestBase): + @testing.provide_metadata + @testing.combinations( + (fb_types.FBTEXT, 'hi'), + (fb_types.FBTEXT, TEST_UNICODE), + (fb_types.FBBLOB, TEST_BINARY), argnames="type_, expected" + ) + def test_issue_76(self, connection, type_, expected): + metadata = self.metadata + + the_blob = Table( + "the_blob", + metadata, + Column("the_value", type_) + ) + metadata.create_all(testing.db) + + connection.execute( + the_blob.insert() + .values(dict( + the_value=expected, + )) + ) + + eq_( + connection.execute( + select(the_blob.c.the_value) + ).scalar(), + expected, + ) diff --git a/test/test_suite.py b/test/test_suite.py index a868aec..d7c2f51 100644 --- a/test/test_suite.py +++ b/test/test_suite.py @@ -6,24 +6,33 @@ from packaging import version from sqlalchemy import __version__ as SQLALCHEMY_VERSION from sqlalchemy import Index -from sqlalchemy.testing import is_false +from sqlalchemy.testing import config, is_false from sqlalchemy.testing.suite import * # noqa: F401, F403 from sqlalchemy.testing.suite import ( + BizarroCharacterTest as _BizarroCharacterTest, CTETest as _CTETest, ComponentReflectionTest as _ComponentReflectionTest, ComponentReflectionTestExtra as _ComponentReflectionTestExtra, CompoundSelectTest as _CompoundSelectTest, - DeprecatedCompoundSelectTest as _DeprecatedCompoundSelectTest, IdentityColumnTest as _IdentityColumnTest, IdentityReflectionTest as _IdentityReflectionTest, StringTest as _StringTest, InsertBehaviorTest as _InsertBehaviorTest, RowCountTest as _RowCountTest, SimpleUpdateDeleteTest as _SimpleUpdateDeleteTest, + TempTableElementsTest as _TempTableElementsTest, + WindowFunctionTest as _WindowFunctionTest ) +@pytest.mark.skipif( + config.db.dialect.server_version_info < (4,), + reason="These tests rely on correct identity semantics, which were only fixed starting from Firebird 4.0." +) +class BizarroCharacterTest(_BizarroCharacterTest): + pass + @pytest.mark.skip( reason="These tests fails in Firebird because a DELETE FROM