1212 Any ,
1313 List ,
1414 Tuple ,
15+ Type ,
1516 TypeVar ,
1617)
1718from uuid import uuid4
3435)
3536
3637
38+ T = TypeVar ('T' , bound = ModelData )
39+
40+
3741def setupAsync (
38- query_result : List [ModelData ]
39- ) -> Tuple [AsyncModel [ModelData ], AsyncClient ]:
42+ query_result : List [T ],
43+ model_data : Type [T ]
44+ ) -> Tuple [AsyncModel [T ], AsyncClient ]:
4045 """Setup helper that returns instances of both a Model & a Client.
4146
4247 Mocks the execute_and_return method on the Client instance to skip
@@ -54,17 +59,18 @@ def setupAsync(
5459
5560 # mock client's sql execution method
5661 client .execute_and_return = helpers .AsyncMock ( # type:ignore
57- return_value = query_result )
62+ return_value = [ i . dict () for i in query_result ] )
5863
5964 # init a real model with mocked client
60- model = AsyncModel [Any ](client , 'test' , ModelData )
65+ model = AsyncModel [Any ](client , 'test' , model_data )
6166
6267 return model , client
6368
6469
6570def setupSync (
66- query_result : List [ModelData ]
67- ) -> Tuple [SyncModel [ModelData ], SyncClient ]:
71+ query_result : List [T ],
72+ model_data : Type [T ]
73+ ) -> Tuple [SyncModel [T ], SyncClient ]:
6874 """Setup helper that returns instances of both a Model & a Client.
6975
7076 Mocks the execute_and_return method on the Client instance to skip
@@ -82,10 +88,10 @@ def setupSync(
8288
8389 # mock client's sql execution method
8490 client .execute_and_return = helpers .MagicMock ( # type:ignore
85- return_value = query_result )
91+ return_value = [ i . dict () for i in query_result ] )
8692
8793 # init a real model with mocked client
88- model = SyncModel [ModelData ](client , 'test' , ModelData )
94+ model = SyncModel [T ](client , 'test' , model_data )
8995
9096 return model , client
9197
@@ -96,8 +102,8 @@ class TestReadOneById(TestCase):
96102 @helpers .async_test
97103 async def test_it_correctly_builds_query_with_given_id (self ) -> None :
98104 item = ModelData (id = uuid4 ())
99- async_model , async_client = setupAsync ([item ])
100- sync_model , sync_client = setupSync ([item ])
105+ async_model , async_client = setupAsync ([item ], ModelData )
106+ sync_model , sync_client = setupSync ([item ], ModelData )
101107
102108 await async_model .read .one_by_id (item .id )
103109 sync_model .read .one_by_id (item .id )
@@ -121,8 +127,8 @@ async def test_it_correctly_builds_query_with_given_id(self) -> None:
121127 @helpers .async_test
122128 async def test_it_returns_a_single_result (self ) -> None :
123129 item = ModelData (id = uuid4 ())
124- async_model , _ = setupAsync ([item ])
125- sync_model , _ = setupSync ([item ])
130+ async_model , _ = setupAsync ([item ], ModelData )
131+ sync_model , _ = setupSync ([item ], ModelData )
126132 results = [await async_model .read .one_by_id (item .id ),
127133 sync_model .read .one_by_id (item .id )]
128134
@@ -133,8 +139,8 @@ async def test_it_returns_a_single_result(self) -> None:
133139 @helpers .async_test
134140 async def test_it_raises_exception_if_more_than_one_result (self ) -> None :
135141 item = ModelData (id = uuid4 ())
136- async_model , _ = setupAsync ([item , item ])
137- sync_model , _ = setupSync ([item , item ])
142+ async_model , _ = setupAsync ([item , item ], ModelData )
143+ sync_model , _ = setupSync ([item , item ], ModelData )
138144
139145 with self .subTest ():
140146 with self .assertRaises (UnexpectedMultipleResults ):
@@ -146,18 +152,20 @@ async def test_it_raises_exception_if_more_than_one_result(self) -> None:
146152
147153 @ helpers .async_test
148154 async def test_it_raises_exception_if_no_result_to_return (self ) -> None :
155+ empty_async : List [ModelData ] = []
156+ empty_sync : List [ModelData ] = []
149157 async_model : AsyncModel [ModelData ]
150158 sync_model : SyncModel [ModelData ]
151- async_model , _ = setupAsync ([] )
152- sync_model , _ = setupSync ([] )
159+ async_model , _ = setupAsync (empty_async , ModelData )
160+ sync_model , _ = setupSync (empty_sync , ModelData )
153161
154162 with self .subTest ():
155163 with self .assertRaises (NoResultFound ):
156- await async_model .read .one_by_id ('id' )
164+ await async_model .read .one_by_id (uuid4 () )
157165
158166 with self .subTest ():
159167 with self .assertRaises (NoResultFound ):
160- sync_model .read .one_by_id ('id' )
168+ sync_model .read .one_by_id (uuid4 () )
161169
162170
163171class TestCreateOne (TestCase ):
@@ -175,8 +183,8 @@ async def test_it_correctly_builds_query_with_given_data(self) -> None:
175183 'a' : 'a' ,
176184 'b' : 'b' ,
177185 })
178- async_model , async_client = setupAsync ([item ])
179- sync_model , sync_client = setupSync ([item ])
186+ async_model , async_client = setupAsync ([item ], TestCreateOne . Item )
187+ sync_model , sync_client = setupSync ([item ], TestCreateOne . Item )
180188
181189 await async_model .create .one (item )
182190 sync_model .create .one (item )
@@ -203,8 +211,8 @@ async def test_it_returns_the_new_record(self) -> None:
203211 'a' : 'a' ,
204212 'b' : 'b' ,
205213 })
206- async_model , _ = setupAsync ([item ])
207- sync_model , _ = setupSync ([item ])
214+ async_model , _ = setupAsync ([item ], TestCreateOne . Item )
215+ sync_model , _ = setupSync ([item ], TestCreateOne . Item )
208216
209217 results = [await async_model .create .one (item ),
210218 sync_model .create .one (item )]
@@ -229,11 +237,11 @@ async def test_it_correctly_builds_query_with_given_data(self) -> None:
229237 'a' : 'a' ,
230238 'b' : 'b' ,
231239 })
232- async_model , async_client = setupAsync ([item ])
233- sync_model , sync_client = setupSync ([item ])
240+ async_model , async_client = setupAsync ([item ], TestUpdateOne . Item )
241+ sync_model , sync_client = setupSync ([item ], TestUpdateOne . Item )
234242
235- await async_model .update .one_by_id (str ( item .id ) , {'b' : 'c' })
236- sync_model .update .one_by_id (str ( item .id ) , {'b' : 'c' })
243+ await async_model .update .one_by_id (item .id , {'b' : 'c' })
244+ sync_model .update .one_by_id (item .id , {'b' : 'c' })
237245
238246 async_query_composed = cast (
239247 helpers .AsyncMock , async_client .execute_and_return ).call_args [0 ][0 ]
@@ -260,8 +268,8 @@ async def test_it_returns_the_new_record(self) -> None:
260268 })
261269 # mock result
262270 updated = TestUpdateOne .Item (** {** item .dict (), 'b' : 'c' })
263- async_model , _ = setupAsync ([updated ])
264- sync_model , _ = setupSync ([updated ])
271+ async_model , _ = setupAsync ([updated ], TestUpdateOne . Item )
272+ sync_model , _ = setupSync ([updated ], TestUpdateOne . Item )
265273
266274 results = [
267275 await async_model .update .one_by_id (item .id , {'b' : 'c' }),
@@ -288,8 +296,8 @@ async def test_it_correctly_builds_query_with_given_data(self) -> None:
288296 'a' : 'a' ,
289297 'b' : 'b' ,
290298 })
291- async_model , async_client = setupAsync ([item ])
292- sync_model , sync_client = setupSync ([item ])
299+ async_model , async_client = setupAsync ([item ], TestDeleteOneById . Item )
300+ sync_model , sync_client = setupSync ([item ], TestDeleteOneById . Item )
293301
294302 await async_model .delete .one_by_id (str (item .id ))
295303 sync_model .delete .one_by_id (str (item .id ))
@@ -316,8 +324,8 @@ async def test_it_returns_the_deleted_record(self) -> None:
316324 'a' : 'a' ,
317325 'b' : 'b' ,
318326 })
319- async_model , _ = setupAsync ([item ])
320- sync_model , _ = setupSync ([item ])
327+ async_model , _ = setupAsync ([item ], TestDeleteOneById . Item )
328+ sync_model , _ = setupSync ([item ], TestDeleteOneById . Item )
321329
322330 results = [await async_model .delete .one_by_id (str (item .id )),
323331 sync_model .delete .one_by_id (str (item .id ))]
@@ -347,8 +355,8 @@ class AsyncExtendedModel(AsyncModel[Item]):
347355 read : AsyncReadExtended
348356
349357 def __init__ (self , client : AsyncClient ) -> None :
350- super ().__init__ (client , 'extended_model' )
351- self .read = AsyncReadExtended (self .client , self .table )
358+ super ().__init__ (client , 'extended_model' , Item )
359+ self .read = AsyncReadExtended (self .client , self .table , Item )
352360
353361 class SyncReadExtended (SyncRead [Item ]):
354362 """Extending Read with additional query."""
@@ -361,11 +369,11 @@ class SyncExtendedModel(SyncModel[Item]):
361369 read : SyncReadExtended
362370
363371 def __init__ (self , client : SyncClient ) -> None :
364- super ().__init__ (client , 'extended_model' )
365- self .read = SyncReadExtended (self .client , self .table )
372+ super ().__init__ (client , 'extended_model' , Item )
373+ self .read = SyncReadExtended (self .client , self .table , Item )
366374
367- _ , async_client = setupAsync ([Item (** {"id" : uuid4 ()})])
368- _ , sync_client = setupSync ([Item (** {"id" : uuid4 ()})])
375+ _ , async_client = setupAsync ([Item (** {"id" : uuid4 ()})], Item )
376+ _ , sync_client = setupSync ([Item (** {"id" : uuid4 ()})], Item )
369377 self .models = [AsyncExtendedModel (async_client ),
370378 SyncExtendedModel (sync_client )]
371379
0 commit comments