1616import io
1717import logging
1818import time
19- from typing import Optional , Tuple , Any
19+ from typing import Optional , Tuple , Any , List
2020
2121from tink import aead , daead , KmsClient , kms_client_from_uri , \
2222 register_kms_client , TinkError
4545ENCRYPT_KMS_TYPE = "encrypt.kms.type"
4646ENCRYPT_DEK_ALGORITHM = "encrypt.dek.algorithm"
4747ENCRYPT_DEK_EXPIRY_DAYS = "encrypt.dek.expiry.days"
48+ ENCRYPT_ALTERNATE_KMS_KEY_IDS = "encrypt.alternate.kms.key.ids"
4849
4950MILLIS_IN_DAY = 24 * 60 * 60 * 1000
5051
@@ -279,7 +280,7 @@ def _get_or_create_dek(self, ctx: RuleContext, version: Optional[int]) -> Dek:
279280 raise RuleError (f"no dek found for { dek_id .kek_name } during consume" )
280281 encrypted_dek = None
281282 if not kek .shared :
282- primitive = self . _get_aead (self ._executor .config , self ._kek )
283+ primitive = AeadWrapper (self ._executor .config , self ._kek )
283284 raw_dek = self ._cryptor .generate_key ()
284285 encrypted_dek = primitive .encrypt (raw_dek , self ._cryptor .EMPTY_AAD )
285286 new_version = dek .version + 1 if is_expired else 1
@@ -293,7 +294,7 @@ def _get_or_create_dek(self, ctx: RuleContext, version: Optional[int]) -> Dek:
293294 key_bytes = dek .get_key_material_bytes ()
294295 if key_bytes is None :
295296 if primitive is None :
296- primitive = self . _get_aead (self ._executor .config , self ._kek )
297+ primitive = AeadWrapper (self ._executor .config , self ._kek )
297298 encrypted_dek = dek .get_encrypted_key_material_bytes ()
298299 raw_dek = primitive .decrypt (encrypted_dek , self ._cryptor .EMPTY_AAD )
299300 dek .set_key_material (raw_dek )
@@ -410,8 +411,51 @@ def _to_object(self, field_type: FieldType, value: bytes) -> Any:
410411 return value
411412 return None
412413
413- def _get_aead (self , config : dict , kek : Kek ) -> aead .Aead :
414- kek_url = kek .kms_type + "://" + kek .kms_key_id
414+
415+ class AeadWrapper (aead .Aead ):
416+ def __init__ (self , config : dict , kek : Kek ):
417+ self ._config = config
418+ self ._kek = kek
419+ self ._kms_key_ids = self ._get_kms_key_ids ()
420+
421+ def encrypt (self , plaintext : bytes , associated_data : bytes ) -> bytes :
422+ for index , kms_key_id in enumerate (self ._kms_key_ids ):
423+ try :
424+ aead = self ._get_aead (self ._config , self ._kek .kms_type , kms_key_id )
425+ return aead .encrypt (plaintext , associated_data )
426+ except Exception as e :
427+ log .warning ("failed to encrypt with kek %s and kms key id %s" ,
428+ self ._kek .name , kms_key_id )
429+ if index == len (self ._kms_key_ids ) - 1 :
430+ raise RuleError (f"failed to encrypt with all KEKs for { self ._kek .name } " ) from e
431+ raise RuleError ("No KEK found for encryption" )
432+
433+ def decrypt (self , ciphertext : bytes , associated_data : bytes ) -> bytes :
434+ for index , kms_key_id in enumerate (self ._kms_key_ids ):
435+ try :
436+ aead = self ._get_aead (self ._config , self ._kek .kms_type , kms_key_id )
437+ return aead .decrypt (ciphertext , associated_data )
438+ except Exception as e :
439+ log .warning ("failed to decrypt with kek %s and kms key id %s" ,
440+ self ._kek .name , kms_key_id )
441+ if index == len (self ._kms_key_ids ) - 1 :
442+ raise RuleError (f"failed to decrypt with all KEKs for { self ._kek .name } " ) from e
443+ raise RuleError ("No KEK found for decryption" )
444+
445+ def _get_kms_key_ids (self ) -> List [str ]:
446+ kms_key_ids = [self ._kek .kms_key_id ]
447+ alternate_kms_key_ids = None
448+ if self ._kek .kms_props is not None :
449+ alternate_kms_key_ids = self ._kek .kms_props .properties .get (ENCRYPT_ALTERNATE_KMS_KEY_IDS )
450+ if alternate_kms_key_ids is None :
451+ alternate_kms_key_ids = self ._config .get (ENCRYPT_ALTERNATE_KMS_KEY_IDS )
452+ if alternate_kms_key_ids is not None :
453+ # Split the comma-separated list of alternate KMS key IDs and append to kms_key_ids
454+ kms_key_ids .extend ([id .strip () for id in alternate_kms_key_ids .split (',' ) if id .strip ()])
455+ return kms_key_ids
456+
457+ def _get_aead (self , config : dict , kms_type : str , kms_key_id : str ) -> aead .Aead :
458+ kek_url = kms_type + "://" + kms_key_id
415459 kms_client = self ._get_kms_client (config , kek_url )
416460 return kms_client .get_aead (kek_url )
417461
0 commit comments