Skip to content

Commit 14a6c65

Browse files
author
muha
committed
fix PARAuthorizationWithPkce
1 parent 62898fe commit 14a6c65

File tree

1 file changed

+114
-108
lines changed

1 file changed

+114
-108
lines changed

authorization-code/src/main/java/com/example/security/PARAuthorizationWithPkceRequestResolver.java

Lines changed: 114 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,10 @@
44
import com.example.config.OAuth2ClientDetailProperties.Registration;
55
import com.example.service.JwtClientAssertionParametersService;
66
import jakarta.servlet.http.HttpServletRequest;
7-
import java.lang.reflect.Field;
8-
import java.util.Arrays;
97
import java.util.Base64;
108
import java.util.HashMap;
119
import java.util.List;
1210
import java.util.Map;
13-
import org.springframework.http.HttpEntity;
14-
import org.springframework.http.HttpHeaders;
1511
import org.springframework.http.MediaType;
1612
import org.springframework.http.ResponseEntity;
1713
import org.springframework.http.converter.FormHttpMessageConverter;
@@ -29,50 +25,52 @@
2925
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
3026
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
3127
import org.springframework.security.oauth2.core.oidc.OidcScopes;
28+
import org.springframework.security.web.servlet.util.matcher.PathPatternRequestMatcher;
3229
import org.springframework.security.web.util.UrlUtils;
33-
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
3430
import org.springframework.util.Assert;
3531
import org.springframework.util.LinkedMultiValueMap;
3632
import org.springframework.util.MultiValueMap;
37-
import org.springframework.util.ReflectionUtils;
3833
import org.springframework.util.StringUtils;
39-
import org.springframework.web.client.RestOperations;
40-
import org.springframework.web.client.RestTemplate;
34+
import org.springframework.web.client.RestClient;
4135
import org.springframework.web.util.DefaultUriBuilderFactory;
4236
import org.springframework.web.util.UriBuilder;
4337
import org.springframework.web.util.UriComponents;
4438
import org.springframework.web.util.UriComponentsBuilder;
4539

46-
4740
public class PARAuthorizationWithPkceRequestResolver implements OAuth2AuthorizationRequestResolver {
4841

4942
private static final String REGISTRATION_ID_URI_VARIABLE_NAME = "registrationId";
5043
private static final char PATH_DELIMITER = '/';
51-
private static final StringKeyGenerator DEFAULT_STATE_GENERATOR = new Base64StringKeyGenerator(
52-
Base64.getUrlEncoder());
44+
private static final StringKeyGenerator DEFAULT_STATE_GENERATOR =
45+
new Base64StringKeyGenerator(Base64.getUrlEncoder());
5346
private final ClientRegistrationRepository clientRegistrationRepository;
54-
private final AntPathRequestMatcher authorizationRequestMatcher;
47+
private final PathPatternRequestMatcher authorizationRequestMatcher;
5548

56-
private final RestOperations restOperations;
5749
private final JwtClientAssertionParametersService jwtClientAssertionParametersService;
5850
private final Map<String, OAuth2ClientDetailProperties.Registration> registrations;
5951

52+
private final RestClient restClient =
53+
RestClient.builder()
54+
.messageConverters(
55+
(messageConverters) -> {
56+
messageConverters.clear();
57+
messageConverters.add(new FormHttpMessageConverter());
58+
messageConverters.add(new OAuth2AccessTokenResponseHttpMessageConverter());
59+
messageConverters.add(new MappingJackson2HttpMessageConverter());
60+
})
61+
.defaultStatusHandler(new OAuth2ErrorResponseErrorHandler())
62+
.build();
63+
6064
public PARAuthorizationWithPkceRequestResolver(
6165
ClientRegistrationRepository clientRegistrationRepository,
6266
Map<String, OAuth2ClientDetailProperties.Registration> registrations,
6367
String authorizationRequestBaseUri) {
6468
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
6569
Assert.hasText(authorizationRequestBaseUri, "authorizationRequestBaseUri cannot be empty");
6670
this.clientRegistrationRepository = clientRegistrationRepository;
67-
this.authorizationRequestMatcher = new AntPathRequestMatcher(
68-
authorizationRequestBaseUri + "/{" + REGISTRATION_ID_URI_VARIABLE_NAME + "}");
69-
70-
RestTemplate restTemplate = new RestTemplate(
71-
Arrays.asList(new FormHttpMessageConverter(),
72-
new OAuth2AccessTokenResponseHttpMessageConverter(),
73-
new MappingJackson2HttpMessageConverter()));
74-
restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler());
75-
this.restOperations = restTemplate;
71+
this.authorizationRequestMatcher =
72+
PathPatternRequestMatcher.withDefaults()
73+
.matcher(authorizationRequestBaseUri + "/{" + REGISTRATION_ID_URI_VARIABLE_NAME + "}");
7674

7775
this.registrations = registrations;
7876
jwtClientAssertionParametersService = new JwtClientAssertionParametersService(registrations);
@@ -97,95 +95,99 @@ public OAuth2AuthorizationRequest resolve(HttpServletRequest request, String reg
9795
return resolve(request, registrationId, redirectUriAction);
9896
}
9997

100-
private OAuth2AuthorizationRequest resolve(HttpServletRequest request, String registrationId,
101-
String redirectUriAction) {
98+
private OAuth2AuthorizationRequest resolve(
99+
HttpServletRequest request, String registrationId, String redirectUriAction) {
102100
if (registrationId == null) {
103101
return null;
104102
}
105-
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(
106-
registrationId);
103+
ClientRegistration clientRegistration =
104+
this.clientRegistrationRepository.findByRegistrationId(registrationId);
107105
if (clientRegistration == null) {
108106
throw new IllegalArgumentException("Invalid Client Registration with Id: " + registrationId);
109107
}
110108

111109
String redirectUri = expandRedirectUri(request, clientRegistration, redirectUriAction);
112110
String state = DEFAULT_STATE_GENERATOR.generateKey();
113111
Map<String, Object> pkceParameters = buildPkceParameters(clientRegistration);
114-
String parRequestUri = sendParRequest(redirectUri, state, clientRegistration,
115-
pkceParameters).request_uri;
112+
String parRequestUri =
113+
sendParRequest(redirectUri, state, clientRegistration, pkceParameters).request_uri;
116114

117115
return buildOAuth2AuthorizationRequest(
118116
pkceParameters, clientRegistration, redirectUri, state, parRequestUri);
119117
}
120118

121119
private OAuth2AuthorizationRequest buildOAuth2AuthorizationRequest(
122-
Map<String, Object> pkceParameters, ClientRegistration clientRegistration, String redirectUri,
123-
String state, String parRequestUri) {
124-
String codeVerifier = pkceParameters
125-
.get(PkceParameterNames.CODE_VERIFIER).toString();
126-
127-
String authorizationEndpoint = clientRegistration.getProviderDetails()
128-
.getConfigurationMetadata().get("authorization_endpoint").toString();
129-
130-
OAuth2AuthorizationRequest oAuth2AuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
131-
.attributes((attrs) ->
132-
{
133-
attrs.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId());
134-
attrs.put(PkceParameterNames.CODE_VERIFIER, codeVerifier);
135-
})
120+
Map<String, Object> pkceParameters,
121+
ClientRegistration clientRegistration,
122+
String redirectUri,
123+
String state,
124+
String parRequestUri) {
125+
String codeVerifier = pkceParameters.get(PkceParameterNames.CODE_VERIFIER).toString();
126+
127+
String authorizationEndpoint =
128+
clientRegistration
129+
.getProviderDetails()
130+
.getConfigurationMetadata()
131+
.get("authorization_endpoint")
132+
.toString();
133+
134+
String authorizationRequestUri =
135+
getAuthorizationRequestUri(
136+
clientRegistration.getClientId(), authorizationEndpoint, parRequestUri);
137+
138+
return OAuth2AuthorizationRequest.authorizationCode()
139+
.attributes(
140+
(attrs) -> {
141+
attrs.put(
142+
OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId());
143+
attrs.put(PkceParameterNames.CODE_VERIFIER, codeVerifier);
144+
})
136145
.redirectUri(redirectUri)
137146
.clientId(clientRegistration.getClientId())
138147
.scope(OidcScopes.OPENID)
139148
.authorizationUri(authorizationEndpoint)
140-
.state(state).build();
141-
142-
String authorizationRequestUri = getAuthorizationRequestUri(parRequestUri,
143-
oAuth2AuthorizationRequest);
144-
145-
Field authorizationRequestUriField = ReflectionUtils.findField(OAuth2AuthorizationRequest.class,
146-
"authorizationRequestUri");
147-
if (authorizationRequestUriField != null) {
148-
ReflectionUtils.makeAccessible(authorizationRequestUriField);
149-
ReflectionUtils.setField(authorizationRequestUriField, oAuth2AuthorizationRequest,
150-
authorizationRequestUri);
151-
}
152-
return oAuth2AuthorizationRequest;
149+
.authorizationRequestUri(authorizationRequestUri)
150+
.state(state)
151+
.build();
153152
}
154153

155-
private String getAuthorizationRequestUri(String parRequestUri,
156-
OAuth2AuthorizationRequest oAuth2AuthorizationRequest) {
157-
DefaultUriBuilderFactory uriBuilderFactory = new DefaultUriBuilderFactory();
158-
uriBuilderFactory.setEncodingMode(DefaultUriBuilderFactory.EncodingMode.NONE);
154+
private String getAuthorizationRequestUri(
155+
String clientId, String authorizationEndpoint, String parRequestUri) {
159156

160157
MultiValueMap<String, String> queryParams = new LinkedMultiValueMap<>();
161158
queryParams.put("request_uri", List.of(parRequestUri));
162-
queryParams.put(OAuth2ParameterNames.CLIENT_ID,
163-
List.of(oAuth2AuthorizationRequest.getClientId()));
164-
UriBuilder uriBuilder = uriBuilderFactory.uriString(
165-
oAuth2AuthorizationRequest.getAuthorizationUri()).queryParams(queryParams);
159+
queryParams.put(OAuth2ParameterNames.CLIENT_ID, List.of(clientId));
160+
161+
DefaultUriBuilderFactory uriBuilderFactory = new DefaultUriBuilderFactory();
162+
uriBuilderFactory.setEncodingMode(DefaultUriBuilderFactory.EncodingMode.NONE);
163+
UriBuilder uriBuilder =
164+
uriBuilderFactory.uriString(authorizationEndpoint).queryParams(queryParams);
166165
return uriBuilder.build().toString();
167166
}
168167

169-
public ParResponse sendParRequest(String redirectUri, String state,
170-
ClientRegistration clientRegistration, Map<String, Object> pkceParameters) {
171-
String parEndpoint = clientRegistration.getProviderDetails()
172-
.getConfigurationMetadata().get("pushed_authorization_request_endpoint").toString();
173-
174-
String codeChallengeMethod = pkceParameters
175-
.get(PkceParameterNames.CODE_CHALLENGE_METHOD).toString();
176-
String codeChallenge = pkceParameters
177-
.get(PkceParameterNames.CODE_CHALLENGE).toString();
178-
179-
MultiValueMap<String, String> assertionParameters = jwtClientAssertionParametersService.buildClientAssertionParameters(
180-
clientRegistration);
181-
182-
String clientAssertionType = assertionParameters.get(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE)
183-
.getFirst();
184-
String clientAssertion = assertionParameters.get(OAuth2ParameterNames.CLIENT_ASSERTION)
185-
.getFirst();
186-
187-
HttpHeaders headers = new HttpHeaders();
188-
headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
168+
public ParResponse sendParRequest(
169+
String redirectUri,
170+
String state,
171+
ClientRegistration clientRegistration,
172+
Map<String, Object> pkceParameters) {
173+
String parEndpoint =
174+
clientRegistration
175+
.getProviderDetails()
176+
.getConfigurationMetadata()
177+
.get("pushed_authorization_request_endpoint")
178+
.toString();
179+
180+
String codeChallengeMethod =
181+
pkceParameters.get(PkceParameterNames.CODE_CHALLENGE_METHOD).toString();
182+
String codeChallenge = pkceParameters.get(PkceParameterNames.CODE_CHALLENGE).toString();
183+
184+
MultiValueMap<String, String> assertionParameters =
185+
jwtClientAssertionParametersService.buildClientAssertionParameters(clientRegistration);
186+
187+
String clientAssertionType =
188+
assertionParameters.get(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE).getFirst();
189+
String clientAssertion =
190+
assertionParameters.get(OAuth2ParameterNames.CLIENT_ASSERTION).getFirst();
189191

190192
MultiValueMap<String, String> body = new LinkedMultiValueMap<>();
191193
body.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
@@ -201,17 +203,23 @@ public ParResponse sendParRequest(String redirectUri, String state,
201203

202204
Registration registration = registrations.get(clientRegistration.getRegistrationId());
203205
if (registration != null) {
204-
if (registration.getAcrValues() != null) {
206+
if (registration.getAcrValues() != null && !registration.getAcrValues().isEmpty()) {
205207
body.add("acr_values", registration.getAcrValues());
206208
}
207-
if (registration.getPrompt() != null) {
209+
if (registration.getPrompt() != null && !registration.getPrompt().isEmpty()) {
208210
body.add("prompt", registration.getPrompt());
209211
}
210212
}
211213

212-
HttpEntity<MultiValueMap<String, String>> requestEntity = new HttpEntity<>(body, headers);
213-
ResponseEntity<ParResponse> response = restOperations.postForEntity(parEndpoint, requestEntity,
214-
ParResponse.class);
214+
ResponseEntity<ParResponse> response =
215+
restClient
216+
.post()
217+
.uri(parEndpoint)
218+
.body(body)
219+
.headers(
220+
httpHeaders -> httpHeaders.setContentType(MediaType.APPLICATION_FORM_URLENCODED))
221+
.retrieve()
222+
.toEntity(ParResponse.class);
215223

216224
if (response.hasBody()) {
217225
return response.getBody();
@@ -220,27 +228,25 @@ public ParResponse sendParRequest(String redirectUri, String state,
220228
}
221229

222230
private Map<String, Object> buildPkceParameters(ClientRegistration clientRegistration) {
223-
Builder builder = OAuth2AuthorizationRequest.authorizationCode()
224-
.clientId(clientRegistration.getClientId())
225-
.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri());
231+
Builder builder =
232+
OAuth2AuthorizationRequest.authorizationCode()
233+
.clientId(clientRegistration.getClientId())
234+
.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri());
226235

227-
OAuth2AuthorizationRequestCustomizers
228-
.withPkce().accept(builder);
236+
OAuth2AuthorizationRequestCustomizers.withPkce().accept(builder);
229237
OAuth2AuthorizationRequest build = builder.build();
230238
Map<String, Object> additionalParameters = new HashMap<>(build.getAdditionalParameters());
231-
additionalParameters.put(PkceParameterNames.CODE_VERIFIER,
232-
build.getAttribute(PkceParameterNames.CODE_VERIFIER));
239+
additionalParameters.put(
240+
PkceParameterNames.CODE_VERIFIER, build.getAttribute(PkceParameterNames.CODE_VERIFIER));
233241
return additionalParameters;
234242
}
235243

236-
237244
public static class ParResponse {
238245

239246
public String request_uri;
240247
public String expires_in;
241248
}
242249

243-
244250
private String getAction(HttpServletRequest request, String defaultAction) {
245251
String action = request.getParameter("action");
246252
if (action == null) {
@@ -251,24 +257,25 @@ private String getAction(HttpServletRequest request, String defaultAction) {
251257

252258
private String resolveRegistrationId(HttpServletRequest request) {
253259
if (this.authorizationRequestMatcher.matches(request)) {
254-
return this.authorizationRequestMatcher.matcher(request)
260+
return this.authorizationRequestMatcher
261+
.matcher(request)
255262
.getVariables()
256263
.get(REGISTRATION_ID_URI_VARIABLE_NAME);
257264
}
258265
return null;
259266
}
260267

261-
private String expandRedirectUri(HttpServletRequest request,
262-
ClientRegistration clientRegistration,
263-
String action) {
268+
private String expandRedirectUri(
269+
HttpServletRequest request, ClientRegistration clientRegistration, String action) {
264270
Map<String, String> uriVariables = new HashMap<>();
265271
uriVariables.put("registrationId", clientRegistration.getRegistrationId());
266272
// @formatter:off
267-
UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
268-
.replacePath(request.getContextPath())
269-
.replaceQuery(null)
270-
.fragment(null)
271-
.build();
273+
UriComponents uriComponents =
274+
UriComponentsBuilder.fromUriString(UrlUtils.buildFullRequestUrl(request))
275+
.replacePath(request.getContextPath())
276+
.replaceQuery(null)
277+
.fragment(null)
278+
.build();
272279
// @formatter:on
273280
String scheme = uriComponents.getScheme();
274281
uriVariables.put("baseScheme", (scheme != null) ? scheme : "");
@@ -290,5 +297,4 @@ private String expandRedirectUri(HttpServletRequest request,
290297
.buildAndExpand(uriVariables)
291298
.toUriString();
292299
}
293-
294-
}
300+
}

0 commit comments

Comments
 (0)