|
3 | 3 | from datetime import timedelta |
4 | 4 |
|
5 | 5 | from flask import Flask |
6 | | -from jwt import ExpiredSignatureError |
| 6 | +from jwt import ExpiredSignatureError, InvalidSignatureError |
7 | 7 |
|
8 | 8 | from flask_jwt_extended import ( |
9 | 9 | JWTManager, create_access_token, decode_token, create_refresh_token, |
@@ -48,7 +48,9 @@ def empty_user_loader_return(identity): |
48 | 48 | # returned via the decode_token call |
49 | 49 | with app.test_request_context(): |
50 | 50 | token = create_access_token('username') |
51 | | - pure_decoded = jwt.decode(token, config.decode_key, algorithms=[config.algorithm]) |
| 51 | + unverfied_claims = jwt.decode(token, verify=False, algorithms=[config.algorithm]) |
| 52 | + decode_key = jwtM._decode_key_callback(unverfied_claims) |
| 53 | + pure_decoded = jwt.decode(token, decode_key, algorithms=[config.algorithm]) |
52 | 54 | assert config.user_claims_key not in pure_decoded |
53 | 55 | extension_decoded = decode_token(token) |
54 | 56 | assert config.user_claims_key in extension_decoded |
@@ -117,3 +119,27 @@ def test_get_jti(app, default_access_token): |
117 | 119 |
|
118 | 120 | with app.test_request_context(): |
119 | 121 | assert default_access_token['jti'] == get_jti(token) |
| 122 | + |
| 123 | + |
| 124 | +def test_decode_key_callback(app, default_access_token): |
| 125 | + jwtM = get_jwt_manager(app) |
| 126 | + app.config['JWT_SECRET_KEY'] = 'correct secret' |
| 127 | + |
| 128 | + @jwtM.decode_key_loader |
| 129 | + def get_decode_key_1(claims): |
| 130 | + return 'different secret' |
| 131 | + |
| 132 | + assert jwtM._decode_key_callback({}) == 'different secret' |
| 133 | + |
| 134 | + with pytest.raises(InvalidSignatureError): |
| 135 | + with app.test_request_context(): |
| 136 | + token = encode_token(app, default_access_token) |
| 137 | + decode_token(token) |
| 138 | + |
| 139 | + @jwtM.decode_key_loader |
| 140 | + def get_decode_key_2(claims): |
| 141 | + return 'correct secret' |
| 142 | + |
| 143 | + with app.test_request_context(): |
| 144 | + token = encode_token(app, default_access_token) |
| 145 | + decode_token(token) |
0 commit comments