Skip to content

Commit 01603bc

Browse files
committed
store cookies in jwt and option to csrf protect them
closes #5
1 parent e6812d5 commit 01603bc

File tree

6 files changed

+366
-46
lines changed

6 files changed

+366
-46
lines changed

examples/token_in_cookie.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,47 +5,54 @@
55
set_access_cookies, set_refresh_cookie
66

77

8-
# NOTE: This is being actively worked on, and is not complete yet. At present,
9-
# this code will not work! It should be rolled out next week sometime
10-
11-
128
app = Flask(__name__)
139
app.secret_key = 'super-secret' # Change this!
1410
jwt = JWTManager(app)
1511

1612

17-
# Configure application to store jwts in cookies with double submit csrf protection
18-
app.config['JWT_TOKEN_LOCATION'] = 'cookie'
19-
app.config['JWT_COOKIE_HTTPONLY'] = True
20-
app.config['JWT_COOKIE_SECURE'] = True
13+
# Configure application to store JWTs in cookies
14+
app.config['JWT_TOKEN_LOCATION'] = 'cookies'
15+
app.config['JWT_COOKIE_SECURE'] = False # In prod this should likely be True
2116

22-
app.config['JWT_ACCESS_COOKIE_NAME'] = 'access_token_cookie'
17+
# Set the cookie paths, so that you are only sending your access token cookie
18+
# to the access endpoints, and only sending your refresh token to the refresh
19+
# endpoint.
2320
app.config['JWT_ACCESS_COOKIE_PATH'] = '/api/'
24-
25-
app.config['JWT_REFRESH_COOKIE_NAME'] = 'refresh_token_cookie'
2621
app.config['JWT_REFRESH_COOKIE_PATH'] = '/token/refresh'
2722

23+
# Enable csrf double submit protection. Check out this for a simple overview
24+
# of what this is: http://stackoverflow.com/a/37396572/272689.
2825
app.config['JWT_COOKIE_CSRF_PROTECT'] = True
29-
app.config['JWT_ACCESS_CSRF_COOKIE_NAME'] = 'x_xsrf_access_token'
30-
app.config['JWT_REFRESH_CSRF_COOKIE_NAME'] = 'x_xsrf_refresh_token'
26+
27+
28+
# Now, whenever you make a request to a protected endpoint, you will need to
29+
# send in the access or refresh JWT via a cookie, as well as a custom header
30+
# which has the same csrf token that is in the cookie. You cannot access the
31+
# csrf token from the JWT, as httponly is set to true (and javascript thus
32+
# cannot see it), but you can get the JWT from a secondary cookie (that only
33+
# javascript on your site can access), and thus verify a csrf attack isn't
34+
# happening.
35+
#
36+
# You can modify the cookie name, csrf cookie name, and csrf header name via
37+
# various app.config options. Check the options page for details.
3138

3239

3340
@app.route('/token/auth', methods=['POST'])
3441
def login():
3542
username = request.json.get('username', None)
3643
password = request.json.get('password', None)
3744
if username != 'test' and password != 'test':
38-
return jsonify({"msg": "Bad username or password"}), 401
45+
return jsonify({'login': False}), 401
3946

4047
# Create the tokens we will be sending back to the user
4148
access_token = create_access_token(identity=username)
4249
refresh_token = create_refresh_token(identity=username)
4350

4451
# Set the JWTs and the CSRF double submit protection cookies in this response
45-
resp = jsonify({'login': True}), 200
52+
resp = jsonify({'login': True})
4653
set_access_cookies(resp, access_token)
4754
set_refresh_cookie(resp, refresh_token)
48-
return resp
55+
return resp, 200
4956

5057

5158
@app.route('/token/refresh', methods=['POST'])
@@ -56,9 +63,9 @@ def refresh():
5663
access_token = create_access_token(identity=current_user)
5764

5865
# Set the access JWT and CSRF double submit protection cookies in this response
59-
resp = jsonify({'refresh': True}), 200
66+
resp = jsonify({'refresh': True})
6067
set_access_cookies(resp, access_token)
61-
return resp
68+
return resp, 200
6269

6370

6471
# We do not need to make any changes here, all of the protected endpoints will

flask_jwt_extended/exceptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class InvalidHeaderError(JWTExtendedException):
2626
pass
2727

2828

29-
class NoAuthHeaderError(JWTExtendedException):
29+
class NoAuthorizationError(JWTExtendedException):
3030
"""
3131
An error getting header information from a request
3232
"""

flask_jwt_extended/utils.py

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
get_access_cookie_name, get_cookie_secure, get_access_cookie_path, \
1919
get_cookie_csrf_protect, get_access_csrf_cookie_name, \
2020
get_refresh_cookie_name, get_refresh_cookie_path, \
21-
get_refresh_csrf_cookie_name
21+
get_refresh_csrf_cookie_name, get_token_location, \
22+
get_access_csrf_header_name, get_refresh_csrf_header_name
2223
from flask_jwt_extended.exceptions import JWTEncodeError, JWTDecodeError, \
23-
InvalidHeaderError, NoAuthHeaderError, WrongTokenError, RevokedTokenError, \
24+
InvalidHeaderError, NoAuthorizationError, WrongTokenError, RevokedTokenError, \
2425
FreshTokenRequired
2526
from flask_jwt_extended.blacklist import check_if_token_revoked, store_token
2627

@@ -42,8 +43,8 @@ def get_jwt_claims():
4243

4344

4445
# TODO set csrf token in jwt when creating tokens (if enabled)
45-
def _create_xsrf_token():
46-
return binascii.hexlify(os.urandom(60))
46+
def _create_csrf_token():
47+
return str(uuid.uuid4())
4748

4849

4950
def _encode_access_token(identity, secret, algorithm, token_expire_delta,
@@ -80,6 +81,8 @@ def _encode_access_token(identity, secret, algorithm, token_expire_delta,
8081
'type': 'access',
8182
'user_claims': user_claims,
8283
}
84+
if get_token_location() == 'cookies' and get_cookie_csrf_protect():
85+
token_data['csrf'] = _create_csrf_token()
8386
encoded_token = jwt.encode(token_data, secret, algorithm).decode('utf-8')
8487

8588
# If blacklisting is enabled and configured to store access and refresh tokens,
@@ -110,6 +113,8 @@ def _encode_refresh_token(identity, secret, algorithm, token_expire_delta):
110113
'identity': identity,
111114
'type': 'refresh',
112115
}
116+
if get_token_location() == 'cookies' and get_cookie_csrf_protect():
117+
token_data['csrf'] = _create_csrf_token()
113118
encoded_token = jwt.encode(token_data, secret, algorithm).decode('utf-8')
114119

115120
# If blacklisting is enabled, store this token in our key-value store
@@ -142,14 +147,17 @@ def _decode_jwt(token, secret, algorithm):
142147
raise JWTDecodeError("Missing or invalid claim: fresh")
143148
if 'user_claims' not in data or not isinstance(data['user_claims'], dict):
144149
raise JWTDecodeError("Missing or invalid claim: user_claims")
150+
if get_token_location() == 'cookies' and get_cookie_csrf_protect():
151+
if 'csrf' not in data or not isinstance(data['csrf'], six.string_types):
152+
raise JWTDecodeError("Missing or invalid claim: csrf")
145153
return data
146154

147155

148-
def _decode_jwt_from_request():
156+
def _decode_jwt_from_headers():
149157
# Verify we have the auth header
150158
auth_header = request.headers.get('Authorization', None)
151159
if not auth_header:
152-
raise NoAuthHeaderError("Missing Authorization Header")
160+
raise NoAuthorizationError("Missing Authorization Header")
153161

154162
# Make sure the header is valid
155163
expected_header = get_jwt_header_type()
@@ -170,6 +178,37 @@ def _decode_jwt_from_request():
170178
return _decode_jwt(token, secret, algorithm)
171179

172180

181+
def _decode_jwt_from_cookies(type):
182+
if type == 'access':
183+
cookie_key = get_access_cookie_name()
184+
csrf_header_key = get_access_csrf_header_name()
185+
else:
186+
cookie_key = get_refresh_cookie_name()
187+
csrf_header_key = get_refresh_csrf_header_name()
188+
189+
token = request.cookies.get(cookie_key)
190+
if not token:
191+
raise NoAuthorizationError('Missing cookie "{}"'.format(cookie_key))
192+
secret = _get_secret_key()
193+
algorithm = get_algorithm()
194+
token = _decode_jwt(token, secret, algorithm)
195+
196+
if get_cookie_csrf_protect():
197+
csrf = request.headers.get(csrf_header_key, None)
198+
if not csrf or csrf != token['csrf']:
199+
raise NoAuthorizationError("Missing or invalid csrf double submit header")
200+
201+
return token
202+
203+
204+
def _decode_jwt_from_request(type):
205+
token_location = get_token_location()
206+
if token_location == 'headers':
207+
return _decode_jwt_from_headers()
208+
else:
209+
return _decode_jwt_from_cookies(type)
210+
211+
173212
def _handle_callbacks_on_error(fn):
174213
"""
175214
Helper decorator that will catch any exceptions we expect to encounter
@@ -183,7 +222,7 @@ def wrapper(*args, **kwargs):
183222

184223
try:
185224
return fn(*args, **kwargs)
186-
except NoAuthHeaderError:
225+
except NoAuthorizationError:
187226
return m.unauthorized_callback()
188227
except jwt.ExpiredSignatureError:
189228
return m.expired_token_callback()
@@ -211,7 +250,7 @@ def jwt_required(fn):
211250
@wraps(fn)
212251
def wrapper(*args, **kwargs):
213252
# Attempt to decode the token
214-
jwt_data = _decode_jwt_from_request()
253+
jwt_data = _decode_jwt_from_request(type='access')
215254

216255
# Verify this is an access token
217256
if jwt_data['type'] != 'access':
@@ -243,7 +282,7 @@ def fresh_jwt_required(fn):
243282
@wraps(fn)
244283
def wrapper(*args, **kwargs):
245284
# Attempt to decode the token
246-
jwt_data = _decode_jwt_from_request()
285+
jwt_data = _decode_jwt_from_request(type='access')
247286

248287
# Verify this is an access token
249288
if jwt_data['type'] != 'access':
@@ -276,7 +315,7 @@ def jwt_refresh_token_required(fn):
276315
@wraps(fn)
277316
def wrapper(*args, **kwargs):
278317
# Get the JWT
279-
jwt_data = _decode_jwt_from_request()
318+
jwt_data = _decode_jwt_from_request(type='refresh')
280319

281320
# verify this is a refresh token
282321
if jwt_data['type'] != 'refresh':
@@ -329,11 +368,7 @@ def _get_csrf_token(encoded_token):
329368
secret = _get_secret_key()
330369
algorithm = get_algorithm()
331370
token = _decode_jwt(encoded_token, secret, algorithm)
332-
try:
333-
return token['csrf']
334-
except KeyError:
335-
raise RuntimeError('JWT does not have csrf token set. Is '
336-
'JWT_COOKIE_CSRF_PROTECT set to True?')
371+
return token['csrf']
337372

338373

339374
def set_access_cookies(response, encoded_access_token):

tests/test_config.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,16 @@
66

77
from flask_jwt_extended.config import get_access_expires, get_refresh_expires, \
88
get_algorithm, get_blacklist_enabled, get_blacklist_store, \
9-
get_blacklist_checks, get_jwt_header_type, get_jwt_header_name
9+
get_blacklist_checks, get_jwt_header_type, get_jwt_header_name, \
10+
get_token_location, get_cookie_secure, get_access_cookie_name, \
11+
get_refresh_cookie_name, get_access_cookie_path, get_refresh_cookie_path, \
12+
get_cookie_csrf_protect, get_access_csrf_cookie_name, \
13+
get_refresh_csrf_cookie_name, get_access_csrf_header_name, \
14+
get_refresh_csrf_header_name
1015
from flask_jwt_extended import JWTManager
1116

1217

18+
1319
class TestEndpoints(unittest.TestCase):
1420

1521
def setUp(self):
@@ -18,37 +24,78 @@ def setUp(self):
1824
JWTManager(self.app)
1925
self.client = self.app.test_client()
2026

27+
#57, 69, 73, 77, 81, 85, 89, 93, 97, 101, 105
2128
def test_default_configs(self):
2229
with self.app.test_request_context():
30+
self.assertEqual(get_token_location(), 'headers')
31+
self.assertEqual(get_jwt_header_name(), 'Authorization')
32+
self.assertEqual(get_jwt_header_type(), 'Bearer')
33+
34+
self.assertEqual(get_cookie_secure(), False)
35+
self.assertEqual(get_access_cookie_name(), 'access_token_cookie')
36+
self.assertEqual(get_refresh_cookie_name(), 'refresh_token_cookie')
37+
self.assertEqual(get_access_cookie_path(), None)
38+
self.assertEqual(get_refresh_cookie_path(), None)
39+
self.assertEqual(get_cookie_csrf_protect(), True)
40+
self.assertEqual(get_access_csrf_cookie_name(), 'csrf_access_token')
41+
self.assertEqual(get_refresh_csrf_cookie_name(), 'csrf_refresh_token')
42+
self.assertEqual(get_access_csrf_header_name(), 'X-CSRF-ACCESS-TOKEN')
43+
self.assertEqual(get_refresh_csrf_header_name(), 'X-CSRF-REFRESH-TOKEN')
44+
2345
self.assertEqual(get_access_expires(), timedelta(minutes=15))
2446
self.assertEqual(get_refresh_expires(), timedelta(days=30))
2547
self.assertEqual(get_algorithm(), 'HS256')
2648
self.assertEqual(get_blacklist_enabled(), False)
2749
self.assertEqual(get_blacklist_store(), None)
2850
self.assertEqual(get_blacklist_checks(), 'refresh')
29-
self.assertEqual(get_jwt_header_name(), 'Authorization')
30-
self.assertEqual(get_jwt_header_type(), 'Bearer')
3151

3252
def test_override_configs(self):
53+
self.app.config['JWT_TOKEN_LOCATION'] = 'cookies'
54+
self.app.config['JWT_HEADER_NAME'] = 'Auth'
55+
self.app.config['JWT_HEADER_TYPE'] = 'JWT'
56+
57+
self.app.config['JWT_COOKIE_SECURE'] = True
58+
self.app.config['JWT_ACCESS_COOKIE_NAME'] = 'banana1'
59+
self.app.config['JWT_REFRESH_COOKIE_NAME'] = 'banana2'
60+
self.app.config['JWT_ACCESS_COOKIE_PATH'] = '/banana/'
61+
self.app.config['JWT_REFRESH_COOKIE_PATH'] = '/banana2/'
62+
self.app.config['JWT_COOKIE_CSRF_PROTECT'] = False
63+
self.app.config['JWT_ACCESS_CSRF_COOKIE_NAME'] = 'banana1a'
64+
self.app.config['JWT_REFRESH_CSRF_COOKIE_NAME'] = 'banana2a'
65+
self.app.config['JWT_ACCESS_CSRF_HEADER_NAME'] = 'banana1b'
66+
self.app.config['JWT_REFRESH_CSRF_HEADER_NAME'] = 'banana2b'
67+
3368
self.app.config['JWT_ACCESS_TOKEN_EXPIRES'] = timedelta(minutes=5)
3469
self.app.config['JWT_REFRESH_TOKEN_EXPIRES'] = timedelta(days=7)
3570
self.app.config['JWT_ALGORITHM'] = 'HS512'
3671
self.app.config['JWT_BLACKLIST_ENABLED'] = True
3772
self.app.config['JWT_BLACKLIST_STORE'] = simplekv.memory.DictStore()
3873
self.app.config['JWT_BLACKLIST_TOKEN_CHECKS'] = 'all'
39-
self.app.config['JWT_HEADER_NAME'] = 'Auth'
40-
self.app.config['JWT_HEADER_TYPE'] = 'JWT'
4174

4275
with self.app.test_request_context():
76+
self.assertEqual(get_token_location(), 'cookies')
77+
self.assertEqual(get_jwt_header_name(), 'Auth')
78+
self.assertEqual(get_jwt_header_type(), 'JWT')
79+
80+
self.assertEqual(get_cookie_secure(), True)
81+
self.assertEqual(get_access_cookie_name(), 'banana1')
82+
self.assertEqual(get_refresh_cookie_name(), 'banana2')
83+
self.assertEqual(get_access_cookie_path(), '/banana/')
84+
self.assertEqual(get_refresh_cookie_path(), '/banana2/')
85+
self.assertEqual(get_cookie_csrf_protect(), False)
86+
self.assertEqual(get_access_csrf_cookie_name(), 'banana1a')
87+
self.assertEqual(get_refresh_csrf_cookie_name(), 'banana2a')
88+
self.assertEqual(get_access_csrf_header_name(), 'banana1b')
89+
self.assertEqual(get_refresh_csrf_header_name(), 'banana2b')
90+
4391
self.assertEqual(get_access_expires(), timedelta(minutes=5))
4492
self.assertEqual(get_refresh_expires(), timedelta(days=7))
4593
self.assertEqual(get_algorithm(), 'HS512')
4694
self.assertEqual(get_blacklist_enabled(), True)
4795
self.assertIsInstance(get_blacklist_store(), simplekv.memory.DictStore)
4896
self.assertEqual(get_blacklist_checks(), 'all')
49-
self.assertEqual(get_jwt_header_name(), 'Auth')
50-
self.assertEqual(get_jwt_header_type(), 'JWT')
5197

98+
self.app.config['JWT_TOKEN_LOCATION'] = 'banana'
5299
self.app.config['JWT_HEADER_NAME'] = ''
53100
self.app.config['JWT_ACCESS_TOKEN_EXPIRES'] = 'banana'
54101
self.app.config['JWT_REFRESH_TOKEN_EXPIRES'] = 'banana'
@@ -61,3 +108,5 @@ def test_override_configs(self):
61108
get_access_expires()
62109
with self.assertRaises(RuntimeError):
63110
get_refresh_expires()
111+
with self.assertRaises(RuntimeError):
112+
get_token_location()

tests/test_jwt_encode_decode.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,3 +300,33 @@ def test_decode_invalid_jwt(self):
300300
}
301301
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
302302
_decode_jwt(encoded_token, 'secret', 'HS256')
303+
304+
# Missing and bad csrf tokens
305+
self.app.config['JWT_TOKEN_LOCATION'] = 'cookies'
306+
self.app.config['JWT_COOKIE_CSRF_PROTECTION'] = True
307+
with self.app.test_request_context():
308+
now = datetime.utcnow()
309+
with self.assertRaises(JWTDecodeError):
310+
token_data = {
311+
'exp': now + timedelta(minutes=5),
312+
'iat': now,
313+
'nbf': now,
314+
'jti': 'banana',
315+
'identity': 'banana',
316+
'type': 'refresh',
317+
}
318+
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
319+
_decode_jwt(encoded_token, 'secret', 'HS256')
320+
321+
with self.assertRaises(JWTDecodeError):
322+
token_data = {
323+
'exp': now + timedelta(minutes=5),
324+
'iat': now,
325+
'nbf': now,
326+
'jti': 'banana',
327+
'identity': 'banana',
328+
'type': 'refresh',
329+
'csrf': True
330+
}
331+
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
332+
_decode_jwt(encoded_token, 'secret', 'HS256')

0 commit comments

Comments
 (0)