44
55import hashlib
66from abc import ABCMeta , abstractmethod as abstract_method
7- from typing import Dict , Generator , Iterable , List , MutableSequence , Optional
7+ from contextlib import contextmanager as context_manager
8+ from threading import Lock
9+ from typing import Dict , Generator , Iterable , List , MutableSequence , \
10+ Optional , Tuple
811
912from iota import Address , TRITS_PER_TRYTE , TrytesCompatible
1013from iota .crypto import Curl
@@ -23,6 +26,19 @@ class BaseAddressCache(with_metaclass(ABCMeta)):
2326 """
2427 Base functionality for classes that cache generated addresses.
2528 """
29+ LockType = Lock
30+ """
31+ The type of locking mechanism used by :py:meth:`acquire_lock`.
32+
33+ Defaults to ``threading.Lock``, but you can change it if you want to
34+ use a different mechanism (e.g., multithreading or distributed).
35+ """
36+
37+ def __init__ (self ):
38+ super (BaseAddressCache , self ).__init__ ()
39+
40+ self ._lock = self .LockType ()
41+
2642 @abstract_method
2743 def get (self , seed , index ):
2844 # type: (Seed, int) -> Optional[Address]
@@ -34,6 +50,18 @@ def get(self, seed, index):
3450 'Not implemented in {cls}.' .format (cls = type (self ).__name__ ),
3551 )
3652
53+ @context_manager
54+ def acquire_lock (self ):
55+ """
56+ Acquires a lock on the cache instance, to prevent invalid cache
57+ misses when multiple threads access the cache concurrently.
58+
59+ Note: Acquire lock before checking the cache, and do not release it
60+ until after the cache hit/miss is resolved.
61+ """
62+ with self ._lock :
63+ yield
64+
3765 @abstract_method
3866 def set (self , seed , index , address ):
3967 # type: (Seed, int, Address) -> None
@@ -45,6 +73,17 @@ def set(self, seed, index, address):
4573 'Not implemented in {cls}.' .format (cls = type (self ).__name__ ),
4674 )
4775
76+ @staticmethod
77+ def _gen_cache_key (seed , index ):
78+ # type: (Seed, int) -> binary_type
79+ """
80+ Generates an obfuscated cache key so that we're not storing seeds
81+ in cleartext.
82+ """
83+ h = hashlib .new ('sha256' )
84+ h .update (binary_type (seed ) + b':' + binary_type (index ))
85+ return h .digest ()
86+
4887
4988class MemoryAddressCache (BaseAddressCache ):
5089 """
@@ -63,17 +102,6 @@ def set(self, seed, index, address):
63102 # type: (Seed, int, Address) -> None
64103 self .cache [self ._gen_cache_key (seed , index )] = address
65104
66- @staticmethod
67- def _gen_cache_key (seed , index ):
68- # type: (Seed, int) -> binary_type
69- """
70- Generates an obfuscated cache key so that we're not storing seeds
71- in cleartext.
72- """
73- h = hashlib .new ('sha256' )
74- h .update (binary_type (seed ) + b':' + binary_type (index ))
75- return h .digest ()
76-
77105
78106class AddressGenerator (Iterable [Address ]):
79107 """
@@ -213,18 +241,19 @@ def create_iterator(self, start=0, step=1):
213241
214242 while True :
215243 if self .cache :
216- address = self .cache .get (self .seed , key_iterator .current )
244+ with self .cache .acquire_lock ():
245+ address = self .cache .get (self .seed , key_iterator .current )
217246
218- if not address :
219- address = self ._generate_address (key_iterator )
220- self .cache .set (self .seed , address .key_index , address )
247+ if not address :
248+ address = self ._generate_address (key_iterator )
249+ self .cache .set (self .seed , address .key_index , address )
221250 else :
222251 address = self ._generate_address (key_iterator )
223252
224253 yield address
225254
226255 @staticmethod
227- def address_from_digest (digest_trits , key_index ):
256+ def address_from_digest_trits (digest_trits , key_index ):
228257 # type: (List[int], int) -> Address
229258 """
230259 Generates an address from a private key digest.
@@ -247,13 +276,13 @@ def _generate_address(self, key_iterator):
247276
248277 Used in the event of a cache miss.
249278 """
250- return self .address_from_digest (* self ._get_digest_params (key_iterator ))
279+ return self .address_from_digest_trits (* self ._get_digest_params (key_iterator ))
251280
252281 @staticmethod
253282 def _get_digest_params (key_iterator ):
254283 # type: (KeyIterator) -> Tuple[List[int], int]
255284 """
256- Extracts parameters for :py:meth:`address_from_digest `.
285+ Extracts parameters for :py:meth:`address_from_digest_trits `.
257286
258287 Split into a separate method so that it can be mocked during unit
259288 tests.
0 commit comments