Skip to content

Commit a09016c

Browse files
Merge pull request #4 from cheese-drawer/correct_query_return_types
Correct query return types
2 parents 6b55247 + 15af9ca commit a09016c

File tree

23 files changed

+615
-344
lines changed

23 files changed

+615
-344
lines changed

db_wrapper/client/async_client.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44
from typing import (
5+
cast,
56
Any,
67
TypeVar,
78
Union,
@@ -10,8 +11,8 @@
1011
List,
1112
Dict)
1213

13-
import aiopg # type: ignore
14-
from psycopg2.extras import register_uuid
14+
import aiopg
15+
from psycopg2.extras import register_uuid, RealDictRow
1516
from psycopg2 import sql
1617

1718
from db_wrapper.connection import ConnectionParameters, connect
@@ -20,10 +21,6 @@
2021
register_uuid()
2122

2223

23-
# Generic doesn't need a more descriptive name
24-
# pylint: disable=invalid-name
25-
T = TypeVar('T')
26-
2724
Query = Union[str, sql.Composed]
2825

2926

@@ -57,6 +54,11 @@ async def _execute_query(
5754
query: Query,
5855
params: Optional[Dict[Hashable, Any]] = None,
5956
) -> None:
57+
# aiopg type is incorrect & thinks execute only takes str
58+
# when in the query is passed through to psycopg2's
59+
# cursor.execute which does accept sql.Composed objects.
60+
query = cast(str, query)
61+
6062
if params:
6163
await cursor.execute(query, params)
6264
else:
@@ -88,7 +90,7 @@ async def execute_and_return(
8890
self,
8991
query: Query,
9092
params: Optional[Dict[Hashable, Any]] = None,
91-
) -> List[T]:
93+
) -> List[RealDictRow]:
9294
"""Execute the given SQL query & return the result.
9395
9496
Arguments:
@@ -102,5 +104,5 @@ async def execute_and_return(
102104
async with self._connection.cursor() as cursor:
103105
await self._execute_query(cursor, query, params)
104106

105-
result: List[T] = await cursor.fetchall()
107+
result: List[RealDictRow] = await cursor.fetchall()
106108
return result

db_wrapper/client/sync_client.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
from __future__ import annotations
44
from typing import (
55
Any,
6-
TypeVar,
7-
Union,
8-
Optional,
6+
Dict,
97
Hashable,
108
List,
11-
Dict)
9+
Optional,
10+
Union,
11+
)
1212

13-
from psycopg2.extras import register_uuid
13+
from psycopg2.extras import register_uuid, RealDictRow
1414
from psycopg2 import sql
1515
# pylint can't seem to find the items in psycopg2 despite being available
1616
from psycopg2._psycopg import cursor # pylint: disable=no-name-in-module
@@ -24,10 +24,6 @@
2424
register_uuid()
2525

2626

27-
# Generic doesn't need a more descriptive name
28-
# pylint: disable=invalid-name
29-
T = TypeVar('T')
30-
3127
Query = Union[str, sql.Composed]
3228

3329

@@ -60,7 +56,7 @@ def _execute_query(
6056
params: Optional[Dict[Hashable, Any]] = None,
6157
) -> None:
6258
if params:
63-
db_cursor.execute(query, params) # type: ignore
59+
db_cursor.execute(query, params)
6460
else:
6561
db_cursor.execute(query)
6662

@@ -88,7 +84,7 @@ def execute_and_return(
8884
self,
8985
query: Query,
9086
params: Optional[Dict[Hashable, Any]] = None,
91-
) -> List[T]:
87+
) -> List[RealDictRow]:
9288
"""Execute the given SQL query & return the result.
9389
9490
Arguments:
@@ -102,5 +98,5 @@ def execute_and_return(
10298
with self._connection.cursor() as db_cursor:
10399
self._execute_query(db_cursor, query, params)
104100

105-
result: List[T] = db_cursor.fetchall()
101+
result: List[RealDictRow] = db_cursor.fetchall()
106102
return result

db_wrapper/model/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Convenience objects to simplify database interactions w/ given interface."""
22

3+
from psycopg2.extras import RealDictRow
34
from .async_model import (
45
AsyncModel,
56
AsyncCreate,

db_wrapper/model/async_model.py

Lines changed: 65 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
"""Asynchronous Model objects."""
22

3-
from typing import Any, Dict, List
3+
from typing import Any, Dict, List, Type
44
from uuid import UUID
55

6+
from psycopg2.extras import RealDictRow
7+
68
from db_wrapper.client import AsyncClient
79
from .base import (
810
ensure_exactly_one,
11+
sql,
912
T,
1013
CreateABC,
1114
ReadABC,
1215
UpdateABC,
1316
DeleteABC,
1417
ModelABC,
15-
sql,
1618
)
1719

1820

@@ -23,16 +25,22 @@ class AsyncCreate(CreateABC[T]):
2325

2426
_client: AsyncClient
2527

26-
def __init__(self, client: AsyncClient, table: sql.Composable) -> None:
27-
super().__init__(table)
28+
def __init__(
29+
self,
30+
client: AsyncClient,
31+
table: sql.Composable,
32+
return_constructor: Type[T]
33+
) -> None:
34+
super().__init__(table, return_constructor)
2835
self._client = client
2936

3037
async def one(self, item: T) -> T:
3138
"""Create one new record with a given item."""
32-
result: List[T] = await self._client.execute_and_return(
33-
self._query_one(item))
39+
query_result: List[RealDictRow] = \
40+
await self._client.execute_and_return(self._query_one(item))
41+
result: T = self._return_constructor(**query_result[0])
3442

35-
return result[0]
43+
return result
3644

3745

3846
class AsyncRead(ReadABC[T]):
@@ -42,19 +50,27 @@ class AsyncRead(ReadABC[T]):
4250

4351
_client: AsyncClient
4452

45-
def __init__(self, client: AsyncClient, table: sql.Composable) -> None:
46-
super().__init__(table)
53+
def __init__(
54+
self,
55+
client: AsyncClient,
56+
table: sql.Composable,
57+
return_constructor: Type[T]
58+
) -> None:
59+
super().__init__(table, return_constructor)
4760
self._client = client
4861

4962
async def one_by_id(self, id_value: UUID) -> T:
5063
"""Read a row by it's id."""
51-
result: List[T] = await self._client.execute_and_return(
52-
self._query_one_by_id(id_value))
64+
query_result: List[RealDictRow] = \
65+
await self._client.execute_and_return(
66+
self._query_one_by_id(id_value))
5367

5468
# Should only return one item from DB
55-
ensure_exactly_one(result)
69+
ensure_exactly_one(query_result)
70+
71+
result: T = self._return_constructor(**query_result[0])
5672

57-
return result[0]
73+
return result
5874

5975

6076
class AsyncUpdate(UpdateABC[T]):
@@ -64,11 +80,16 @@ class AsyncUpdate(UpdateABC[T]):
6480

6581
_client: AsyncClient
6682

67-
def __init__(self, client: AsyncClient, table: sql.Composable) -> None:
68-
super().__init__(table)
83+
def __init__(
84+
self,
85+
client: AsyncClient,
86+
table: sql.Composable,
87+
return_constructor: Type[T]
88+
) -> None:
89+
super().__init__(table, return_constructor)
6990
self._client = client
7091

71-
async def one_by_id(self, id_value: str, changes: Dict[str, Any]) -> T:
92+
async def one_by_id(self, id_value: UUID, changes: Dict[str, Any]) -> T:
7293
"""Apply changes to row with given id.
7394
7495
Arguments:
@@ -79,12 +100,14 @@ async def one_by_id(self, id_value: str, changes: Dict[str, Any]) -> T:
79100
Returns:
80101
full value of row updated
81102
"""
82-
result: List[T] = await self._client.execute_and_return(
83-
self._query_one_by_id(id_value, changes))
103+
query_result: List[RealDictRow] = \
104+
await self._client.execute_and_return(
105+
self._query_one_by_id(id_value, changes))
84106

85-
ensure_exactly_one(result)
107+
ensure_exactly_one(query_result)
108+
result: T = self._return_constructor(**query_result[0])
86109

87-
return result[0]
110+
return result
88111

89112

90113
class AsyncDelete(DeleteABC[T]):
@@ -94,19 +117,26 @@ class AsyncDelete(DeleteABC[T]):
94117

95118
_client: AsyncClient
96119

97-
def __init__(self, client: AsyncClient, table: sql.Composable) -> None:
98-
super().__init__(table)
120+
def __init__(
121+
self,
122+
client: AsyncClient,
123+
table: sql.Composable,
124+
return_constructor: Type[T]
125+
) -> None:
126+
super().__init__(table, return_constructor)
99127
self._client = client
100128

101129
async def one_by_id(self, id_value: str) -> T:
102130
"""Delete one record with matching ID."""
103-
result: List[T] = await self._client.execute_and_return(
104-
self._query_one_by_id(id_value))
131+
query_result: List[RealDictRow] = \
132+
await self._client.execute_and_return(
133+
self._query_one_by_id(id_value))
105134

106135
# Should only return one item from DB
107-
ensure_exactly_one(result)
136+
ensure_exactly_one(query_result)
137+
result = self._return_constructor(**query_result[0])
108138

109-
return result[0]
139+
return result
110140

111141

112142
class AsyncModel(ModelABC[T]):
@@ -122,19 +152,22 @@ class AsyncModel(ModelABC[T]):
122152
_update: AsyncUpdate[T]
123153
_delete: AsyncDelete[T]
124154

125-
# PENDS python 3.9 support in pylint
126-
# pylint: disable=unsubscriptable-object
127155
def __init__(
128156
self,
129157
client: AsyncClient,
130158
table: str,
159+
return_constructor: Type[T],
131160
) -> None:
132161
super().__init__(client, table)
133162

134-
self._create = AsyncCreate[T](self.client, self.table)
135-
self._read = AsyncRead[T](self.client, self.table)
136-
self._update = AsyncUpdate[T](self.client, self.table)
137-
self._delete = AsyncDelete[T](self.client, self.table)
163+
self._create = AsyncCreate[T](
164+
self.client, self.table, return_constructor)
165+
self._read = AsyncRead[T](
166+
self.client, self.table, return_constructor)
167+
self._update = AsyncUpdate[T](
168+
self.client, self.table, return_constructor)
169+
self._delete = AsyncDelete[T](
170+
self.client, self.table, return_constructor)
138171

139172
@property
140173
def create(self) -> AsyncCreate[T]:

0 commit comments

Comments
 (0)