1+ from urllib .parse import parse_qs
12from urllib .parse import urlencode
3+ from urllib .parse import urlparse
24
35import pytest
46from httpx import AsyncClient
@@ -14,7 +16,11 @@ async def oauth2_workflow(get_app, idp=False, ssr=True, authorize_query="", toke
1416 response = await client .get ("/oauth2/test/authorize" + authorize_query ) # Get authorization endpoint
1517 authorization_endpoint = response .headers .get ("location" ) if ssr else response .json ().get ("url" )
1618 response = await client .get (authorization_endpoint ) # Authorize
17- response = await client .get (response .headers .get ("location" ) + token_query ) # Obtain token
19+ token_url = response .headers .get ("location" )
20+ query = {k : v [0 ] for k , v in parse_qs (urlparse (token_url ).query ).items ()}
21+ query .update ({k : v [0 ] for k , v in parse_qs (token_query ).items ()})
22+ token_url = "%s?%s" % (token_url .split ("?" )[0 ], urlencode (query ))
23+ response = await client .get (token_url ) # Obtain token
1824
1925 response = await client .get ("/user" , headers = dict (
2026 Authorization = jwt_encode (response .json (), "" ) # Set token
@@ -43,3 +49,16 @@ async def test_oauth2_pkce_workflow(get_app):
4349 tq = "&" + urlencode (dict (code_verifier = code_verifier ))
4450 await oauth2_workflow (get_app , idp = True , authorize_query = aq , token_query = tq )
4551 await oauth2_workflow (get_app , idp = True , ssr = False , authorize_query = aq , token_query = tq , use_header = True )
52+
53+
54+ @pytest .mark .anyio
55+ async def test_oauth2_csrf_workflow (get_app ):
56+ for aq , tq in [
57+ ("?state=test_state" , "&state=test_state" ),
58+ ("?state=test_state" , "&state=test_wrong_state" )
59+ ]:
60+ try :
61+ await oauth2_workflow (get_app , idp = True , authorize_query = aq , token_query = tq )
62+ await oauth2_workflow (get_app , idp = True , ssr = False , authorize_query = aq , token_query = tq , use_header = True )
63+ except AssertionError :
64+ assert aq != tq
0 commit comments