Skip to content

Commit 63075d3

Browse files
jeanphixvimalloc
authored andcommitted
Allow to specify user_claims at token creation (#229)
1 parent 3f37e1e commit 63075d3

File tree

3 files changed

+51
-10
lines changed

3 files changed

+51
-10
lines changed

flask_jwt_extended/jwt_manager.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -436,14 +436,12 @@ def encode_key_loader(self, callback):
436436
self._encode_key_callback = callback
437437
return callback
438438

439-
def _create_refresh_token(self, identity, expires_delta=None):
439+
def _create_refresh_token(self, identity, expires_delta=None, user_claims=None):
440440
if expires_delta is None:
441441
expires_delta = config.refresh_expires
442442

443-
if config.user_claims_in_refresh_token:
443+
if user_claims is None and config.user_claims_in_refresh_token:
444444
user_claims = self._user_claims_callback(identity)
445-
else:
446-
user_claims = None
447445

448446
refresh_token = encode_refresh_token(
449447
identity=self._user_identity_callback(identity),
@@ -458,17 +456,20 @@ def _create_refresh_token(self, identity, expires_delta=None):
458456
)
459457
return refresh_token
460458

461-
def _create_access_token(self, identity, fresh=False, expires_delta=None):
459+
def _create_access_token(self, identity, fresh=False, expires_delta=None, user_claims=None):
462460
if expires_delta is None:
463461
expires_delta = config.access_expires
464462

463+
if user_claims is None:
464+
user_claims = self._user_claims_callback(identity)
465+
465466
access_token = encode_access_token(
466467
identity=self._user_identity_callback(identity),
467468
secret=self._encode_key_callback(identity),
468469
algorithm=config.algorithm,
469470
expires_delta=expires_delta,
470471
fresh=fresh,
471-
user_claims=self._user_claims_callback(identity),
472+
user_claims=user_claims,
472473
csrf=config.csrf_protect,
473474
identity_claim_key=config.identity_claim_key,
474475
user_claims_key=config.user_claims_key,

flask_jwt_extended/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _get_jwt_manager():
114114
"application before using this method")
115115

116116

117-
def create_access_token(identity, fresh=False, expires_delta=None):
117+
def create_access_token(identity, fresh=False, expires_delta=None, user_claims=None):
118118
"""
119119
Create a new access token.
120120
@@ -134,13 +134,14 @@ def create_access_token(identity, fresh=False, expires_delta=None):
134134
expiration. If this is None, it will use the
135135
'JWT_ACCESS_TOKEN_EXPIRES` config value
136136
(see :ref:`Configuration Options`)
137+
:param user_claims: Optionnal JSON serializable to override user claims.
137138
:return: An encoded access token
138139
"""
139140
jwt_manager = _get_jwt_manager()
140-
return jwt_manager._create_access_token(identity, fresh, expires_delta)
141+
return jwt_manager._create_access_token(identity, fresh, expires_delta, user_claims)
141142

142143

143-
def create_refresh_token(identity, expires_delta=None):
144+
def create_refresh_token(identity, expires_delta=None, user_claims=None):
144145
"""
145146
Creates a new refresh token.
146147
@@ -155,10 +156,11 @@ def create_refresh_token(identity, expires_delta=None):
155156
expiration. If this is None, it will use the
156157
'JWT_REFRESH_TOKEN_EXPIRES` config value
157158
(see :ref:`Configuration Options`)
159+
:param user_claims: Optionnal JSON serializable to override user claims.
158160
:return: An encoded refresh token
159161
"""
160162
jwt_manager = _get_jwt_manager()
161-
return jwt_manager._create_refresh_token(identity, expires_delta)
163+
return jwt_manager._create_refresh_token(identity, expires_delta, user_claims)
162164

163165

164166
def has_user_loader():

tests/test_user_claims_loader.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,41 @@ def add_claims(identity):
137137
response = test_client.get('/protected2', headers=make_headers(refresh_token))
138138
assert response.get_json() == {'foo': 'bar'}
139139
assert response.status_code == 200
140+
141+
142+
def test_user_claim_in_refresh_token_specified_at_creation(app):
143+
app.config['JWT_CLAIMS_IN_REFRESH_TOKEN'] = True
144+
145+
with app.test_request_context():
146+
refresh_token = create_refresh_token('username', user_claims={'foo': 'bar'})
147+
148+
test_client = app.test_client()
149+
response = test_client.get('/protected2', headers=make_headers(refresh_token))
150+
assert response.get_json() == {'foo': 'bar'}
151+
assert response.status_code == 200
152+
153+
154+
def test_user_claims_in_access_token_specified_at_creation(app):
155+
with app.test_request_context():
156+
access_token = create_access_token('username', user_claims={'foo': 'bar'})
157+
158+
test_client = app.test_client()
159+
response = test_client.get('/protected', headers=make_headers(access_token))
160+
assert response.get_json() == {'foo': 'bar'}
161+
assert response.status_code == 200
162+
163+
164+
def test_user_claims_in_access_token_specified_at_creation_override(app):
165+
jwt = get_jwt_manager(app)
166+
167+
@jwt.user_claims_loader
168+
def add_claims(identity):
169+
return {'default': 'value'}
170+
171+
with app.test_request_context():
172+
access_token = create_access_token('username', user_claims={'foo': 'bar'})
173+
174+
test_client = app.test_client()
175+
response = test_client.get('/protected', headers=make_headers(access_token))
176+
assert response.get_json() == {'foo': 'bar'}
177+
assert response.status_code == 200

0 commit comments

Comments
 (0)