@@ -25,18 +25,22 @@ def long_or_none(r):
2525 return long (r )
2626 return r
2727
28- def json_or_none (r ):
28+ def json_or_none (d ):
2929 "Return a deserialized JSON object or None"
30- if r :
31- return json .loads (r )
32- return r
30+ def _f (r ):
31+ if r :
32+ return d (r )
33+ return r
34+ return _f
3335
34- def bulk_of_jsons (b ):
36+ def bulk_of_jsons (d ):
3537 "Replace serialized JSON values with objects in a bulk array response (list)"
36- for index , item in enumerate (b ):
37- if item is not None :
38- b [index ] = json .loads (item )
39- return b
38+ def _f (b ):
39+ for index , item in enumerate (b ):
40+ if item is not None :
41+ b [index ] = d (item )
42+ return b
43+ return _f
4044
4145class Client (StrictRedis ):
4246 """
@@ -51,49 +55,55 @@ class Client(StrictRedis):
5155 'ver' : 1
5256 }
5357
54- MODULE_CALLBACKS = {
55- 'JSON.DEL' : long ,
56- 'JSON.GET' : json_or_none ,
57- 'JSON.MGET' : bulk_of_jsons ,
58- 'JSON.SET' : lambda r : r and nativestr (r ) == 'OK' ,
59- 'JSON.NUMINCRBY' : float_or_long ,
60- 'JSON.NUMMULTBY' : float_or_long ,
61- 'JSON.STRAPPEND' : long_or_none ,
62- 'JSON.STRLEN' : long_or_none ,
63- 'JSON.ARRAPPEND' : long_or_none ,
64- 'JSON.ARRINDEX' : long_or_none ,
65- 'JSON.ARRINSERT' : long_or_none ,
66- 'JSON.ARRLEN' : long_or_none ,
67- 'JSON.ARRPOP' : json_or_none ,
68- 'JSON.ARRTRIM' : long_or_none ,
69- 'JSON.OBJLEN' : long_or_none ,
70- }
58+ _encoder = None
59+ _encode = None
60+ _decoder = None
61+ _decode = None
62+
63+ def __init__ (self , encoder = None , decoder = None , * args , ** kwargs ):
64+ """
65+ Creates a new ReJSON client
66+ """
67+ self .setEncoder (encoder )
68+ self .setDecoder (decoder )
69+ StrictRedis .__init__ (self , * args , ** kwargs )
7170
72- def __init__ (self , * args , ** kwargs ):
73- super (Client , self ).__init__ (* args , ** kwargs )
74- self .__checkPrerequirements ()
7571 # Set the module commands' callbacks
76- for k , v in self .MODULE_CALLBACKS .iteritems ():
72+ MODULE_CALLBACKS = {
73+ 'JSON.DEL' : long ,
74+ 'JSON.GET' : json_or_none (self ._decode ),
75+ 'JSON.MGET' : bulk_of_jsons (self ._decode ),
76+ 'JSON.SET' : lambda r : r and nativestr (r ) == 'OK' ,
77+ 'JSON.NUMINCRBY' : float_or_long ,
78+ 'JSON.NUMMULTBY' : float_or_long ,
79+ 'JSON.STRAPPEND' : long_or_none ,
80+ 'JSON.STRLEN' : long_or_none ,
81+ 'JSON.ARRAPPEND' : long_or_none ,
82+ 'JSON.ARRINDEX' : long_or_none ,
83+ 'JSON.ARRINSERT' : long_or_none ,
84+ 'JSON.ARRLEN' : long_or_none ,
85+ 'JSON.ARRPOP' : json_or_none (self ._decode ),
86+ 'JSON.ARRTRIM' : long_or_none ,
87+ 'JSON.OBJLEN' : long_or_none ,
88+ }
89+ for k , v in MODULE_CALLBACKS .iteritems ():
7790 self .set_response_callback (k , v )
91+
92+ def setEncoder (self , encoder ):
93+ "Sets the encoder"
94+ if not encoder :
95+ self ._encoder = json .JSONEncoder ()
96+ else :
97+ self ._encoder = encoder
98+ self ._encode = self ._encoder .encode
7899
79- def __checkPrerequirements (self ):
80- "Checks that the module is ready"
81- try :
82- reply = self .execute_command ('MODULE' , 'LIST' )
83- except exceptions .ResponseError as e :
84- if e .message .startswith ('unknown command' ):
85- raise exceptions .RedisError ('Modules are not supported '
86- 'on your Redis server - consider '
87- 'upgrading to a newer version.' )
88- finally :
89- info = self .MODULE_INFO
90- for r in reply :
91- module = dict (zip (r [0 ::2 ], r [1 ::2 ]))
92- if info ['name' ] == module ['name' ] and \
93- info ['ver' ] <= module ['ver' ]:
94- return
95- raise exceptions .RedisError ('ReJSON is not loaded - follow the '
96- 'instructions at http://rejson.io' )
100+ def setDecoder (self , decoder ):
101+ "Sets the decoder"
102+ if not decoder :
103+ self ._decoder = json .JSONDecoder ()
104+ else :
105+ self ._decoder = decoder
106+ self ._decode = self ._decoder .decode
97107
98108 def JSONDel (self , name , path = Path .rootPath ()):
99109 """
@@ -130,7 +140,8 @@ def JSONSet(self, name, path, obj, nx=False, xx=False):
130140 ``nx`` if set to True, set ``value`` only if it does not exist
131141 ``xx`` if set to True, set ``value`` only if it exists
132142 """
133- pieces = [name , str_path (path ), json .dumps (obj )]
143+ pieces = [name , str_path (path ), self ._encode (obj )]
144+
134145 # Handle existential modifiers
135146 if nx and xx :
136147 raise Exception ('nx and xx are mutually exclusive: use one, the '
@@ -152,21 +163,21 @@ def JSONNumIncrBy(self, name, path, number):
152163 Increments the numeric (integer or floating point) JSON value under
153164 ``path`` at key ``name`` by the provided ``number``
154165 """
155- return self .execute_command ('JSON.NUMINCRBY' , name , str_path (path ), json . dumps (number ))
166+ return self .execute_command ('JSON.NUMINCRBY' , name , str_path (path ), self . _encode (number ))
156167
157168 def JSONNumMultBy (self , name , path , number ):
158169 """
159170 Multiplies the numeric (integer or floating point) JSON value under
160171 ``path`` at key ``name`` with the provided ``number``
161172 """
162- return self .execute_command ('JSON.NUMMULTBY' , name , str_path (path ), json . dumps (number ))
173+ return self .execute_command ('JSON.NUMMULTBY' , name , str_path (path ), self . _encode (number ))
163174
164175 def JSONStrAppend (self , name , string , path = Path .rootPath ()):
165176 """
166177 Appends to the string JSON value under ``path`` at key ``name`` the
167178 provided ``string``
168179 """
169- return self .execute_command ('JSON.STRAPPEND' , name , str_path (path ), json . dumps (string ))
180+ return self .execute_command ('JSON.STRAPPEND' , name , str_path (path ), self . _encode (string ))
170181
171182 def JSONStrLen (self , name , path = Path .rootPath ()):
172183 """
@@ -182,7 +193,7 @@ def JSONArrAppend(self, name, path=Path.rootPath(), *args):
182193 """
183194 pieces = [name , str_path (path )]
184195 for o in args :
185- pieces .append (json . dumps (o ))
196+ pieces .append (self . _encode (o ))
186197 return self .execute_command ('JSON.ARRAPPEND' , * pieces )
187198
188199 def JSONArrIndex (self , name , path , scalar , start = 0 , stop = - 1 ):
@@ -191,7 +202,7 @@ def JSONArrIndex(self, name, path, scalar, start=0, stop=-1):
191202 ``name``. The search can be limited using the optional inclusive
192203 ``start`` and exclusive ``stop`` indices.
193204 """
194- return self .execute_command ('JSON.ARRINDEX' , name , str_path (path ), json . dumps (scalar ), start , stop )
205+ return self .execute_command ('JSON.ARRINDEX' , name , str_path (path ), self . _encode (scalar ), start , stop )
195206
196207 def JSONArrInsert (self , name , path , index , * args ):
197208 """
@@ -200,7 +211,7 @@ def JSONArrInsert(self, name, path, index, *args):
200211 """
201212 pieces = [name , str_path (path ), index ]
202213 for o in args :
203- pieces .append (json . dumps (o ))
214+ pieces .append (self . _encode (o ))
204215 return self .execute_command ('JSON.ARRINSERT' , * pieces )
205216
206217 def JSONArrLen (self , name , path = Path .rootPath ()):
@@ -246,12 +257,14 @@ def pipeline(self, transaction=True, shard_hint=None):
246257 atomic, pipelines are useful for reducing the back-and-forth overhead
247258 between the client and server.
248259 """
249- return Pipeline (
250- self .connection_pool ,
251- self .response_callbacks ,
252- transaction ,
253- shard_hint )
260+ p = Pipeline (
261+ connection_pool = self .connection_pool ,
262+ response_callbacks = self .response_callbacks ,
263+ transaction = transaction ,
264+ shard_hint = shard_hint )
265+ p .setEncoder (self ._encoder )
266+ p .setDecoder (self ._decoder )
267+ return p
254268
255269class Pipeline (BasePipeline , Client ):
256270 "Pipeline for ReJSONClient"
257- pass
0 commit comments