55"""
66
77import copy
8+ from decimal import Decimal
89import json
910import logging
1011import time
11- import hmac
12- import hashlib
1312
1413import asyncio
1514import aiohttp
@@ -50,7 +49,34 @@ def _auth_headers(self, path, method, body=''):
5049 'CB-ACCESS-PASSPHRASE' : self .passphrase ,
5150 }
5251
53- async def _get (self , path , params = None , pagination = False ):
52+ def _convert_return_fields (self , fields , decimal_fields , convert_all ):
53+ if decimal_fields is None and not convert_all :
54+ return fields
55+ if isinstance (fields , list ):
56+ return [self ._convert_return_fields (field , decimal_fields ,
57+ convert_all )
58+ for field in fields ]
59+ elif isinstance (fields , dict ):
60+ new_fields = {}
61+ for k , v in fields .items ():
62+ if (decimal_fields is not None and k in decimal_fields ) \
63+ or convert_all :
64+ if isinstance (v , list ):
65+ new_fields [k ] = self ._convert_return_fields (
66+ v , decimal_fields , convert_all )
67+ else :
68+ new_fields [k ] = Decimal (v )
69+ else :
70+ new_fields [k ] = v
71+ return new_fields
72+ else :
73+ if convert_all and not isinstance (fields , int ):
74+ return Decimal (fields )
75+ else :
76+ return fields
77+
78+ async def _get (self , path , params = None , decimal_return_fields = None ,
79+ convert_all = False , pagination = False ):
5480 if params is None :
5581 params_copy = {}
5682 else :
@@ -81,9 +107,11 @@ async def _get(self, path, params=None, pagination=False):
81107 if "cb-after" in resp_headers :
82108 params_copy ['after' ] = resp_headers ['cb-after' ]
83109 else :
84- return results
110+ return self ._convert_return_fields (
111+ results , decimal_return_fields , convert_all )
85112 else :
86- return res
113+ return self ._convert_return_fields (
114+ res , decimal_return_fields , convert_all )
87115
88116 async def _post (self , path , data = None ):
89117 json_data = json .dumps (data )
@@ -109,42 +137,57 @@ async def _delete(self, path, data=None):
109137 return await response .json ()
110138
111139 async def get_products (self ):
112- return await self ._get ('/products' )
140+ return await self ._get (
141+ '/products' ,
142+ decimal_return_fields = {'base_min_size' , 'base_max_size' ,
143+ 'quote_increment' })
113144
114145 async def get_product_ticker (self , product_id = None ):
115146 return await self ._get (
116- '/products/{}/ticker' .format (product_id or self .product_id ))
147+ '/products/{}/ticker' .format (product_id or self .product_id ),
148+ decimal_return_fields = {'price' , 'size' , 'bid' , 'ask' , 'volume' })
117149
118150 async def get_product_trades (self , product_id = None ):
119151 return await self ._get (
120- '/products/{}/trades' .format (product_id or self .product_id ))
152+ '/products/{}/trades' .format (product_id or self .product_id ),
153+ decimal_return_fields = {'price' , 'size' })
121154
122155 async def get_product_order_book (self , product_id = None , level = 1 ):
123156 params = {'level' : level }
124157 return await self ._get (
125158 '/products/{}/book' .format (product_id or self .product_id ),
126- params = params )
159+ params = params , decimal_return_fields = {'bids' , 'asks' },
160+ convert_all = True )
127161
128162 async def get_product_historic_rates (self , product_id = None , start = '' ,
129163 end = '' , granularity = '' ):
130164 payload = {}
131165 payload ["start" ] = start
132166 payload ["end" ] = end
133167 payload ["granularity" ] = granularity
134- return await self ._get (
168+ res = await self ._get (
135169 '/products/{}/candles' .format (product_id or self .product_id ),
136170 params = payload )
171+ # NOTE: there's a bug where the API returns floats instead of strings
172+ # here
173+ for row in res :
174+ for i , col in enumerate (row [1 :]):
175+ row [i + 1 ] = Decimal (str (col ))
176+ return res
137177
138178 async def get_product_24hr_stats (self , product_id = None ):
139179 return await self ._get (
140- '/products/{}/stats' .format (product_id or self .product_id ))
180+ '/products/{}/stats' .format (product_id or self .product_id ),
181+ convert_all = True )
141182
142183 async def get_currencies (self ):
143- return await self ._get ('/currencies' )
184+ return await self ._get ('/currencies' ,
185+ decimal_return_fields = {'min_size' })
144186
145187 async def get_time (self ):
146188 return await self ._get ('/time' )
147189
190+ # TODO: convert return values
148191 # authenticated API
149192 async def get_account (self , account_id = '' ):
150193 assert self .authenticated
@@ -358,6 +401,7 @@ async def main(): # pragma: no cover
358401 trader .get_products (),
359402 trader .get_product_ticker (),
360403 trader .get_time (),
404+ trader .get_product_historic_rates (),
361405 # trader.buy(type='limit', size='0.01', price='2500.12'),
362406 )
363407 logging .info (res )
0 commit comments