Skip to content

Commit 3ff178b

Browse files
Fix CRUD method return types.
Previously, the built-in methods on the CRUD objects were returning the raw data from the Client's `execute_and_return()` method, of type `RealDictRow` from `psycopg2.extras`. Now each CRUD object also requires a return type object of type `Type[T]` to be passed as a second argument on initialization, which is then used to initialize a return type object from the db query results (i.e. an instance of `ModelData`) to be returned. Also generalizes `model.base.ensure_exactly_one()` to accept any list object. Also removes unnecessary type imports.
1 parent a67d155 commit 3ff178b

File tree

3 files changed

+84
-53
lines changed

3 files changed

+84
-53
lines changed

db_wrapper/model/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Convenience objects to simplify database interactions w/ given interface."""
22

3+
from psycopg2.extras import RealDictRow
34
from .async_model import (
45
AsyncModel,
56
AsyncCreate,

db_wrapper/model/base.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
# std lib dependencies
44
from __future__ import annotations
55
from typing import (
6-
TypeVar,
7-
Generic,
86
Any,
9-
Tuple,
10-
List,
117
Dict,
8+
Generic,
9+
List,
10+
Tuple,
11+
Type,
12+
TypeVar,
1213
)
1314
from uuid import UUID
1415

@@ -54,23 +55,35 @@ def __init__(self) -> None:
5455
super().__init__(self, message)
5556

5657

57-
def ensure_exactly_one(result: List[T]) -> None:
58+
def ensure_exactly_one(result: List[Any]) -> None:
5859
"""Raise appropriate Exceptions if list longer than 1."""
5960
if len(result) > 1:
6061
raise UnexpectedMultipleResults(result)
6162
if len(result) == 0:
6263
raise NoResultFound()
6364

6465

65-
class CreateABC(Generic[T]):
66-
"""Encapsulate Create operations for Model.read."""
66+
class CRUDABC(Generic[T]):
67+
"""Encapsulate object creation behavior for all CRUD objects."""
6768

6869
# pylint: disable=too-few-public-methods
6970

7071
_table: sql.Composable
72+
_return_constructor: Type[T]
7173

72-
def __init__(self, table: sql.Composable) -> None:
74+
def __init__(
75+
self,
76+
table: sql.Composable,
77+
return_constructor: Type[T]
78+
) -> None:
7379
self._table = table
80+
self._return_constructor = return_constructor
81+
82+
83+
class CreateABC(CRUDABC[T]):
84+
"""Encapsulate Create operations for Model.create."""
85+
86+
# pylint: disable=too-few-public-methods
7487

7588
def _query_one(self, item: T) -> sql.Composed:
7689
"""Build query to create one new record with a given item."""
@@ -95,16 +108,11 @@ def _query_one(self, item: T) -> sql.Composed:
95108
return query
96109

97110

98-
class ReadABC(Generic[T]):
111+
class ReadABC(CRUDABC[T]):
99112
"""Encapsulate Read operations for Model.read."""
100113

101114
# pylint: disable=too-few-public-methods
102115

103-
_table: sql.Composable
104-
105-
def __init__(self, table: sql.Composable) -> None:
106-
self._table = table
107-
108116
def _query_one_by_id(self, id_value: UUID) -> sql.Composed:
109117
"""Build query to read a row by it's id."""
110118
query = sql.SQL(
@@ -119,16 +127,11 @@ def _query_one_by_id(self, id_value: UUID) -> sql.Composed:
119127
return query
120128

121129

122-
class UpdateABC(Generic[T]):
130+
class UpdateABC(CRUDABC[T]):
123131
"""Encapsulate Update operations for Model.read."""
124132

125133
# pylint: disable=too-few-public-methods
126134

127-
_table: sql.Composable
128-
129-
def __init__(self, table: sql.Composable) -> None:
130-
self._table = table
131-
132135
def _query_one_by_id(
133136
self,
134137
id_value: str,
@@ -160,16 +163,11 @@ def compose_changes(changes: Dict[str, Any]) -> sql.Composed:
160163
return query
161164

162165

163-
class DeleteABC(Generic[T]):
166+
class DeleteABC(CRUDABC[T]):
164167
"""Encapsulate Delete operations for Model.read."""
165168

166169
# pylint: disable=too-few-public-methods
167170

168-
_table: sql.Composable
169-
170-
def __init__(self, table: sql.Composable) -> None:
171-
self._table = table
172-
173171
def _query_one_by_id(self, id_value: str) -> sql.Composed:
174172
"""Build query to delete one record with matching ID."""
175173
query = sql.SQL(

db_wrapper/model/sync_model.py

Lines changed: 59 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
"""Synchronous Model objects."""
22

3-
from typing import Any, Dict, List
3+
from typing import Any, Dict, List, Type
44
from uuid import UUID
55

6+
from psycopg2.extras import RealDictRow
7+
68
from db_wrapper.client import SyncClient
79
from .base import (
8-
ensure_exactly_one,
910
T,
1011
CreateABC,
12+
DeleteABC,
1113
ReadABC,
1214
UpdateABC,
13-
DeleteABC,
1415
ModelABC,
16+
ensure_exactly_one,
1517
sql,
1618
)
1719

@@ -23,16 +25,22 @@ class SyncCreate(CreateABC[T]):
2325

2426
_client: SyncClient
2527

26-
def __init__(self, client: SyncClient, table: sql.Composable) -> None:
27-
super().__init__(table)
28+
def __init__(
29+
self,
30+
client: SyncClient,
31+
table: sql.Composable,
32+
return_constructor: Type[T]
33+
) -> None:
34+
super().__init__(table, return_constructor)
2835
self._client = client
2936

3037
def one(self, item: T) -> T:
3138
"""Create one new record with a given item."""
32-
result: List[T] = self._client.execute_and_return(
39+
query_result: List[RealDictRow] = self._client.execute_and_return(
3340
self._query_one(item))
41+
result: T = self._return_constructor(**query_result[0])
3442

35-
return result[0]
43+
return result
3644

3745

3846
class SyncRead(ReadABC[T]):
@@ -42,19 +50,26 @@ class SyncRead(ReadABC[T]):
4250

4351
_client: SyncClient
4452

45-
def __init__(self, client: SyncClient, table: sql.Composable) -> None:
46-
super().__init__(table)
53+
def __init__(
54+
self,
55+
client: SyncClient,
56+
table: sql.Composable,
57+
return_constructor: Type[T]
58+
) -> None:
59+
super().__init__(table, return_constructor)
4760
self._client = client
4861

4962
def one_by_id(self, id_value: UUID) -> T:
5063
"""Read a row by it's id."""
51-
result: List[T] = self._client.execute_and_return(
64+
query_result: List[RealDictRow] = self._client.execute_and_return(
5265
self._query_one_by_id(id_value))
5366

5467
# Should only return one item from DB
55-
ensure_exactly_one(result)
68+
ensure_exactly_one(query_result)
5669

57-
return result[0]
70+
result: T = self._return_constructor(**query_result[0])
71+
72+
return result
5873

5974

6075
class SyncUpdate(UpdateABC[T]):
@@ -64,8 +79,13 @@ class SyncUpdate(UpdateABC[T]):
6479

6580
_client: SyncClient
6681

67-
def __init__(self, client: SyncClient, table: sql.Composable) -> None:
68-
super().__init__(table)
82+
def __init__(
83+
self,
84+
client: SyncClient,
85+
table: sql.Composable,
86+
return_constructor: Type[T]
87+
) -> None:
88+
super().__init__(table, return_constructor)
6989
self._client = client
7090

7191
def one_by_id(self, id_value: str, changes: Dict[str, Any]) -> T:
@@ -79,12 +99,14 @@ def one_by_id(self, id_value: str, changes: Dict[str, Any]) -> T:
7999
Returns:
80100
full value of row updated
81101
"""
82-
result: List[T] = self._client.execute_and_return(
102+
query_result: List[RealDictRow] = self._client.execute_and_return(
83103
self._query_one_by_id(id_value, changes))
84104

85-
ensure_exactly_one(result)
105+
ensure_exactly_one(query_result)
106+
107+
result: T = self._return_constructor(**query_result[0])
86108

87-
return result[0]
109+
return result
88110

89111

90112
class SyncDelete(DeleteABC[T]):
@@ -94,19 +116,25 @@ class SyncDelete(DeleteABC[T]):
94116

95117
_client: SyncClient
96118

97-
def __init__(self, client: SyncClient, table: sql.Composable) -> None:
98-
super().__init__(table)
119+
def __init__(
120+
self,
121+
client: SyncClient,
122+
table: sql.Composable,
123+
return_constructor: Type[T]
124+
) -> None:
125+
super().__init__(table, return_constructor)
99126
self._client = client
100127

101128
def one_by_id(self, id_value: str) -> T:
102129
"""Delete one record with matching ID."""
103-
result: List[T] = self._client.execute_and_return(
130+
query_result: List[RealDictRow] = self._client.execute_and_return(
104131
self._query_one_by_id(id_value))
105132

106-
# Should only return one item from DB
107-
ensure_exactly_one(result)
133+
ensure_exactly_one(query_result)
134+
135+
result: T = self._return_constructor(**query_result[0])
108136

109-
return result[0]
137+
return result
110138

111139

112140
class SyncModel(ModelABC[T]):
@@ -128,13 +156,17 @@ def __init__(
128156
self,
129157
client: SyncClient,
130158
table: str,
159+
return_constructor: Type[T],
131160
) -> None:
132161
super().__init__(client, table)
133162

134-
self._create = SyncCreate[T](self.client, self.table)
135-
self._read = SyncRead[T](self.client, self.table)
136-
self._update = SyncUpdate[T](self.client, self.table)
137-
self._delete = SyncDelete[T](self.client, self.table)
163+
self._create = SyncCreate[T](
164+
self.client, self.table, return_constructor)
165+
self._read = SyncRead[T](self.client, self.table, return_constructor)
166+
self._update = SyncUpdate[T](
167+
self.client, self.table, return_constructor)
168+
self._delete = SyncDelete[T](
169+
self.client, self.table, return_constructor)
138170

139171
@property
140172
def create(self) -> SyncCreate[T]:

0 commit comments

Comments
 (0)