1414
1515"""Client side encryption."""
1616
17+ import functools
1718import subprocess
1819import uuid
1920import weakref
3031 MongoCryptCallback = object
3132
3233from bson import _bson_to_dict , _dict_to_bson , decode , encode
33- from bson .binary import STANDARD , Binary
3434from bson .codec_options import CodecOptions
35+ from bson .binary import STANDARD , Binary
36+ from bson .errors import BSONError
3537from bson .raw_bson import (DEFAULT_RAW_BSON_OPTIONS ,
3638 RawBSONDocument ,
3739 _inflate_bson )
5658 uuid_representation = STANDARD )
5759
5860
61+ def _wrap_encryption_errors (encryption_func = None ):
62+ """Decorator to wrap encryption related errors with EncryptionError."""
63+ @functools .wraps (encryption_func )
64+ def wrap_encryption_errors (* args , ** kwargs ):
65+ try :
66+ return encryption_func (* args , ** kwargs )
67+ except BSONError :
68+ # BSON encoding/decoding errors are unrelated to encryption so
69+ # we should propagate them unchanged.
70+ raise
71+ except Exception as exc :
72+ raise EncryptionError (exc )
73+
74+ return wrap_encryption_errors
75+
76+
5977class _EncryptionIO (MongoCryptCallback ):
6078 def __init__ (self , client , key_vault_coll , mongocryptd_client , opts ):
6179 """Internal class to perform I/O on behalf of pymongocrypt."""
@@ -85,14 +103,11 @@ def kms_request(self, kms_context):
85103 opts = PoolOptions (connect_timeout = _KMS_CONNECT_TIMEOUT ,
86104 socket_timeout = _KMS_CONNECT_TIMEOUT ,
87105 ssl_context = ctx )
88- try :
89- with _configured_socket ((endpoint , _HTTPS_PORT ), opts ) as conn :
90- conn .sendall (message )
91- while kms_context .bytes_needed > 0 :
92- data = conn .recv (kms_context .bytes_needed )
93- kms_context .feed (data )
94- except Exception as exc :
95- raise MongoCryptError (str (exc ))
106+ with _configured_socket ((endpoint , _HTTPS_PORT ), opts ) as conn :
107+ conn .sendall (message )
108+ while kms_context .bytes_needed > 0 :
109+ data = conn .recv (kms_context .bytes_needed )
110+ kms_context .feed (data )
96111
97112 def collection_info (self , database , filter ):
98113 """Get the collection info for a namespace.
@@ -222,6 +237,7 @@ def __init__(self, io_callbacks, opts):
222237 opts ._kms_providers , schema_map ))
223238 self ._bypass_auto_encryption = opts ._bypass_auto_encryption
224239
240+ @_wrap_encryption_errors
225241 def encrypt (self , database , cmd , check_keys , codec_options ):
226242 """Encrypt a MongoDB command.
227243
@@ -237,16 +253,14 @@ def encrypt(self, database, cmd, check_keys, codec_options):
237253 # Workaround for $clusterTime which is incompatible with check_keys.
238254 cluster_time = check_keys and cmd .pop ('$clusterTime' , None )
239255 encoded_cmd = _dict_to_bson (cmd , check_keys , codec_options )
240- try :
241- encrypted_cmd = self ._auto_encrypter .encrypt (database , encoded_cmd )
242- except MongoCryptError as exc :
243- raise EncryptionError (exc )
256+ encrypted_cmd = self ._auto_encrypter .encrypt (database , encoded_cmd )
244257 # TODO: PYTHON-1922 avoid decoding the encrypted_cmd.
245258 encrypt_cmd = _inflate_bson (encrypted_cmd , DEFAULT_RAW_BSON_OPTIONS )
246259 if cluster_time :
247260 encrypt_cmd ['$clusterTime' ] = cluster_time
248261 return encrypt_cmd
249262
263+ @_wrap_encryption_errors
250264 def decrypt (self , response ):
251265 """Decrypt a MongoDB command response.
252266
@@ -256,10 +270,7 @@ def decrypt(self, response):
256270 :Returns:
257271 The decrypted command response.
258272 """
259- try :
260- return self ._auto_encrypter .decrypt (response )
261- except MongoCryptError as exc :
262- raise EncryptionError (exc )
273+ return self ._auto_encrypter .decrypt (response )
263274
264275 def close (self ):
265276 """Cleanup resources."""
@@ -349,6 +360,7 @@ def __init__(self, kms_providers, key_vault_namespace, key_vault_client):
349360 self ._encryption = ExplicitEncrypter (
350361 self ._io_callbacks , MongoCryptOptions (kms_providers , None ))
351362
363+ @_wrap_encryption_errors
352364 def create_data_key (self , kms_provider , master_key = None ,
353365 key_alt_names = None ):
354366 """Create and insert a new data key into the key vault collection.
@@ -383,6 +395,7 @@ def create_data_key(self, kms_provider, master_key=None,
383395 return self ._encryption .create_data_key (
384396 kms_provider , master_key = master_key , key_alt_names = key_alt_names )
385397
398+ @_wrap_encryption_errors
386399 def encrypt (self , value , algorithm , key_id = None , key_alt_name = None ):
387400 """Encrypt a BSON value with a given key and algorithm.
388401
@@ -410,6 +423,14 @@ def encrypt(self, value, algorithm, key_id=None, key_alt_name=None):
410423 doc , algorithm , key_id = raw_key_id , key_alt_name = key_alt_name )
411424 return decode (encrypted_doc )['v' ]
412425
426+ @_wrap_encryption_errors
427+ def _decrypt (self , value ):
428+ """Internal decrypt helper."""
429+ doc = encode ({'v' : value })
430+ decrypted_doc = self ._encryption .decrypt (doc )
431+ # TODO: Add a required codec_options argument for decoding?
432+ return decode (decrypted_doc , codec_options = _DATA_KEY_OPTS )['v' ]
433+
413434 def decrypt (self , value ):
414435 """Decrypt an encrypted value.
415436
@@ -423,10 +444,8 @@ def decrypt(self, value):
423444 if not (isinstance (value , Binary ) and value .subtype == 6 ):
424445 raise TypeError (
425446 'value to decrypt must be a bson.binary.Binary with subtype 6' )
426- doc = encode ({'v' : value })
427- decrypted_doc = self ._encryption .decrypt (doc )
428- # TODO: Add a required codec_options argument for decoding?
429- return decode (decrypted_doc , codec_options = _DATA_KEY_OPTS )['v' ]
447+
448+ return self ._decrypt (value )
430449
431450 def close (self ):
432451 """Release resources."""
0 commit comments