@@ -1617,6 +1617,8 @@ namespace jwt {
16171617 explicit rs256 (const std::string& public_key, const std::string& private_key = " " ,
16181618 const std::string& public_key_password = " " , const std::string& private_key_password = " " )
16191619 : rsa(public_key, private_key, public_key_password, private_key_password, EVP_sha256, " RS256" ) {}
1620+
1621+ explicit rs256 (helper::evp_pkey_handle pkey) : rsa(pkey, EVP_sha256, " RS256" ) {}
16201622 };
16211623 /* *
16221624 * RS384 algorithm
@@ -1632,6 +1634,8 @@ namespace jwt {
16321634 explicit rs384 (const std::string& public_key, const std::string& private_key = " " ,
16331635 const std::string& public_key_password = " " , const std::string& private_key_password = " " )
16341636 : rsa(public_key, private_key, public_key_password, private_key_password, EVP_sha384, " RS384" ) {}
1637+
1638+ explicit rs384 (helper::evp_pkey_handle pkey) : rsa(pkey, EVP_sha384, " RS384" ) {}
16351639 };
16361640 /* *
16371641 * RS512 algorithm
@@ -1647,6 +1651,8 @@ namespace jwt {
16471651 explicit rs512 (const std::string& public_key, const std::string& private_key = " " ,
16481652 const std::string& public_key_password = " " , const std::string& private_key_password = " " )
16491653 : rsa(public_key, private_key, public_key_password, private_key_password, EVP_sha512, " RS512" ) {}
1654+
1655+ explicit rs512 (helper::evp_pkey_handle pkey) : rsa(pkey, EVP_sha512, " RS512" ) {}
16501656 };
16511657 /* *
16521658 * ES256 algorithm
@@ -3126,6 +3132,12 @@ namespace jwt {
31263132 };
31273133 } // namespace verify_ops
31283134
3135+ using alg_name = std::string;
3136+ using alg_list = std::vector<alg_name>;
3137+ using algorithms = std::unordered_map<std::string, alg_list>;
3138+ static const algorithms supported_alg = {{" RSA" , {" RS256" , " RS384" , " RS512" , " PS256" , " PS384" , " PS512" }},
3139+ {" EC" , {" ES256" , " ES384" , " ES512" , " ES256K" }},
3140+ {" oct" , {" HS256" , " HS384" , " HS512" }}};
31293141 /* *
31303142 * \brief JSON Web Key
31313143 *
@@ -3346,6 +3358,11 @@ namespace jwt {
33463358
33473359 std::string get_oct_key () const { return key.get_symmetric_key (); }
33483360
3361+ bool supports (const std::string& alg_name) const {
3362+ const alg_list& x = supported_alg.find (get_key_type ())->second ;
3363+ return std::find (x.begin (), x.end (), alg_name) != x.end ();
3364+ }
3365+
33493366 private:
33503367 class key {
33513368 public:
@@ -3488,6 +3505,11 @@ namespace jwt {
34883505 // / Supported algorithms
34893506 std::unordered_map<std::string, std::shared_ptr<algo_base>> algs;
34903507
3508+ typedef std::vector<jwt::jwk<json_traits>> key_list;
3509+ // / https://datatracker.ietf.org/doc/html/rfc7517#section-4.5 - kid to keys
3510+ typedef std::unordered_map<std::string, key_list> keysets;
3511+ keysets keys;
3512+
34913513 void verify_claims (const decoded_jwt<json_traits>& jwt, std::error_code& ec) const {
34923514 verify_ops::verify_context<json_traits> ctx{clock.now (), jwt, default_leeway};
34933515 for (auto & c : claims) {
@@ -3497,6 +3519,52 @@ namespace jwt {
34973519 }
34983520 }
34993521
3522+ static inline std::unique_ptr<algo_base> from_key_and_alg (const jwt::jwk<json_traits>& key,
3523+ const std::string& alg_name, std::error_code& ec) {
3524+ ec.clear ();
3525+ algorithms::const_iterator it = supported_alg.find (key.get_key_type ());
3526+ if (it == supported_alg.end ()) {
3527+ ec = error::token_verification_error::wrong_algorithm;
3528+ return nullptr ;
3529+ }
3530+
3531+ const alg_list& supported_jwt_algorithms = it->second ;
3532+ if (std::find (supported_jwt_algorithms.begin (), supported_jwt_algorithms.end (), alg_name) ==
3533+ supported_jwt_algorithms.end ()) {
3534+ ec = error::token_verification_error::wrong_algorithm;
3535+ return nullptr ;
3536+ }
3537+
3538+ if (alg_name == " RS256" ) {
3539+ return std::make_unique<algo<jwt::algorithm::rs256>>(jwt::algorithm::rs256 (key.get_pkey ()));
3540+ } else if (alg_name == " RS384" ) {
3541+ return std::make_unique<algo<jwt::algorithm::rs384>>(jwt::algorithm::rs384 (key.get_pkey ()));
3542+ } else if (alg_name == " RS512" ) {
3543+ return std::make_unique<algo<jwt::algorithm::rs512>>(jwt::algorithm::rs512 (key.get_pkey ()));
3544+ } else if (alg_name == " PS256" ) {
3545+ return std::make_unique<algo<jwt::algorithm::ps256>>(jwt::algorithm::ps256 (key.get_pkey ()));
3546+ } else if (alg_name == " PS384" ) {
3547+ return std::make_unique<algo<jwt::algorithm::ps384>>(jwt::algorithm::ps384 (key.get_pkey ()));
3548+ } else if (alg_name == " PS512" ) {
3549+ return std::make_unique<algo<jwt::algorithm::ps512>>(jwt::algorithm::ps512 (key.get_pkey ()));
3550+ } else if (alg_name == " ES256" ) {
3551+ return std::make_unique<algo<jwt::algorithm::es256>>(jwt::algorithm::es256 (key.get_pkey ()));
3552+ } else if (alg_name == " ES384" ) {
3553+ return std::make_unique<algo<jwt::algorithm::es384>>(jwt::algorithm::es384 (key.get_pkey ()));
3554+ } else if (alg_name == " ES512" ) {
3555+ return std::make_unique<algo<jwt::algorithm::es512>>(jwt::algorithm::es512 (key.get_pkey ()));
3556+ } else if (alg_name == " HS256" ) {
3557+ return std::make_unique<algo<jwt::algorithm::hs256>>(jwt::algorithm::hs256 (key.get_oct_key ()));
3558+ } else if (alg_name == " HS384" ) {
3559+ return std::make_unique<algo<jwt::algorithm::hs384>>(jwt::algorithm::hs384 (key.get_oct_key ()));
3560+ } else if (alg_name == " HS512" ) {
3561+ return std::make_unique<algo<jwt::algorithm::hs512>>(jwt::algorithm::hs512 (key.get_oct_key ()));
3562+ }
3563+
3564+ ec = error::token_verification_error::wrong_algorithm;
3565+ return nullptr ;
3566+ }
3567+
35003568 public:
35013569 /* *
35023570 * Constructor for building a new verifier instance
@@ -3661,6 +3729,18 @@ namespace jwt {
36613729 return *this ;
36623730 }
36633731
3732+ verifier& allow_key (const jwt::jwk<json_traits>& key) {
3733+ std::string keyid = " " ;
3734+ if (key.has_key_id ()) {
3735+ keyid = key.get_key_id ();
3736+ typename keysets::const_iterator it = keys.find (keyid);
3737+ if (it == keys.end ()) { keys[keyid] = key_list (); }
3738+ }
3739+
3740+ keys[keyid].push_back (key);
3741+ return *this ;
3742+ }
3743+
36643744 /* *
36653745 * Verify the given token.
36663746 * \param jwt Token to check
@@ -3681,13 +3761,32 @@ namespace jwt {
36813761 const typename json_traits::string_type data = jwt.get_header_base64 () + " ." + jwt.get_payload_base64 ();
36823762 const typename json_traits::string_type sig = jwt.get_signature ();
36833763 const std::string algo = jwt.get_algorithm ();
3684- if (algs.count (algo) == 0 ) {
3685- ec = error::token_verification_error::wrong_algorithm;
3686- return ;
3764+ std::string kid (" " );
3765+ if (jwt.has_header_claim (" kid" )) { kid = jwt.get_header_claim (" kid" ).as_string (); }
3766+
3767+ typename keysets::const_iterator key_set_it = keys.find (kid);
3768+ bool key_found = false ;
3769+ if (key_set_it != keys.end ()) {
3770+ const key_list& keys = key_set_it->second ;
3771+ for (const auto & key : keys) {
3772+ if (key.supports (algo)) {
3773+ key_found = true ;
3774+ auto alg = from_key_and_alg (key, algo, ec);
3775+ alg->verify (data, sig, ec);
3776+ break ;
3777+ }
3778+ }
3779+ }
3780+
3781+ if (!key_found) {
3782+ if (algs.count (algo) == 0 ) {
3783+ ec = error::token_verification_error::wrong_algorithm;
3784+ return ;
3785+ }
3786+ algs.at (algo)->verify (data, sig, ec);
36873787 }
3688- algs.at (algo)->verify (data, sig, ec);
3689- if (ec) return ;
36903788
3789+ if (ec) return ;
36913790 verify_claims (jwt, ec);
36923791 }
36933792 };
0 commit comments