1010import ipaddress
1111import math
1212import random
13+ import struct
1314import uuid
1415
1516import asyncpg
@@ -31,6 +32,9 @@ def _timezone(offset):
3132
3233
3334type_samples = [
35+ ('bool' , 'bool' , (
36+ True , False ,
37+ )),
3438 ('smallint' , 'int2' , (
3539 - 2 ** 15 + 1 , 2 ** 15 - 1 ,
3640 - 1 , 0 , 1 ,
@@ -132,7 +136,8 @@ def _timezone(offset):
132136 bytes (range (255 , - 1 , - 1 )),
133137 b'\x00 \x00 ' ,
134138 b'foo' ,
135- b'f' * 1024 * 1024
139+ b'f' * 1024 * 1024 ,
140+ dict (input = bytearray (b'\x02 \x01 ' ), output = b'\x02 \x01 ' ),
136141 )),
137142 ('text' , 'text' , (
138143 '' ,
@@ -156,6 +161,7 @@ def _timezone(offset):
156161 datetime .date (2000 , 1 , 1 ),
157162 datetime .date (500 , 1 , 1 ),
158163 datetime .date (1 , 1 , 1 ),
164+ infinity_date ,
159165 ]),
160166 ('time' , 'time' , [
161167 datetime .time (12 , 15 , 20 ),
@@ -191,7 +197,9 @@ def _timezone(offset):
191197 ]),
192198 ('uuid' , 'uuid' , [
193199 uuid .UUID ('38a4ff5a-3a56-11e6-a6c2-c8f73323c6d4' ),
194- uuid .UUID ('00000000-0000-0000-0000-000000000000' )
200+ uuid .UUID ('00000000-0000-0000-0000-000000000000' ),
201+ {'input' : '00000000-0000-0000-0000-000000000000' ,
202+ 'output' : uuid .UUID ('00000000-0000-0000-0000-000000000000' )}
195203 ]),
196204 ('uuid[]' , 'uuid[]' , [
197205 (uuid .UUID ('38a4ff5a-3a56-11e6-a6c2-c8f73323c6d4' ),
@@ -294,11 +302,21 @@ def _timezone(offset):
294302 asyncpg .BitString (),
295303 asyncpg .BitString .frombytes (b'\x00 ' , bitlength = 3 ),
296304 asyncpg .BitString ('0000 0000 1' ),
305+ dict (input = b'\x01 ' , output = asyncpg .BitString ('0000 0001' )),
306+ dict (input = bytearray (b'\x02 ' ), output = asyncpg .BitString ('0000 0010' )),
297307 ]),
298308 ('path' , 'path' , [
299309 asyncpg .Path (asyncpg .Point (0.0 , 0.0 ), asyncpg .Point (1.0 , 1.0 )),
300310 asyncpg .Path (asyncpg .Point (0.0 , 0.0 ), asyncpg .Point (1.0 , 1.0 ),
301311 is_closed = True ),
312+ dict (input = ((0.0 , 0.0 ), (1.0 , 1.0 )),
313+ output = asyncpg .Path (asyncpg .Point (0.0 , 0.0 ),
314+ asyncpg .Point (1.0 , 1.0 ),
315+ is_closed = True )),
316+ dict (input = [(0.0 , 0.0 ), (1.0 , 1.0 )],
317+ output = asyncpg .Path (asyncpg .Point (0.0 , 0.0 ),
318+ asyncpg .Point (1.0 , 1.0 ),
319+ is_closed = False )),
302320 ]),
303321 ('point' , 'point' , [
304322 asyncpg .Point (0.0 , 0.0 ),
@@ -334,22 +352,28 @@ async def test_standard_codecs(self):
334352
335353 for sample in sample_data :
336354 with self .subTest (sample = sample , typname = typname ):
337- rsample = await st .fetchval (sample )
355+ if isinstance (sample , dict ):
356+ inputval = sample ['input' ]
357+ outputval = sample ['output' ]
358+ else :
359+ inputval = outputval = sample
360+
361+ result = await st .fetchval (inputval )
338362 err_msg = (
339- "failed to return {} object data as-is; "
340- "gave {!r}, received {!r}" .format (
341- typname , sample , rsample ))
363+ "unexpected result for {} when passing {!r}: "
364+ "received {!r}, expected {!r}" .format (
365+ typname , inputval , result , outputval ))
342366
343367 if typname .startswith ('float' ):
344- if math .isnan (sample ):
345- if not math .isnan (rsample ):
368+ if math .isnan (outputval ):
369+ if not math .isnan (result ):
346370 self .fail (err_msg )
347371 else :
348372 self .assertTrue (
349- math .isclose (rsample , sample , rel_tol = 1e-6 ),
373+ math .isclose (result , outputval , rel_tol = 1e-6 ),
350374 err_msg )
351375 else :
352- self .assertEqual (rsample , sample , err_msg )
376+ self .assertEqual (result , outputval , err_msg )
353377
354378 with self .subTest (sample = None , typname = typname ):
355379 # Test that None is handled for all types.
@@ -369,10 +393,9 @@ async def test_all_builtin_types_handled(self):
369393 'core type {} ({}) is unhandled' .format (typename , oid ))
370394
371395 async def test_void (self ):
372- stmt = await self .con .prepare ('select pg_sleep(0)' )
373- self .assertIsNone (await stmt .fetchval ())
374-
375- await self .con .fetchval ('select now($1::void)' , None )
396+ res = await self .con .fetchval ('select pg_sleep(0)' )
397+ self .assertIsNone (res )
398+ await self .con .fetchval ('select now($1::void)' , '' )
376399
377400 def test_bitstring (self ):
378401 bitlen = random .randint (0 , 1000 )
@@ -424,6 +447,10 @@ async def test_invalid_input(self):
424447 32768 ,
425448 - 32768
426449 ]),
450+ ('float4' , ValueError , 'float value too large' , [
451+ 4.1 * 10 ** 40 ,
452+ - 4.1 * 10 ** 40 ,
453+ ]),
427454 ('int4' , TypeError , 'an integer is required' , [
428455 '2' ,
429456 'aa' ,
@@ -452,7 +479,11 @@ async def test_arrays(self):
452479 (
453480 r"SELECT '{{{{{{1}}}}}}'::int[]" ,
454481 ((((((1 ,),),),),),)
455- )
482+ ),
483+ (
484+ r"SELECT '{1, 2, NULL}'::int[]::anyarray" ,
485+ (1 , 2 , None )
486+ ),
456487 ]
457488
458489 for sql , expected in cases :
@@ -464,6 +495,7 @@ async def test_arrays(self):
464495 await self .con .fetchval ("SELECT '{{{{{{{1}}}}}}}'::int[]" )
465496
466497 cases = [
498+ (None ,),
467499 (1 , 2 , 3 , 4 , 5 , 6 ),
468500 ((1 , 2 ), (4 , 5 ), (6 , 7 )),
469501 (((1 ,), (2 ,)), ((4 ,), (5 ,)), ((None ,), (7 ,))),
@@ -559,6 +591,10 @@ async def test_composites(self):
559591 self .assertEqual (at [0 ].type .name , 'test_composite' )
560592 self .assertEqual (at [0 ].type .kind , 'composite' )
561593
594+ res = await self .con .fetchval ('''
595+ SELECT $1::test_composite
596+ ''' , res )
597+
562598 finally :
563599 await self .con .execute ('DROP TYPE test_composite' )
564600
@@ -645,13 +681,29 @@ async def test_extra_codec_alias(self):
645681 await self .con .set_builtin_type_codec (
646682 'hstore' , codec_name = 'pg_contrib.hstore' )
647683
684+ cases = [
685+ {'ham' : 'spam' , 'nada' : None },
686+ {}
687+ ]
688+
648689 st = await self .con .prepare ('''
649690 SELECT $1::hstore AS result
650691 ''' )
651- res = await st .fetchrow ({'ham' : 'spam' , 'nada' : None })
652- res = res ['result' ]
653692
654- self .assertEqual (res , {'ham' : 'spam' , 'nada' : None })
693+ for case in cases :
694+ res = await st .fetchval (case )
695+ self .assertEqual (res , case )
696+
697+ res = await self .con .fetchval ('''
698+ SELECT $1::hstore AS result
699+ ''' , (('foo' , 2 ), ('bar' , 3 )))
700+
701+ self .assertEqual (res , {'foo' : '2' , 'bar' : '3' })
702+
703+ with self .assertRaisesRegex (ValueError , 'null value not allowed' ):
704+ await self .con .fetchval ('''
705+ SELECT $1::hstore AS result
706+ ''' , {None : '1' })
655707
656708 finally :
657709 await self .con .execute ('''
@@ -728,3 +780,83 @@ def hstore_encoder(obj):
728780 await self .con .execute ('''
729781 DROP EXTENSION hstore
730782 ''' )
783+
784+ async def test_custom_codec_binary (self ):
785+ """Test encoding/decoding using a custom codec in binary mode."""
786+ await self .con .execute ('''
787+ CREATE EXTENSION IF NOT EXISTS hstore
788+ ''' )
789+
790+ longstruct = struct .Struct ('!L' )
791+ ulong_unpack = lambda b : longstruct .unpack_from (b )[0 ]
792+ ulong_pack = longstruct .pack
793+
794+ def hstore_decoder (data ):
795+ result = {}
796+ n = ulong_unpack (data )
797+ view = memoryview (data )
798+ ptr = 4
799+
800+ for i in range (n ):
801+ klen = ulong_unpack (view [ptr :ptr + 4 ])
802+ ptr += 4
803+ k = bytes (view [ptr :ptr + klen ]).decode ()
804+ ptr += klen
805+ vlen = ulong_unpack (view [ptr :ptr + 4 ])
806+ ptr += 4
807+ if vlen == - 1 :
808+ v = None
809+ else :
810+ v = bytes (view [ptr :ptr + vlen ]).decode ()
811+ ptr += vlen
812+
813+ result [k ] = v
814+
815+ return result
816+
817+ def hstore_encoder (obj ):
818+ buffer = bytearray (ulong_pack (len (obj )))
819+
820+ for k , v in obj .items ():
821+ kenc = k .encode ()
822+ buffer += ulong_pack (len (kenc )) + kenc
823+
824+ if v is None :
825+ buffer += b'\xFF \xFF \xFF \xFF ' # -1
826+ else :
827+ venc = v .encode ()
828+ buffer += ulong_pack (len (venc )) + venc
829+
830+ return buffer
831+
832+ try :
833+ await self .con .set_type_codec ('hstore' , encoder = hstore_encoder ,
834+ decoder = hstore_decoder ,
835+ binary = True )
836+
837+ st = await self .con .prepare ('''
838+ SELECT $1::hstore AS result
839+ ''' )
840+
841+ res = await st .fetchrow ({'ham' : 'spam' })
842+ res = res ['result' ]
843+
844+ self .assertEqual (res , {'ham' : 'spam' })
845+
846+ pt = st .get_parameters ()
847+ self .assertTrue (isinstance (pt , tuple ))
848+ self .assertEqual (len (pt ), 1 )
849+ self .assertEqual (pt [0 ].name , 'hstore' )
850+ self .assertEqual (pt [0 ].kind , 'scalar' )
851+ self .assertEqual (pt [0 ].schema , 'public' )
852+
853+ at = st .get_attributes ()
854+ self .assertTrue (isinstance (at , tuple ))
855+ self .assertEqual (len (at ), 1 )
856+ self .assertEqual (at [0 ].name , 'result' )
857+ self .assertEqual (at [0 ].type , pt [0 ])
858+
859+ finally :
860+ await self .con .execute ('''
861+ DROP EXTENSION hstore
862+ ''' )
0 commit comments