Skip to content

Commit b062c7f

Browse files
Guillaume Pujolguillp
authored andcommitted
refactor serializers and enums into their own submodules
1 parent ee8cf16 commit b062c7f

File tree

3 files changed

+122
-69
lines changed

3 files changed

+122
-69
lines changed

requests_oauth2client/serializers.py

Lines changed: 63 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,17 @@
1212

1313
from abc import ABC, abstractmethod
1414
from datetime import datetime, timezone
15-
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, override
15+
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, TypeVar, override
1616

1717
import jwskate
1818
from attr import asdict, field, frozen
1919
from binapy import BinaPy
20-
from jwskate import Jwk
2120

22-
from . import RequestParameterAuthorizationRequest, RequestUriParameterAuthorizationRequest
23-
from .authorization_request import AuthorizationRequest
21+
from .authorization_request import (
22+
AuthorizationRequest,
23+
RequestParameterAuthorizationRequest,
24+
RequestUriParameterAuthorizationRequest,
25+
)
2426
from .dpop import DPoPKey, DPoPToken
2527
from .tokens import BearerToken
2628

@@ -101,8 +103,8 @@ def get_class(self, args: Mapping[str, Any]) -> type[BearerToken]:
101103
"dpop": DPoPToken,
102104
}.get(token_type.lower(), BearerToken)
103105

104-
@staticmethod
105-
def default_dumper(token: BearerToken) -> str:
106+
@classmethod
107+
def default_dumper(cls, token: BearerToken) -> str:
106108
"""Serialize a token as JSON, then compress with deflate, then encodes as base64url.
107109
108110
Args:
@@ -123,8 +125,10 @@ def default_dumper(token: BearerToken) -> str:
123125
BinaPy.serialize_to("json", {k: w for k, w in d.items() if w is not None}).to("deflate").to("b64u").ascii()
124126
)
125127

126-
@staticmethod
127-
def default_loader(serialized: str, get_class: Callable[[Mapping[str, Any]], type[BearerToken]]) -> BearerToken:
128+
@classmethod
129+
def default_loader(
130+
cls, serialized: str, get_class: Callable[[Mapping[str, Any]], type[BearerToken]]
131+
) -> BearerToken:
128132
"""Deserialize a BearerToken.
129133
130134
This does the opposite operations than `default_dumper`.
@@ -151,6 +155,48 @@ def default_loader(serialized: str, get_class: Callable[[Mapping[str, Any]], typ
151155
return token_class(**args)
152156

153157

158+
@frozen
159+
class DPoPKeySerializer(Serializer[DPoPKey]):
160+
"""A (de)serializer for `DPoPKey` instances."""
161+
162+
dumper: Callable[[DPoPKey], str] = field(factory=lambda: DPoPKeySerializer.default_dumper)
163+
loader: Callable[[str, Callable[[Mapping[str, Any]], type[DPoPKey]]], DPoPKey] = field(
164+
factory=lambda: DPoPKeySerializer.default_loader
165+
)
166+
167+
@override
168+
def get_class(self, args: Mapping[str, Any]) -> type[DPoPKey]:
169+
return DPoPKey
170+
171+
@classmethod
172+
def default_dumper(cls, dpop_key: DPoPKey) -> str:
173+
"""Provide a default dumper implementation.
174+
175+
This will not serialize jti_generator, iat_generator, and dpop_token_class!
176+
177+
"""
178+
d = dpop_key.private_key.to_dict()
179+
d.pop("jti_generator", None)
180+
d.pop("iat_generator", None)
181+
d.pop("dpop_token_class", None)
182+
return BinaPy.serialize_to("json", d).to("deflate").to("b64u").ascii()
183+
184+
@classmethod
185+
def default_loader(
186+
cls,
187+
serialized: str,
188+
get_class: Callable[[Mapping[str, Any]], type[DPoPKey]],
189+
) -> DPoPKey:
190+
"""Provide a default deserializer implementation.
191+
192+
This will not deserialize iat_generator, iat_generator, and dpop_token_class!
193+
194+
"""
195+
private_key = BinaPy(serialized).decode_from("b64u").decode_from("deflate").parse_from("json")
196+
dpop_class = get_class({})
197+
return dpop_class(private_key=private_key)
198+
199+
154200
@frozen
155201
class AuthorizationRequestSerializer(
156202
Serializer[AuthorizationRequest | RequestParameterAuthorizationRequest | RequestUriParameterAuthorizationRequest]
@@ -180,6 +226,8 @@ class AuthorizationRequestSerializer(
180226
AuthorizationRequest | RequestParameterAuthorizationRequest | RequestUriParameterAuthorizationRequest,
181227
] = field(factory=lambda: AuthorizationRequestSerializer.default_loader)
182228

229+
dpop_key_serializer: ClassVar[Serializer[DPoPKey]] = DPoPKeySerializer()
230+
183231
@override
184232
def get_class(
185233
self, args: Mapping[str, Any]
@@ -190,8 +238,9 @@ def get_class(
190238
return RequestUriParameterAuthorizationRequest
191239
return AuthorizationRequest
192240

193-
@staticmethod
241+
@classmethod
194242
def default_dumper(
243+
cls,
195244
azr: AuthorizationRequest | RequestParameterAuthorizationRequest | RequestUriParameterAuthorizationRequest,
196245
) -> str:
197246
"""Provide a default dumper implementation.
@@ -208,12 +257,13 @@ def default_dumper(
208257
"""
209258
d = asdict(azr)
210259
if azr.dpop_key:
211-
d["dpop_key"]["private_key"] = azr.dpop_key.private_key.to_dict()
260+
d["dpop_key"] = cls.dpop_key_serializer.dumps(azr.dpop_key)
212261
d.update(**d.pop("kwargs", {}))
213262
return BinaPy.serialize_to("json", d).to("deflate").to("b64u").ascii()
214263

215-
@staticmethod
264+
@classmethod
216265
def default_loader(
266+
cls,
217267
serialized: str,
218268
get_class: Callable[
219269
[Mapping[str, Any]],
@@ -234,55 +284,9 @@ def default_loader(
234284
"""
235285
args = BinaPy(serialized).decode_from("b64u").decode_from("deflate").parse_from("json")
236286

237-
if dpop_key := args.get("dpop_key"):
238-
dpop_key["private_key"] = Jwk(dpop_key["private_key"])
239-
dpop_key.pop("jti_generator", None)
240-
dpop_key.pop("iat_generator", None)
241-
dpop_key.pop("dpop_token_class", None)
242-
args["dpop_key"] = DPoPKey(**dpop_key)
287+
if args["dpop_key"]:
288+
args["dpop_key"] = cls.dpop_key_serializer.loads(args["dpop_key"])
243289

244290
azr_class = get_class(args)
245291

246292
return azr_class(**args)
247-
248-
249-
@frozen
250-
class DPoPKeySerializer(Serializer[DPoPKey]):
251-
"""A (de)serializer for `DPoPKey` instances."""
252-
253-
dumper: Callable[[DPoPKey], str] = field(factory=lambda: DPoPKeySerializer.default_dumper)
254-
loader: Callable[[str, Callable[[Mapping[str, Any]], type[DPoPKey]]], DPoPKey] = field(
255-
factory=lambda: DPoPKeySerializer.default_loader
256-
)
257-
258-
@override
259-
def get_class(self, args: Mapping[str, Any]) -> type[DPoPKey]:
260-
return DPoPKey
261-
262-
@staticmethod
263-
def default_dumper(dpop_key: DPoPKey) -> str:
264-
"""Provide a default dumper implementation.
265-
266-
This will not serialize jti_generator, iat_generator, and dpop_token_class!
267-
268-
"""
269-
d = dpop_key.private_key.to_dict()
270-
d.pop("jti_generator", None)
271-
d.pop("iat_generator", None)
272-
d.pop("dpop_token_class", None)
273-
return BinaPy.serialize_to("json", d).to("deflate").to("b64u").ascii()
274-
275-
@staticmethod
276-
def default_loader(
277-
serialized: str,
278-
get_class: Callable[[Mapping[str, Any]], type[DPoPKey]],
279-
) -> DPoPKey:
280-
"""Provide a default deserializer implementation.
281-
282-
This will not deserialize iat_generator, iat_generator, and dpop_token_class!
283-
284-
"""
285-
args = BinaPy(serialized).decode_from("b64u").decode_from("deflate").parse_from("json")
286-
args["private_key"] = Jwk(args["private_key"])
287-
cls = get_class(args)
288-
return cls(**args)

tests/unit_tests/conftest.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
ClientSecretJwt,
1919
ClientSecretPost,
2020
DPoPKey,
21-
DPoPToken,
2221
OAuth2Client,
2322
PrivateKeyJwt,
2423
PublicApp,
24+
RequestParameterAuthorizationRequest,
2525
)
2626

2727
if TYPE_CHECKING:
@@ -52,14 +52,11 @@ def bearer_auth(access_token: str) -> BearerToken:
5252
return BearerToken(access_token)
5353

5454

55-
@pytest.fixture(scope="session")
56-
def dpop_key() -> DPoPKey:
57-
return DPoPKey.generate()
58-
59-
60-
@pytest.fixture(scope="session")
61-
def dpop_token(access_token: str, dpop_key: DPoPKey) -> DPoPToken:
62-
return DPoPToken(access_token=access_token, _dpop_key=dpop_key)
55+
@pytest.fixture(scope="session", params=[None, "ES256"])
56+
def dpop_key(request: FixtureRequest) -> DPoPKey | None:
57+
if request.param is None:
58+
return None
59+
return DPoPKey.generate(alg=request.param)
6360

6461

6562
@pytest.fixture(scope="session")
@@ -391,6 +388,7 @@ def authorization_request( # noqa: C901
391388
code_challenge_method: str,
392389
expected_issuer: str | None,
393390
auth_request_kwargs: dict[str, Any],
391+
dpop_key: DPoPKey,
394392
) -> AuthorizationRequest:
395393
authorization_response_iss_parameter_supported = bool(expected_issuer)
396394

@@ -405,6 +403,7 @@ def authorization_request( # noqa: C901
405403
code_challenge_method=code_challenge_method,
406404
authorization_response_iss_parameter_supported=authorization_response_iss_parameter_supported,
407405
issuer=expected_issuer,
406+
dpop_key=dpop_key,
408407
**auth_request_kwargs,
409408
)
410409

@@ -416,6 +415,7 @@ def authorization_request( # noqa: C901
416415
assert azr.redirect_uri == redirect_uri
417416
assert azr.issuer == expected_issuer
418417
assert azr.kwargs == auth_request_kwargs
418+
assert azr.dpop_key == dpop_key
419419

420420
args = dict(url.args)
421421
expected_args = dict(
@@ -499,6 +499,9 @@ def authorization_request( # noqa: C901
499499
assert generated_code_challenge == code_verifier
500500
assert azr.code_verifier == code_verifier
501501

502+
if dpop_key:
503+
expected_args["dpop_jkt"] = dpop_key.dpop_jkt
504+
502505
assert args == expected_args
503506

504507
return azr
@@ -535,3 +538,16 @@ def authorization_response(
535538
assert auth_response.code_verifier == authorization_request.code_verifier
536539

537540
return auth_response
541+
542+
543+
@pytest.fixture(scope="session")
544+
def request_parameter_signing_key() -> Jwk:
545+
return Jwk.generate(alg="ES256")
546+
547+
548+
@pytest.fixture
549+
def request_parameter_authorization_request(
550+
authorization_request: AuthorizationRequest,
551+
request_parameter_signing_key: Jwk,
552+
) -> RequestParameterAuthorizationRequest:
553+
return authorization_request.sign(request_parameter_signing_key)

tests/unit_tests/test_serializers.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
BearerTokenSerializer,
1111
DPoPKey,
1212
DPoPToken,
13+
RequestParameterAuthorizationRequest,
14+
RequestUriParameterAuthorizationRequest,
1315
)
1416

1517

@@ -33,11 +35,42 @@ def test_token_serializer(token: BearerToken, freezer: FrozenDateTimeFactory) ->
3335
assert serializer.loads(candidate) == token
3436

3537

36-
def test_authorization_request_serializer(authorization_request: AuthorizationRequest) -> None:
38+
def test_authorization_request_serializer(
39+
authorization_request: AuthorizationRequest,
40+
request_parameter_authorization_request: RequestParameterAuthorizationRequest,
41+
) -> None:
3742
serializer = AuthorizationRequestSerializer()
3843
serialized = serializer.dumps(authorization_request)
3944
assert serializer.loads(serialized) == authorization_request
4045

46+
request_parameter_serialized = serializer.dumps(request_parameter_authorization_request)
47+
assert serializer.loads(request_parameter_serialized) == request_parameter_authorization_request
48+
49+
50+
@pytest.fixture(
51+
scope="module", params=["this_is_a_request_uri", "urn:this:is:a:request_uri", "https://foo.bar/request_uri"]
52+
)
53+
def request_uri_authorization_request(
54+
authorization_endpoint: str, client_id: str, request: pytest.FixtureRequest
55+
) -> RequestUriParameterAuthorizationRequest:
56+
request_uri = request.param
57+
return RequestUriParameterAuthorizationRequest(
58+
authorization_endpoint=authorization_endpoint,
59+
client_id=client_id,
60+
request_uri=request_uri,
61+
custom_param="custom_value",
62+
)
63+
64+
65+
def test_request_uri_authorization_request_serializer(
66+
request_uri_authorization_request: RequestUriParameterAuthorizationRequest,
67+
) -> None:
68+
serializer = AuthorizationRequestSerializer()
69+
serialized = serializer.dumps(request_uri_authorization_request)
70+
deserialized = serializer.loads(serialized)
71+
assert isinstance(deserialized, RequestUriParameterAuthorizationRequest)
72+
assert deserialized == request_uri_authorization_request
73+
4174

4275
def test_authorization_request_serializer_with_dpop_key() -> None:
4376
dpop_key = DPoPKey.generate()

0 commit comments

Comments
 (0)