|
| 1 | +from typing import Optional |
| 2 | + |
| 3 | +from sqlmodel import Field, Session, SQLModel, create_engine, select |
| 4 | +from sqlmodel.pool import StaticPool |
| 5 | + |
| 6 | + |
| 7 | +def test_fields() -> None: |
| 8 | + class Hero(SQLModel, table=True): |
| 9 | + id: Optional[int] = Field(default=None, primary_key=True) |
| 10 | + name: str |
| 11 | + secret_name: str |
| 12 | + age: Optional[int] = None |
| 13 | + food: Optional[str] = None |
| 14 | + |
| 15 | + engine = create_engine( |
| 16 | + "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool |
| 17 | + ) |
| 18 | + |
| 19 | + SQLModel.metadata.create_all(engine) |
| 20 | + |
| 21 | + with Session(engine) as session: |
| 22 | + session.add(Hero(name="Deadpond", secret_name="Dive Wilson")) |
| 23 | + session.add( |
| 24 | + Hero(name="Spider-Boy", secret_name="Pedro Parqueador", food="pizza") |
| 25 | + ) |
| 26 | + session.add(Hero(name="Rusty-Man", secret_name="Tommy Sharp", age=48)) |
| 27 | + |
| 28 | + session.commit() |
| 29 | + |
| 30 | + # check typing of select with 3 fields |
| 31 | + with Session(engine) as session: |
| 32 | + statement_3 = select(Hero.id, Hero.name, Hero.secret_name) |
| 33 | + results_3 = session.exec(statement_3) |
| 34 | + for hero_3 in results_3: |
| 35 | + assert len(hero_3) == 3 |
| 36 | + name_3: str = hero_3[1] |
| 37 | + assert type(name_3) is str |
| 38 | + assert type(hero_3[0]) is int |
| 39 | + assert type(hero_3[2]) is str |
| 40 | + |
| 41 | + # check typing of select with 4 fields |
| 42 | + with Session(engine) as session: |
| 43 | + statement_4 = select(Hero.id, Hero.name, Hero.secret_name, Hero.age) |
| 44 | + results_4 = session.exec(statement_4) |
| 45 | + for hero_4 in results_4: |
| 46 | + assert len(hero_4) == 4 |
| 47 | + name_4: str = hero_4[1] |
| 48 | + assert type(name_4) is str |
| 49 | + assert type(hero_4[0]) is int |
| 50 | + assert type(hero_4[2]) is str |
| 51 | + assert type(hero_4[3]) in [int, type(None)] |
| 52 | + |
| 53 | + # check typing of select with 5 fields: currently runs but doesn't pass mypy |
| 54 | + # with Session(engine) as session: |
| 55 | + # statement_5 = select(Hero.id, Hero.name, Hero.secret_name, Hero.age, Hero.food) |
| 56 | + # results_5 = session.exec(statement_5) |
| 57 | + # for hero_5 in results_5: |
| 58 | + # assert len(hero_5) == 5 |
| 59 | + # name_5: str = hero_5[1] |
| 60 | + # assert type(name_5) is str |
| 61 | + # assert type(hero_5[0]) is int |
| 62 | + # assert type(hero_5[2]) is str |
| 63 | + # assert type(hero_5[3]) in [int, type(None)] |
| 64 | + # assert type(hero_5[4]) in [str, type(None)] |
0 commit comments