Skip to content

Commit ad09c9f

Browse files
committed
Fix jwt_required with flask_restless (refs #10)
1 parent 2d25870 commit ad09c9f

File tree

3 files changed

+57
-51
lines changed

3 files changed

+57
-51
lines changed

flask_jwt_extended/jwt_manager.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,40 @@
11
from flask import jsonify
22

3+
from flask_jwt_extended.exceptions import JWTDecodeError, NoAuthorizationError, \
4+
InvalidHeaderError, WrongTokenError, RevokedTokenError, FreshTokenRequired
5+
from jwt import ExpiredSignatureError, InvalidTokenError
6+
37

48
class JWTManager:
59
def __init__(self, app=None):
610
# Function that will be called to add custom user claims to a JWT.
711
self.user_claims_callback = lambda _: {}
812

913
# Function that will be called when an expired token is received
10-
self.expired_token_callback = lambda: (
14+
self._expired_token_callback = lambda: (
1115
jsonify({'msg': 'Token has expired'}), 401
1216
)
1317

1418
# Function that will be called when an invalid token is received
15-
self.invalid_token_callback = lambda err: (
19+
self._invalid_token_callback = lambda err: (
1620
jsonify({'msg': err}), 422
1721
)
1822

1923
# Function that will be called when attempting to access a protected
2024
# endpoint without a valid token
21-
self.unauthorized_callback = lambda: (
25+
self._unauthorized_callback = lambda: (
2226
jsonify({'msg': 'Missing Authorization Header'}), 401
2327
)
2428

2529
# Function that will be called when attempting to access a fresh_jwt_required
2630
# endpoint with a valid token that is not fresh
27-
self.needs_fresh_token_callback = lambda: (
31+
self._needs_fresh_token_callback = lambda: (
2832
jsonify({'msg': 'Fresh token required'}), 401
2933
)
3034

3135
# Function that will be called when a revoked token attempts to access
3236
# a protected endpoint
33-
self.revoked_token_callback = lambda: (
37+
self._revoked_token_callback = lambda: (
3438
jsonify({'msg': 'Token has been revoked'}), 401
3539
)
3640

@@ -45,6 +49,38 @@ def init_app(self, app):
4549
"""
4650
app.jwt_manager = self
4751

52+
@app.errorhandler(NoAuthorizationError)
53+
def handle_auth_error(e):
54+
return self._unauthorized_callback()
55+
56+
@app.errorhandler(ExpiredSignatureError)
57+
def handle_expired_error(e):
58+
return self._expired_token_callback()
59+
60+
@app.errorhandler(InvalidHeaderError)
61+
def handle_invalid_header_error(e):
62+
return self._invalid_token_callback(str(e))
63+
64+
@app.errorhandler(InvalidTokenError)
65+
def handle_invalid_token_error(e):
66+
return self._invalid_token_callback(str(e))
67+
68+
@app.errorhandler(JWTDecodeError)
69+
def handle_jwt_decode_error(e):
70+
return self._invalid_token_callback(str(e))
71+
72+
@app.errorhandler(WrongTokenError)
73+
def handle_wrong_token_error(e):
74+
return self._invalid_token_callback(str(e))
75+
76+
@app.errorhandler(RevokedTokenError)
77+
def hanlde_revoked_token_error(e):
78+
return self._revoked_token_callback()
79+
80+
@app.errorhandler(FreshTokenRequired)
81+
def handle_fresh_token_required(e):
82+
return self._needs_fresh_token_callback()
83+
4884
def user_claims_loader(self, callback):
4985
"""
5086
This sets the callback method for adding custom user claims to a JWT.
@@ -66,7 +102,7 @@ def expired_token_loader(self, callback):
66102
67103
Callback must be a function that takes zero arguments.
68104
"""
69-
self.expired_token_callback = callback
105+
self._expired_token_callback = callback
70106
return callback
71107

72108
def invalid_token_loader(self, callback):
@@ -79,7 +115,7 @@ def invalid_token_loader(self, callback):
79115
Callback must be a function that takes only one argument, which is the
80116
error message of why the token is invalid.
81117
"""
82-
self.invalid_token_callback = callback
118+
self._invalid_token_callback = callback
83119
return callback
84120

85121
def unauthorized_loader(self, callback):
@@ -92,7 +128,7 @@ def unauthorized_loader(self, callback):
92128
Callback must be a function that takes only one argument, which is the
93129
error message of why the token is invalid.
94130
"""
95-
self.unauthorized_callback = callback
131+
self._unauthorized_callback = callback
96132
return callback
97133

98134
def needs_fresh_token_loader(self, callback):
@@ -105,7 +141,7 @@ def needs_fresh_token_loader(self, callback):
105141
106142
Callback must be a function that takes no arguments.
107143
"""
108-
self.needs_fresh_token_callback = callback
144+
self._needs_fresh_token_callback = callback
109145
return callback
110146

111147
def revoked_token_loader(self, callback):
@@ -118,5 +154,5 @@ def revoked_token_loader(self, callback):
118154
119155
Callback must be a function that takes no arguments.
120156
"""
121-
self.revoked_token_callback = callback
157+
self._revoked_token_callback = callback
122158
return callback

flask_jwt_extended/utils.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
get_refresh_csrf_cookie_name, get_token_location, \
2121
get_csrf_header_name
2222
from flask_jwt_extended.exceptions import JWTEncodeError, JWTDecodeError, \
23-
InvalidHeaderError, NoAuthorizationError, WrongTokenError, RevokedTokenError, \
23+
InvalidHeaderError, NoAuthorizationError, WrongTokenError, \
2424
FreshTokenRequired
2525
from flask_jwt_extended.blacklist import check_if_token_revoked, store_token
2626

@@ -206,33 +206,6 @@ def _decode_jwt_from_request(type):
206206
return _decode_jwt_from_cookies(type)
207207

208208

209-
def _handle_callbacks_on_error(fn):
210-
"""
211-
Helper decorator that will catch any exceptions we expect to encounter
212-
when dealing with a JWT, and call the appropriate callback function for
213-
handling that error. Callback functions can be set in using the *_loader
214-
methods in jwt_manager.
215-
"""
216-
@wraps(fn)
217-
def wrapper(*args, **kwargs):
218-
m = current_app.jwt_manager
219-
220-
try:
221-
return fn(*args, **kwargs)
222-
except NoAuthorizationError:
223-
return m.unauthorized_callback()
224-
except jwt.ExpiredSignatureError:
225-
return m.expired_token_callback()
226-
except (InvalidHeaderError, jwt.InvalidTokenError, JWTDecodeError,
227-
WrongTokenError) as e:
228-
return m.invalid_token_callback(str(e))
229-
except RevokedTokenError:
230-
return m.revoked_token_callback()
231-
except FreshTokenRequired:
232-
return m.needs_fresh_token_callback()
233-
return wrapper
234-
235-
236209
def jwt_required(fn):
237210
"""
238211
If you decorate a vew with this, it will ensure that the requester has a valid
@@ -243,7 +216,6 @@ def jwt_required(fn):
243216
244217
:param fn: The view function to decorate
245218
"""
246-
@_handle_callbacks_on_error
247219
@wraps(fn)
248220
def wrapper(*args, **kwargs):
249221
# Attempt to decode the token
@@ -275,7 +247,6 @@ def fresh_jwt_required(fn):
275247
276248
:param fn: The view function to decorate
277249
"""
278-
@_handle_callbacks_on_error
279250
@wraps(fn)
280251
def wrapper(*args, **kwargs):
281252
# Attempt to decode the token
@@ -308,7 +279,6 @@ def jwt_refresh_token_required(fn):
308279
valid JWT refresh token before calling the actual view. If the token is
309280
invalid, expired, not present, etc, the appropiate callback will be called
310281
"""
311-
@_handle_callbacks_on_error
312282
@wraps(fn)
313283
def wrapper(*args, **kwargs):
314284
# Get the JWT

tests/test_jwt_manager.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_default_user_claims_callback(self):
3737
def test_default_expired_token_callback(self):
3838
with self.app.test_request_context():
3939
m = JWTManager(self.app)
40-
result = m.expired_token_callback()
40+
result = m._expired_token_callback()
4141
status_code, data = self._parse_callback_result(result)
4242

4343
self.assertEqual(status_code, 401)
@@ -47,7 +47,7 @@ def test_default_invalid_token_callback(self):
4747
with self.app.test_request_context():
4848
m = JWTManager(self.app)
4949
err = "Test error"
50-
result = m.invalid_token_callback(err)
50+
result = m._invalid_token_callback(err)
5151
status_code, data = self._parse_callback_result(result)
5252

5353
self.assertEqual(status_code, 422)
@@ -56,7 +56,7 @@ def test_default_invalid_token_callback(self):
5656
def test_default_unauthorized_callback(self):
5757
with self.app.test_request_context():
5858
m = JWTManager(self.app)
59-
result = m.unauthorized_callback()
59+
result = m._unauthorized_callback()
6060
status_code, data = self._parse_callback_result(result)
6161

6262
self.assertEqual(status_code, 401)
@@ -65,7 +65,7 @@ def test_default_unauthorized_callback(self):
6565
def test_default_needs_fresh_token_callback(self):
6666
with self.app.test_request_context():
6767
m = JWTManager(self.app)
68-
result = m.needs_fresh_token_callback()
68+
result = m._needs_fresh_token_callback()
6969
status_code, data = self._parse_callback_result(result)
7070

7171
self.assertEqual(status_code, 401)
@@ -74,7 +74,7 @@ def test_default_needs_fresh_token_callback(self):
7474
def test_default_revoked_token_callback(self):
7575
with self.app.test_request_context():
7676
m = JWTManager(self.app)
77-
result = m.revoked_token_callback()
77+
result = m._revoked_token_callback()
7878
status_code, data = self._parse_callback_result(result)
7979

8080
self.assertEqual(status_code, 401)
@@ -98,7 +98,7 @@ def test_custom_expired_token_callback(self):
9898
def custom_expired_token():
9999
return jsonify({"res": "TOKEN IS EXPIRED FOOL"}), 422
100100

101-
result = m.expired_token_callback()
101+
result = m._expired_token_callback()
102102
status_code, data = self._parse_callback_result(result)
103103

104104
self.assertEqual(status_code, 422)
@@ -113,7 +113,7 @@ def test_custom_invalid_token_callback(self):
113113
def custom_invalid_token(err):
114114
return jsonify({"err": err}), 200
115115

116-
result = m.invalid_token_callback(err)
116+
result = m._invalid_token_callback(err)
117117
status_code, data = self._parse_callback_result(result)
118118

119119
self.assertEqual(status_code, 200)
@@ -127,7 +127,7 @@ def test_custom_unauthorized_callback(self):
127127
def custom_unauthorized():
128128
return jsonify({"err": "GOTTA LOGIN FOOL"}), 200
129129

130-
result = m.unauthorized_callback()
130+
result = m._unauthorized_callback()
131131
status_code, data = self._parse_callback_result(result)
132132

133133
self.assertEqual(status_code, 200)
@@ -141,7 +141,7 @@ def test_custom_needs_fresh_token_callback(self):
141141
def custom_token_needs_refresh():
142142
return jsonify({'sub_status': 101}), 200
143143

144-
result = m.needs_fresh_token_callback()
144+
result = m._needs_fresh_token_callback()
145145
status_code, data = self._parse_callback_result(result)
146146

147147
self.assertEqual(status_code, 200)
@@ -154,7 +154,7 @@ def test_custom_revoked_token_callback(self):
154154
@m.revoked_token_loader
155155
def custom_revoken_token():
156156
return jsonify({"err": "Nice knowing you!"}), 422
157-
result = m.revoked_token_callback()
157+
result = m._revoked_token_callback()
158158
status_code, data = self._parse_callback_result(result)
159159

160160
self.assertEqual(status_code, 422)

0 commit comments

Comments
 (0)