@@ -297,6 +297,7 @@ def __init__(self, version=None, header_size=IMAGE_HEADER_SIZE,
297297 self .enctlv_len = 0
298298 self .max_align = max (DEFAULT_MAX_ALIGN , align ) if max_align is None else int (max_align )
299299 self .non_bootable = non_bootable
300+ self .key_ids = None
300301
301302 if self .max_align == DEFAULT_MAX_ALIGN :
302303 self .boot_magic = bytes ([
@@ -464,32 +465,40 @@ def ecies_hkdf(self, enckey, plainkey):
464465 format = PublicFormat .Raw )
465466 return cipherkey , ciphermac , pubk
466467
467- def create (self , key , public_key_format , enckey , dependencies = None ,
468+ def create (self , keys , public_key_format , enckey , dependencies = None ,
468469 sw_type = None , custom_tlvs = None , compression_tlvs = None ,
469470 compression_type = None , encrypt_keylen = 128 , clear = False ,
470471 fixed_sig = None , pub_key = None , vector_to_sign = None ,
471472 user_sha = 'auto' , is_pure = False , keep_comp_size = False , dont_encrypt = False ):
472473 self .enckey = enckey
473474
474- # key decides on sha, then pub_key; of both are none default is used
475- check_key = key if key is not None else pub_key
475+ # key decides on sha, then pub_key; if both are none default is used
476+ check_key = keys [ 0 ] if keys [ 0 ] is not None else pub_key
476477 hash_algorithm , hash_tlv = key_and_user_sha_to_alg_and_tlv (check_key , user_sha , is_pure )
477478
478479 # Calculate the hash of the public key
479- if key is not None :
480- pub = key .get_public_bytes ()
481- sha = hash_algorithm ()
482- sha .update (pub )
483- pubbytes = sha .digest ()
484- elif pub_key is not None :
485- if hasattr (pub_key , 'sign' ):
486- print (os .path .basename (__file__ ) + ": sign the payload" )
487- pub = pub_key .get_public_bytes ()
488- sha = hash_algorithm ()
489- sha .update (pub )
490- pubbytes = sha .digest ()
480+ pub_digests = []
481+ pub_list = []
482+
483+ if keys is None :
484+ if pub_key is not None :
485+ if hasattr (pub_key , 'sign' ):
486+ print (os .path .basename (__file__ ) + ": sign the payload" )
487+ pub = pub_key .get_public_bytes ()
488+ sha = hash_algorithm ()
489+ sha .update (pub )
490+ pubbytes = sha .digest ()
491+ else :
492+ pubbytes = bytes (hashlib .sha256 ().digest_size )
491493 else :
492- pubbytes = bytes (hashlib .sha256 ().digest_size )
494+ for key in keys or []:
495+ pub = key .get_public_bytes ()
496+ sha = hash_algorithm ()
497+ sha .update (pub )
498+ pubbytes = sha .digest ()
499+ pub_digests .append (pubbytes )
500+ pub_list .append (pub )
501+
493502
494503 protected_tlv_size = 0
495504
@@ -517,10 +526,14 @@ def create(self, key, public_key_format, enckey, dependencies=None,
517526 # value later.
518527 digest = bytes (hash_algorithm ().digest_size )
519528
529+ if pub_digests :
530+ boot_pub_digest = pub_digests [0 ]
531+ else :
532+ boot_pub_digest = pubbytes
520533 # Create CBOR encoded boot record
521534 boot_record = create_sw_component_data (sw_type , image_version ,
522535 hash_tlv , digest ,
523- pubbytes )
536+ boot_pub_digest )
524537
525538 protected_tlv_size += TLV_SIZE + len (boot_record )
526539
@@ -639,33 +652,39 @@ def create(self, key, public_key_format, enckey, dependencies=None,
639652 print (os .path .basename (__file__ ) + ': export digest' )
640653 return
641654
642- if self . key_ids is not None :
643- self . _add_key_id_tlv_to_unprotected ( tlv , self . key_ids [ 0 ] )
655+ if fixed_sig is not None and keys is not None :
656+ raise click . UsageError ( "Can not sign using key and provide fixed-signature at the same time" )
644657
645- if key is not None or fixed_sig is not None :
646- if public_key_format == 'hash' :
647- tlv .add ('KEYHASH' , pubbytes )
648- else :
649- tlv .add ('PUBKEY' , pub )
658+ if fixed_sig is not None :
659+ tlv .add (pub_key .sig_tlv (), fixed_sig ['value' ])
660+ self .signatures [0 ] = fixed_sig ['value' ]
661+ else :
662+ # Multi-signature handling: iterate through each provided key and sign.
663+ self .signatures = []
664+ for i , key in enumerate (keys ):
665+ # If key IDs are provided, and we have enough for this key, add it first.
666+ if self .key_ids is not None and len (self .key_ids ) > i :
667+ # Convert key id (an integer) to 4-byte big-endian bytes.
668+ kid_bytes = self .key_ids [i ].to_bytes (4 , 'big' )
669+ tlv .add ('KEYID' , kid_bytes ) # Using the TLV tag that corresponds to key IDs.
670+
671+ if public_key_format == 'hash' :
672+ tlv .add ('KEYHASH' , pub_digests [i ])
673+ else :
674+ tlv .add ('PUBKEY' , pub_list [i ])
650675
651- if key is not None and fixed_sig is None :
652676 # `sign` expects the full image payload (hashing done
653677 # internally), while `sign_digest` expects only the digest
654678 # of the payload
655-
656679 if hasattr (key , 'sign' ):
657680 print (os .path .basename (__file__ ) + ": sign the payload" )
658681 sig = key .sign (bytes (self .payload ))
659682 else :
660683 print (os .path .basename (__file__ ) + ": sign the digest" )
661684 sig = key .sign_digest (message )
662685 tlv .add (key .sig_tlv (), sig )
663- self .signature = sig
664- elif fixed_sig is not None and key is None :
665- tlv .add (pub_key .sig_tlv (), fixed_sig ['value' ])
666- self .signature = fixed_sig ['value' ]
667- else :
668- raise click .UsageError ("Can not sign using key and provide fixed-signature at the same time" )
686+ self .signatures .append (sig )
687+
669688
670689 # At this point the image was hashed + signed, we can remove the
671690 # protected TLVs from the payload (will be re-added later)
@@ -714,7 +733,7 @@ def get_struct_endian(self):
714733 return STRUCT_ENDIAN_DICT [self .endian ]
715734
716735 def get_signature (self ):
717- return self .signature
736+ return self .signatures
718737
719738 def get_infile_data (self ):
720739 return self .infile_data
@@ -824,75 +843,100 @@ def verify(imgfile, key):
824843 if magic != IMAGE_MAGIC :
825844 return VerifyResult .INVALID_MAGIC , None , None , None
826845
827- tlv_off = header_size + img_size
846+ # Locate the first TLV info header
847+ base_tlv_off = header_size + img_size
848+ tlv_off = base_tlv_off
828849 tlv_info = b [tlv_off :tlv_off + TLV_INFO_SIZE ]
829850 magic , tlv_tot = struct .unpack ('HH' , tlv_info )
851+
852+ # If it's the protected-TLV block, skip it
830853 if magic == TLV_PROT_INFO_MAGIC :
831- tlv_off += tlv_tot
854+ tlv_off += TLV_INFO_SIZE + tlv_tot
832855 tlv_info = b [tlv_off :tlv_off + TLV_INFO_SIZE ]
833856 magic , tlv_tot = struct .unpack ('HH' , tlv_info )
834857
835858 if magic != TLV_INFO_MAGIC :
836859 return VerifyResult .INVALID_TLV_INFO_MAGIC , None , None , None
837860
838- # This is set by existence of TLV SIG_PURE
839- is_pure = False
861+ # Define the unprotected-TLV window
862+ unprot_off = tlv_off + TLV_INFO_SIZE
863+ unprot_end = unprot_off + tlv_tot
840864
841- prot_tlv_size = tlv_off
842- hash_region = b [:prot_tlv_size ]
843- tlv_end = tlv_off + tlv_tot
844- tlv_off += TLV_INFO_SIZE # skip tlv info
865+ # Region up to the start of unprotected TLVs is hashed
866+ prot_tlv_end = unprot_off - TLV_INFO_SIZE
867+ hash_region = b [:prot_tlv_end ]
845868
846- # First scan all TLVs in search of SIG_PURE
847- while tlv_off < tlv_end :
848- tlv = b [tlv_off :tlv_off + TLV_SIZE ]
869+ # This is set by existence of TLV SIG_PURE
870+ is_pure = False
871+ scan_off = unprot_off
872+ while scan_off < unprot_end :
873+ tlv = b [scan_off :scan_off + TLV_SIZE ]
849874 tlv_type , _ , tlv_len = struct .unpack ('BBH' , tlv )
850875 if tlv_type == TLV_VALUES ['SIG_PURE' ]:
851876 is_pure = True
852877 break
853- tlv_off += TLV_SIZE + tlv_len
878+ scan_off += TLV_SIZE + tlv_len
854879
880+ if key is not None and not isinstance (key , list ):
881+ key = [key ]
882+
883+ verify_results = []
884+ scan_off = unprot_off
855885 digest = None
856- tlv_off = header_size + img_size
857- tlv_end = tlv_off + tlv_tot
858- tlv_off += TLV_INFO_SIZE # skip tlv info
859- while tlv_off < tlv_end :
860- tlv = b [tlv_off : tlv_off + TLV_SIZE ]
886+ prot_tlv_size = unprot_off - TLV_INFO_SIZE
887+
888+ # Verify hash and signatures
889+ while scan_off < unprot_end :
890+ tlv = b [scan_off : scan_off + TLV_SIZE ]
861891 tlv_type , _ , tlv_len = struct .unpack ('BBH' , tlv )
862892 if is_sha_tlv (tlv_type ):
863- if not tlv_matches_key_type (tlv_type , key ):
893+ if not tlv_matches_key_type (tlv_type , key [ 0 ] ):
864894 return VerifyResult .KEY_MISMATCH , None , None , None
865- off = tlv_off + TLV_SIZE
895+ off = scan_off + TLV_SIZE
866896 digest = get_digest (tlv_type , hash_region )
867- if digest == b [off :off + tlv_len ]:
868- if key is None :
869- return VerifyResult .OK , version , digest , None
870- else :
871- return VerifyResult .INVALID_HASH , None , None , None
872- elif not is_pure and key is not None and tlv_type == TLV_VALUES [key .sig_tlv ()]:
873- off = tlv_off + TLV_SIZE
874- tlv_sig = b [off :off + tlv_len ]
875- payload = b [:prot_tlv_size ]
876- try :
877- if hasattr (key , 'verify' ):
878- key .verify (tlv_sig , payload )
879- else :
880- key .verify_digest (tlv_sig , digest )
881- return VerifyResult .OK , version , digest , None
882- except InvalidSignature :
883- # continue to next TLV
884- pass
897+ if digest != b [off :off + tlv_len ]:
898+ verify_results .append (("Digest" , "INVALID_HASH" ))
899+
900+ elif not is_pure and key is not None and tlv_type == TLV_VALUES [key [0 ].sig_tlv ()]:
901+ for idx , k in enumerate (key ):
902+ if tlv_type == TLV_VALUES [k .sig_tlv ()]:
903+ off = scan_off + TLV_SIZE
904+ tlv_sig = b [off :off + tlv_len ]
905+ payload = b [:prot_tlv_size ]
906+ try :
907+ if hasattr (k , 'verify' ):
908+ k .verify (tlv_sig , payload )
909+ else :
910+ k .verify_digest (tlv_sig , digest )
911+ verify_results .append ((f"Key { idx } " , "OK" ))
912+ break
913+ except InvalidSignature :
914+ # continue to next TLV
915+ verify_results .append ((f"Key { idx } " , "INVALID_SIGNATURE" ))
916+ continue
917+
885918 elif is_pure and key is not None and tlv_type in ALLOWED_PURE_SIG_TLVS :
886- off = tlv_off + TLV_SIZE
919+ # pure signature verification
920+ off = scan_off + TLV_SIZE
887921 tlv_sig = b [off :off + tlv_len ]
922+ k = key [0 ]
888923 try :
889- key .verify_digest (tlv_sig , hash_region )
924+ k .verify_digest (tlv_sig , hash_region )
890925 return VerifyResult .OK , version , None , tlv_sig
891926 except InvalidSignature :
892- # continue to next TLV
893- pass
894- tlv_off += TLV_SIZE + tlv_len
895- return VerifyResult .INVALID_SIGNATURE , None , None , None
927+ return VerifyResult .INVALID_SIGNATURE , None , None , None
928+
929+ scan_off += TLV_SIZE + tlv_len
930+ # Now print out the verification results:
931+ for k , result in verify_results :
932+ print (f"{ k } : { result } " )
933+
934+ # Decide on a final return (for example, OK only if at least one signature is valid)
935+ if any (result == "OK" for _ , result in verify_results ):
936+ return VerifyResult .OK , version , digest , None
937+ else :
938+ return VerifyResult .INVALID_SIGNATURE , None , None , None
939+
896940
897941 def set_key_ids (self , key_ids ):
898942 """Set list of key IDs (integers) to be inserted before each signature."""
0 commit comments