1- import pickle
21import sys
32from contextlib import asynccontextmanager
43from typing import (
1514
1615from redis .asyncio import BlockingConnectionPool , Redis , Sentinel
1716from redis .asyncio .cluster import RedisCluster
17+ from redis .asyncio .connection import Connection
1818from taskiq import AsyncResultBackend
1919from taskiq .abc .result_backend import TaskiqResult
2020from taskiq .abc .serializer import TaskiqSerializer
21+ from taskiq .compat import model_dump , model_validate
22+ from taskiq .serializers import PickleSerializer
2123
2224from taskiq_redis .exceptions import (
2325 DuplicateExpireTimeSelectedError ,
2426 ExpireTimeMustBeMoreThanZeroError ,
2527 ResultIsMissingError ,
2628)
27- from taskiq_redis .serializer import PickleSerializer
2829
2930if sys .version_info >= (3 , 10 ):
3031 from typing import TypeAlias
3334
3435if TYPE_CHECKING :
3536 _Redis : TypeAlias = Redis [bytes ]
37+ _BlockingConnectionPool : TypeAlias = BlockingConnectionPool [Connection ]
3638else :
3739 _Redis : TypeAlias = Redis
40+ _BlockingConnectionPool : TypeAlias = BlockingConnectionPool
3841
3942_ReturnType = TypeVar ("_ReturnType" )
4043
@@ -49,6 +52,7 @@ def __init__(
4952 result_ex_time : Optional [int ] = None ,
5053 result_px_time : Optional [int ] = None ,
5154 max_connection_pool_size : Optional [int ] = None ,
55+ serializer : Optional [TaskiqSerializer ] = None ,
5256 ** connection_kwargs : Any ,
5357 ) -> None :
5458 """
@@ -66,11 +70,12 @@ def __init__(
6670 :raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time
6771 and result_px_time are equal zero.
6872 """
69- self .redis_pool = BlockingConnectionPool .from_url (
73+ self .redis_pool : _BlockingConnectionPool = BlockingConnectionPool .from_url (
7074 url = redis_url ,
7175 max_connections = max_connection_pool_size ,
7276 ** connection_kwargs ,
7377 )
78+ self .serializer = serializer or PickleSerializer ()
7479 self .keep_results = keep_results
7580 self .result_ex_time = result_ex_time
7681 self .result_px_time = result_px_time
@@ -110,9 +115,9 @@ async def set_result(
110115 :param task_id: ID of the task.
111116 :param result: TaskiqResult instance.
112117 """
113- redis_set_params : Dict [str , Union [str , bytes , int ]] = {
118+ redis_set_params : Dict [str , Union [str , int , bytes ]] = {
114119 "name" : task_id ,
115- "value" : pickle . dumps ( result ),
120+ "value" : self . serializer . dumpb ( model_dump ( result ) ),
116121 }
117122 if self .result_ex_time :
118123 redis_set_params ["ex" ] = self .result_ex_time
@@ -159,8 +164,9 @@ async def get_result(
159164 if result_value is None :
160165 raise ResultIsMissingError
161166
162- taskiq_result : TaskiqResult [_ReturnType ] = pickle .loads ( # noqa: S301
163- result_value ,
167+ taskiq_result = model_validate (
168+ TaskiqResult [_ReturnType ],
169+ self .serializer .loadb (result_value ),
164170 )
165171
166172 if not with_logs :
@@ -178,6 +184,7 @@ def __init__(
178184 keep_results : bool = True ,
179185 result_ex_time : Optional [int ] = None ,
180186 result_px_time : Optional [int ] = None ,
187+ serializer : Optional [TaskiqSerializer ] = None ,
181188 ** connection_kwargs : Any ,
182189 ) -> None :
183190 """
@@ -198,6 +205,7 @@ def __init__(
198205 redis_url ,
199206 ** connection_kwargs ,
200207 )
208+ self .serializer = serializer or PickleSerializer ()
201209 self .keep_results = keep_results
202210 self .result_ex_time = result_ex_time
203211 self .result_px_time = result_px_time
@@ -239,7 +247,7 @@ async def set_result(
239247 """
240248 redis_set_params : Dict [str , Union [str , bytes , int ]] = {
241249 "name" : task_id ,
242- "value" : pickle . dumps ( result ),
250+ "value" : self . serializer . dumpb ( model_dump ( result ) ),
243251 }
244252 if self .result_ex_time :
245253 redis_set_params ["ex" ] = self .result_ex_time
@@ -283,8 +291,9 @@ async def get_result(
283291 if result_value is None :
284292 raise ResultIsMissingError
285293
286- taskiq_result : TaskiqResult [_ReturnType ] = pickle .loads ( # noqa: S301
287- result_value ,
294+ taskiq_result : TaskiqResult [_ReturnType ] = model_validate (
295+ TaskiqResult [_ReturnType ],
296+ self .serializer .loadb (result_value ),
288297 )
289298
290299 if not with_logs :
@@ -331,9 +340,7 @@ def __init__(
331340 ** connection_kwargs ,
332341 )
333342 self .master_name = master_name
334- if serializer is None :
335- serializer = PickleSerializer ()
336- self .serializer = serializer
343+ self .serializer = serializer or PickleSerializer ()
337344 self .keep_results = keep_results
338345 self .result_ex_time = result_ex_time
339346 self .result_px_time = result_px_time
@@ -375,7 +382,7 @@ async def set_result(
375382 """
376383 redis_set_params : Dict [str , Union [str , bytes , int ]] = {
377384 "name" : task_id ,
378- "value" : self .serializer .dumpb (result ),
385+ "value" : self .serializer .dumpb (model_dump ( result ) ),
379386 }
380387 if self .result_ex_time :
381388 redis_set_params ["ex" ] = self .result_ex_time
@@ -422,11 +429,17 @@ async def get_result(
422429 if result_value is None :
423430 raise ResultIsMissingError
424431
425- taskiq_result : TaskiqResult [_ReturnType ] = pickle .loads ( # noqa: S301
426- result_value ,
432+ taskiq_result = model_validate (
433+ TaskiqResult [_ReturnType ],
434+ self .serializer .loadb (result_value ),
427435 )
428436
429437 if not with_logs :
430438 taskiq_result .log = None
431439
432440 return taskiq_result
441+
442+ async def shutdown (self ) -> None :
443+ """Shutdown sentinel connections."""
444+ for sentinel in self .sentinel .sentinels :
445+ await sentinel .aclose () # type: ignore[attr-defined]
0 commit comments