Skip to content

Commit 2978d0f

Browse files
jimgrahamsynasius
authored andcommitted
Reuse refresh tokens
Reuse refresh tokens if it's enabled. This is based on a code sample in #304 from #304 (comment) For us, this has cleaned up a number of race conditions deleting and recreating tokens.
1 parent d51a351 commit 2978d0f

File tree

2 files changed

+209
-25
lines changed

2 files changed

+209
-25
lines changed

oauth2_provider/oauth2_validators.py

Lines changed: 73 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
from django.conf import settings
1111
from django.contrib.auth import authenticate
1212
from django.core.exceptions import ObjectDoesNotExist
13+
from django.db import transaction
1314
from oauthlib.oauth2 import RequestValidator
1415

1516
from .compat import unquote_plus
17+
from .exceptions import FatalClientError
1618
from .models import Grant, AccessToken, RefreshToken, get_application_model, AbstractApplication
1719
from .settings import oauth2_settings
1820

@@ -86,6 +88,9 @@ def _authenticate_basic_auth(self, request):
8688
if self._load_application(client_id, request) is None:
8789
log.debug("Failed basic auth: Application %s does not exist" % client_id)
8890
return False
91+
elif request.client.client_id != client_id:
92+
log.debug("Failed basic auth: wrong client id %s" % client_id)
93+
return False
8994
elif request.client.client_secret != client_secret:
9095
log.debug("Failed basic auth: wrong client secret %s" % client_secret)
9196
return False
@@ -292,41 +297,88 @@ def save_authorization_code(self, client_id, code, request, *args, **kwargs):
292297
scope=' '.join(request.scopes))
293298
g.save()
294299

300+
@transaction.atomic
295301
def save_bearer_token(self, token, request, *args, **kwargs):
296302
"""
297-
Save access and refresh token, If refresh token is issued, remove old refresh tokens as
298-
in rfc:`6`
303+
Save access and refresh token, If refresh token is issued, remove or
304+
reuse old refresh token as in rfc:`6`
305+
306+
@see: https://tools.ietf.org/html/draft-ietf-oauth-v2-31#page-43
299307
"""
300-
if request.refresh_token:
301-
# remove used refresh token
302-
try:
303-
RefreshToken.objects.get(token=request.refresh_token).revoke()
304-
except RefreshToken.DoesNotExist:
305-
assert() # TODO though being here would be very strange, at least log the error
308+
309+
if 'scope' not in token:
310+
raise FatalClientError(u"Failed to renew access token: missing scope")
306311

307312
expires = timezone.now() + timedelta(seconds=oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS)
313+
308314
if request.grant_type == 'client_credentials':
309315
request.user = None
310316

317+
# This comes from OAuthLib:
318+
# https://github.com/idan/oauthlib/blob/1.0.3/oauthlib/oauth2/rfc6749/tokens.py#L267
319+
# Its value is either a new random code; or if we are reusing
320+
# refresh tokens, then it is the same value that the request passed in
321+
# (stored in `request.refresh_token`)
322+
refresh_token_code = token.get('refresh_token', None)
323+
324+
if refresh_token_code:
325+
# an instance of `RefreshToken` that matches the old refresh code.
326+
# Set on the request in `validate_refresh_token`
327+
refresh_token_instance = getattr(request, 'refresh_token_instance', None)
328+
329+
# If we are to reuse tokens, and we can: do so
330+
if not self.rotate_refresh_token(request) and \
331+
isinstance(refresh_token_instance, RefreshToken) and \
332+
refresh_token_instance.access_token:
333+
334+
access_token = AccessToken.objects.select_for_update().get(
335+
pk=refresh_token_instance.access_token.pk
336+
)
337+
access_token.user = request.user
338+
access_token.scope = token['scope']
339+
access_token.expires = expires
340+
access_token.token = token['access_token']
341+
access_token.application = request.client
342+
access_token.save()
343+
344+
# else create fresh with access & refresh tokens
345+
else:
346+
# revoke existing tokens if possible
347+
if isinstance(refresh_token_instance, RefreshToken):
348+
try:
349+
refresh_token_instance.revoke()
350+
except (AccessToken.DoesNotExist, RefreshToken.DoesNotExist):
351+
pass
352+
else:
353+
setattr(request, 'refresh_token_instance', None)
354+
355+
access_token = self._create_access_token(expires, request, token)
356+
357+
refresh_token = RefreshToken(
358+
user=request.user,
359+
token=refresh_token_code,
360+
application=request.client,
361+
access_token=access_token
362+
)
363+
refresh_token.save()
364+
365+
# No refresh token should be created, just access token
366+
else:
367+
self._create_access_token(expires, request, token)
368+
369+
# TODO: check out a more reliable way to communicate expire time to oauthlib
370+
token['expires_in'] = oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS
371+
372+
def _create_access_token(self, expires, request, token):
311373
access_token = AccessToken(
312374
user=request.user,
313375
scope=token['scope'],
314376
expires=expires,
315377
token=token['access_token'],
316-
application=request.client)
378+
application=request.client
379+
)
317380
access_token.save()
318-
319-
if 'refresh_token' in token:
320-
refresh_token = RefreshToken(
321-
user=request.user,
322-
token=token['refresh_token'],
323-
application=request.client,
324-
access_token=access_token
325-
)
326-
refresh_token.save()
327-
328-
# TODO check out a more reliable way to communicate expire time to oauthlib
329-
token['expires_in'] = oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS
381+
return access_token
330382

331383
def revoke_token(self, token, token_type_hint, request, *args, **kwargs):
332384
"""

oauth2_provider/tests/test_oauth2_validators.py

Lines changed: 136 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,31 @@
1+
from datetime import timedelta
2+
13
from django.contrib.auth import get_user_model
2-
from django.test import TestCase
4+
from django.test import TransactionTestCase
5+
from django.utils import timezone
36

47
import mock
58
from oauthlib.common import Request
69

10+
from ..exceptions import FatalClientError
711
from ..oauth2_validators import OAuth2Validator
8-
from ..models import get_application_model
12+
from ..models import get_application_model, AccessToken, RefreshToken
913

1014
UserModel = get_user_model()
1115
AppModel = get_application_model()
1216

1317

14-
class TestOAuth2Validator(TestCase):
18+
class TestOAuth2Validator(TransactionTestCase):
1519
def setUp(self):
1620
self.user = UserModel.objects.create_user("user", "test@user.com", "123456")
1721
self.request = mock.MagicMock(wraps=Request)
18-
self.request.client = None
22+
self.request.user = self.user
23+
self.request.grant_type = "not client"
1924
self.validator = OAuth2Validator()
2025
self.application = AppModel.objects.create(
2126
client_id='client_id', client_secret='client_secret', user=self.user,
2227
client_type=AppModel.CLIENT_PUBLIC, authorization_grant_type=AppModel.GRANT_PASSWORD)
28+
self.request.client = self.application
2329

2430
def tearDown(self):
2531
self.application.delete()
@@ -108,3 +114,129 @@ def test_client_authentication_required(self):
108114

109115
def test_load_application_fails_when_request_has_no_client(self):
110116
self.assertRaises(AssertionError, self.validator.authenticate_client_id, 'client_id', {})
117+
118+
def test_rotate_refresh_token__is_true(self):
119+
self.assertTrue(self.validator.rotate_refresh_token(mock.MagicMock()))
120+
121+
def test_save_bearer_token__without_user__raises_fatal_client(self):
122+
token = {}
123+
124+
with self.assertRaises(FatalClientError):
125+
self.validator.save_bearer_token(token, mock.MagicMock())
126+
127+
def test_save_bearer_token__with_existing_tokens__does_not_create_new_tokens(self):
128+
129+
rotate_token_function = mock.MagicMock()
130+
rotate_token_function.return_value = False
131+
self.validator.rotate_refresh_token = rotate_token_function
132+
133+
access_token = AccessToken.objects.create(
134+
token="123",
135+
user=self.user,
136+
expires=timezone.now() + timedelta(seconds=60),
137+
application=self.application
138+
)
139+
refresh_token = RefreshToken.objects.create(
140+
access_token=access_token,
141+
token="abc",
142+
user=self.user,
143+
application=self.application
144+
)
145+
self.request.refresh_token_instance = refresh_token
146+
token = {
147+
"scope": "foo bar",
148+
"refresh_token": "abc",
149+
"access_token": "123",
150+
}
151+
152+
self.assertEqual(1, RefreshToken.objects.count())
153+
self.assertEqual(1, AccessToken.objects.count())
154+
155+
self.validator.save_bearer_token(token, self.request)
156+
157+
self.assertEqual(1, RefreshToken.objects.count())
158+
self.assertEqual(1, AccessToken.objects.count())
159+
160+
def test_save_bearer_token__checks_to_rotate_tokens(self):
161+
162+
rotate_token_function = mock.MagicMock()
163+
rotate_token_function.return_value = False
164+
self.validator.rotate_refresh_token = rotate_token_function
165+
166+
access_token = AccessToken.objects.create(
167+
token="123",
168+
user=self.user,
169+
expires=timezone.now() + timedelta(seconds=60),
170+
application=self.application
171+
)
172+
refresh_token = RefreshToken.objects.create(
173+
access_token=access_token,
174+
token="abc",
175+
user=self.user,
176+
application=self.application
177+
)
178+
self.request.refresh_token_instance = refresh_token
179+
token = {
180+
"scope": "foo bar",
181+
"refresh_token": "abc",
182+
"access_token": "123",
183+
}
184+
185+
self.validator.save_bearer_token(token, self.request)
186+
rotate_token_function.assert_called_once_with(self.request)
187+
188+
def test_save_bearer_token__with_new_token__creates_new_tokens(self):
189+
token = {
190+
"scope": "foo bar",
191+
"refresh_token": "abc",
192+
"access_token": "123",
193+
}
194+
195+
self.assertEqual(0, RefreshToken.objects.count())
196+
self.assertEqual(0, AccessToken.objects.count())
197+
198+
self.validator.save_bearer_token(token, self.request)
199+
200+
self.assertEqual(1, RefreshToken.objects.count())
201+
self.assertEqual(1, AccessToken.objects.count())
202+
203+
def test_save_bearer_token__with_new_token_equal_to_existing_token__revokes_old_tokens(self):
204+
access_token = AccessToken.objects.create(
205+
token="123",
206+
user=self.user,
207+
expires=timezone.now() + timedelta(seconds=60),
208+
application=self.application
209+
)
210+
refresh_token = RefreshToken.objects.create(
211+
access_token=access_token,
212+
token="abc",
213+
user=self.user,
214+
application=self.application
215+
)
216+
217+
self.request.refresh_token_instance = refresh_token
218+
219+
token = {
220+
"scope": "foo bar",
221+
"refresh_token": "abc",
222+
"access_token": "123",
223+
}
224+
225+
self.assertEqual(1, RefreshToken.objects.count())
226+
self.assertEqual(1, AccessToken.objects.count())
227+
228+
self.validator.save_bearer_token(token, self.request)
229+
230+
self.assertEqual(1, RefreshToken.objects.count())
231+
self.assertEqual(1, AccessToken.objects.count())
232+
233+
def test_save_bearer_token__with_no_refresh_token__creates_new_access_token_only(self):
234+
token = {
235+
"scope": "foo bar",
236+
"access_token": "123",
237+
}
238+
239+
self.validator.save_bearer_token(token, self.request)
240+
241+
self.assertEqual(0, RefreshToken.objects.count())
242+
self.assertEqual(1, AccessToken.objects.count())

0 commit comments

Comments
 (0)