diff --git a/oauth2_provider/oauth2_validators.py b/oauth2_provider/oauth2_validators.py index db459a446..236820730 100644 --- a/oauth2_provider/oauth2_validators.py +++ b/oauth2_provider/oauth2_validators.py @@ -157,7 +157,7 @@ def _authenticate_basic_auth(self, request): try: client_id, client_secret = map(unquote_plus, auth_string_decoded.split(":", 1)) except ValueError: - log.debug("Failed basic auth, Invalid base64 encoding.") + log.debug("Failed basic auth, Invalid base64 encoding") return False if self._load_application(client_id, request) is None: @@ -202,19 +202,32 @@ def _load_application(self, client_id, request): If request.client was not set, load application instance for given client_id and store it in request.client """ - - # we want to be sure that request has the client attribute! - assert hasattr(request, "client"), '"request" instance has no "client" attribute' - + if request.client: + # check for cached client, to save the db hit if this has already been loaded + if not isinstance(request.client, Application): + log.debug("resetting request.client (client_id=%r): not an Application, something else set request.client erroneously", client_id) + request.client = None + elif request.client.client_id != client_id: + log.debug("resetting request.client (client_id=%r): request.client client_id does not match the given client_id", client_id) + request.client = None + elif not request.client.is_usable(request): + log.debug("resetting request.client (client_id=%r): request.client is a valid Application, but is not usable", client_id) + request.client = None + else: + log.debug("request.client is a valid Application, reusing it") + return request.client try: - request.client = request.client or Application.objects.get(client_id=client_id) - # Check that the application can be used (defaults to always True) - if not request.client.is_usable(request): - log.debug("Failed body authentication: Application %r is disabled" % (client_id)) + # cache wasn't hit, load from db + log.debug("cache not hit, loading application from database for client_id %r", client_id) + client = Application.objects.get(client_id=client_id) + if not client.is_usable(request): + log.debug("Failed to load application: Application %r is not usable" % (client_id)) return None + request.client = client + log.debug("Loaded application with client_id %r from database", client.client_id) return request.client except Application.DoesNotExist: - log.debug("Failed body authentication: Application %r does not exist" % (client_id)) + log.debug("Failed to load application: Application %r does not exist" % (client_id)) return None def _set_oauth2_error_on_request(self, request, access_token, scopes): @@ -277,6 +290,7 @@ def client_authentication_required(self, request, *args, **kwargs): pass self._load_application(request.client_id, request) + log.debug("Determining if client authentication is required for client %r", request.client) if request.client: return request.client.client_type == AbstractApplication.CLIENT_CONFIDENTIAL diff --git a/tests/test_authorization_code.py b/tests/test_authorization_code.py index 360fac957..369b1939f 100644 --- a/tests/test_authorization_code.py +++ b/tests/test_authorization_code.py @@ -1308,6 +1308,27 @@ def test_request_body_params(self): self.assertEqual(content["scope"], "read write") self.assertEqual(content["expires_in"], self.oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + def test_request_body_params_client_typo(self): + """ + Verify that using incorrect parameter name (client instead of client_id) returns invalid_client error + """ + self.client.login(username="test_user", password="123456") + authorization_code = self.get_auth() + + token_request_data = { + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": "http://example.org", + "client": self.application.client_id, + "client_secret": CLEARTEXT_SECRET, + } + + response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data) + self.assertEqual(response.status_code, 401) + + content = json.loads(response.content.decode("utf-8")) + self.assertEqual(content["error"], "invalid_client") + def test_public(self): """ Request an access token using client_type: public diff --git a/tests/test_oauth2_validators.py b/tests/test_oauth2_validators.py index 14c74506e..031ef2f29 100644 --- a/tests/test_oauth2_validators.py +++ b/tests/test_oauth2_validators.py @@ -210,8 +210,52 @@ def test_client_authentication_required(self): self.request.client = "" self.assertTrue(self.validator.client_authentication_required(self.request)) - def test_load_application_fails_when_request_has_no_client(self): - self.assertRaises(AssertionError, self.validator.authenticate_client_id, "client_id", {}) + def test_load_application_loads_client_id_when_request_has_no_client(self): + self.request.client = None + application = self.validator._load_application("client_id", self.request) + self.assertEqual(application, self.application) + + def test_load_application_uses_cached_when_request_has_valid_client_matching_client_id(self): + self.request.client = self.application + application = self.validator._load_application("client_id", self.request) + self.assertIs(application, self.application) + self.assertIs(self.request.client, self.application) + + def test_load_application_succeeds_when_request_has_invalid_client_valid_client_id(self): + self.request.client = 'invalid_client' + application = self.validator._load_application("client_id", self.request) + self.assertEqual(application, self.application) + self.assertEqual(self.request.client, self.application) + + def test_load_application_overwrites_client_on_client_id_mismatch(self): + another_application = Application.objects.create( + client_id="another_client_id", + client_secret=CLEARTEXT_SECRET, + user=self.user, + client_type=Application.CLIENT_PUBLIC, + authorization_grant_type=Application.GRANT_PASSWORD, + ) + self.request.client = another_application + application = self.validator._load_application("client_id", self.request) + self.assertEqual(application, self.application) + self.assertEqual(self.request.client, self.application) + another_application.delete() + + @mock.patch.object(Application, "is_usable") + def test_load_application_returns_none_when_client_not_usable_cached(self, mock_is_usable): + mock_is_usable.return_value = False + self.request.client = self.application + application = self.validator._load_application("client_id", self.request) + self.assertIsNone(application) + self.assertIsNone(self.request.client) + + @mock.patch.object(Application, "is_usable") + def test_load_application_returns_none_when_client_not_usable_db_lookup(self, mock_is_usable): + mock_is_usable.return_value = False + self.request.client = None + application = self.validator._load_application("client_id", self.request) + self.assertIsNone(application) + self.assertIsNone(self.request.client) def test_rotate_refresh_token__is_true(self): self.assertTrue(self.validator.rotate_refresh_token(mock.MagicMock()))