diff --git a/sqlmodel/__init__.py b/sqlmodel/__init__.py index 9175458980..0d64d4ee58 100644 --- a/sqlmodel/__init__.py +++ b/sqlmodel/__init__.py @@ -141,3 +141,4 @@ from .sql.expression import type_coerce as type_coerce from .sql.expression import within_group as within_group from .sql.sqltypes import AutoString as AutoString +from .sql.sqltypes import IntEnum as IntEnum diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py index 512daacbab..a8c2097e70 100644 --- a/sqlmodel/sql/sqltypes.py +++ b/sqlmodel/sql/sqltypes.py @@ -1,8 +1,11 @@ -from typing import Any, cast +import enum +from typing import Any, Optional, Type, TypeVar, cast from sqlalchemy import types from sqlalchemy.engine.interfaces import Dialect +_TIntEnum = TypeVar("_TIntEnum", bound=enum.IntEnum) + class AutoString(types.TypeDecorator): # type: ignore impl = types.String @@ -14,3 +17,31 @@ def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]": if impl.length is None and dialect.name == "mysql": return dialect.type_descriptor(types.String(self.mysql_default_length)) return super().load_dialect_impl(dialect) + + +class IntEnum(types.TypeDecorator[Optional[_TIntEnum]]): + impl = types.SmallInteger + cache_ok = True + + def __init__(self, enum_type: Type[_TIntEnum], *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + + # validate the input enum type + if not issubclass(enum_type, enum.IntEnum): + raise TypeError("Input must be enum.IntEnum") + + self.enum_type = enum_type + + def process_result_value( + self, + value: Optional[int], + dialect: Dialect, + ) -> Optional[_TIntEnum]: + return None if (value is None) else self.enum_type(value) + + def process_bind_param( + self, + value: Optional[_TIntEnum], + dialect: Dialect, + ) -> Optional[int]: + return None if (value is None) else value.value diff --git a/tests/test_enums.py b/tests/test_enums.py index 2808f3f9a9..4a71482145 100644 --- a/tests/test_enums.py +++ b/tests/test_enums.py @@ -43,6 +43,7 @@ def test_postgres_ddl_sql(clear_sqlmodel, capsys: pytest.CaptureFixture[str]): captured = capsys.readouterr() assert "CREATE TYPE myenum1 AS ENUM ('A', 'B');" in captured.out assert "CREATE TYPE myenum2 AS ENUM ('C', 'D');" in captured.out + assert "int_enum_field SMALLINT NOT NULL" in captured.out def test_sqlite_ddl_sql(clear_sqlmodel, capsys: pytest.CaptureFixture[str]): @@ -52,6 +53,7 @@ def test_sqlite_ddl_sql(clear_sqlmodel, capsys: pytest.CaptureFixture[str]): captured = capsys.readouterr() assert "enum_field VARCHAR(1) NOT NULL" in captured.out, captured + assert "int_enum_field SMALLINT NOT NULL" in captured.out, captured assert "CREATE TYPE" not in captured.out @@ -63,15 +65,22 @@ def test_json_schema_flat_model_pydantic_v1(): "properties": { "id": {"title": "Id", "type": "string", "format": "uuid"}, "enum_field": {"$ref": "#/definitions/MyEnum1"}, + "int_enum_field": {"$ref": "#/definitions/MyEnum3"}, }, - "required": ["id", "enum_field"], + "required": ["id", "enum_field", "int_enum_field"], "definitions": { "MyEnum1": { "title": "MyEnum1", "description": "An enumeration.", "enum": ["A", "B"], "type": "string", - } + }, + "MyEnum3": { + "title": "MyEnum3", + "description": "An enumeration.", + "enum": [1, 2], + "type": "integer", + }, }, } @@ -84,15 +93,22 @@ def test_json_schema_inherit_model_pydantic_v1(): "properties": { "id": {"title": "Id", "type": "string", "format": "uuid"}, "enum_field": {"$ref": "#/definitions/MyEnum2"}, + "int_enum_field": {"$ref": "#/definitions/MyEnum3"}, }, - "required": ["id", "enum_field"], + "required": ["id", "enum_field", "int_enum_field"], "definitions": { "MyEnum2": { "title": "MyEnum2", "description": "An enumeration.", "enum": ["C", "D"], "type": "string", - } + }, + "MyEnum3": { + "title": "MyEnum3", + "description": "An enumeration.", + "enum": [1, 2], + "type": "integer", + }, }, } @@ -105,10 +121,12 @@ def test_json_schema_flat_model_pydantic_v2(): "properties": { "id": {"title": "Id", "type": "string", "format": "uuid"}, "enum_field": {"$ref": "#/$defs/MyEnum1"}, + "int_enum_field": {"$ref": "#/$defs/MyEnum3"}, }, - "required": ["id", "enum_field"], + "required": ["id", "enum_field", "int_enum_field"], "$defs": { - "MyEnum1": {"enum": ["A", "B"], "title": "MyEnum1", "type": "string"} + "MyEnum1": {"enum": ["A", "B"], "title": "MyEnum1", "type": "string"}, + "MyEnum3": {"enum": [1, 2], "title": "MyEnum3", "type": "integer"}, }, } @@ -121,9 +139,11 @@ def test_json_schema_inherit_model_pydantic_v2(): "properties": { "id": {"title": "Id", "type": "string", "format": "uuid"}, "enum_field": {"$ref": "#/$defs/MyEnum2"}, + "int_enum_field": {"$ref": "#/$defs/MyEnum3"}, }, - "required": ["id", "enum_field"], + "required": ["id", "enum_field", "int_enum_field"], "$defs": { - "MyEnum2": {"enum": ["C", "D"], "title": "MyEnum2", "type": "string"} + "MyEnum2": {"enum": ["C", "D"], "title": "MyEnum2", "type": "string"}, + "MyEnum3": {"enum": [1, 2], "title": "MyEnum3", "type": "integer"}, }, } diff --git a/tests/test_enums_models.py b/tests/test_enums_models.py index b46ccb7d2b..5edfb534ea 100644 --- a/tests/test_enums_models.py +++ b/tests/test_enums_models.py @@ -1,7 +1,7 @@ import enum import uuid -from sqlmodel import Field, SQLModel +from sqlmodel import Field, IntEnum, SQLModel class MyEnum1(str, enum.Enum): @@ -14,14 +14,21 @@ class MyEnum2(str, enum.Enum): D = "D" +class MyEnum3(enum.IntEnum): + E = 1 + F = 2 + + class BaseModel(SQLModel): id: uuid.UUID = Field(primary_key=True) enum_field: MyEnum2 + int_enum_field: MyEnum3 = Field(sa_type=IntEnum(MyEnum3)) class FlatModel(SQLModel, table=True): id: uuid.UUID = Field(primary_key=True) enum_field: MyEnum1 + int_enum_field: MyEnum3 = Field(sa_type=IntEnum(MyEnum3)) class InheritModel(BaseModel, table=True):