44import com .example .config .OAuth2ClientDetailProperties .Registration ;
55import com .example .service .JwtClientAssertionParametersService ;
66import jakarta .servlet .http .HttpServletRequest ;
7- import java .lang .reflect .Field ;
8- import java .util .Arrays ;
97import java .util .Base64 ;
108import java .util .HashMap ;
119import java .util .List ;
1210import java .util .Map ;
13- import org .springframework .http .HttpEntity ;
14- import org .springframework .http .HttpHeaders ;
1511import org .springframework .http .MediaType ;
1612import org .springframework .http .ResponseEntity ;
1713import org .springframework .http .converter .FormHttpMessageConverter ;
2925import org .springframework .security .oauth2 .core .endpoint .PkceParameterNames ;
3026import org .springframework .security .oauth2 .core .http .converter .OAuth2AccessTokenResponseHttpMessageConverter ;
3127import org .springframework .security .oauth2 .core .oidc .OidcScopes ;
28+ import org .springframework .security .web .servlet .util .matcher .PathPatternRequestMatcher ;
3229import org .springframework .security .web .util .UrlUtils ;
33- import org .springframework .security .web .util .matcher .AntPathRequestMatcher ;
3430import org .springframework .util .Assert ;
3531import org .springframework .util .LinkedMultiValueMap ;
3632import org .springframework .util .MultiValueMap ;
37- import org .springframework .util .ReflectionUtils ;
3833import org .springframework .util .StringUtils ;
39- import org .springframework .web .client .RestOperations ;
40- import org .springframework .web .client .RestTemplate ;
34+ import org .springframework .web .client .RestClient ;
4135import org .springframework .web .util .DefaultUriBuilderFactory ;
4236import org .springframework .web .util .UriBuilder ;
4337import org .springframework .web .util .UriComponents ;
4438import org .springframework .web .util .UriComponentsBuilder ;
4539
46-
4740public class PARAuthorizationWithPkceRequestResolver implements OAuth2AuthorizationRequestResolver {
4841
4942 private static final String REGISTRATION_ID_URI_VARIABLE_NAME = "registrationId" ;
5043 private static final char PATH_DELIMITER = '/' ;
51- private static final StringKeyGenerator DEFAULT_STATE_GENERATOR = new Base64StringKeyGenerator (
52- Base64 .getUrlEncoder ());
44+ private static final StringKeyGenerator DEFAULT_STATE_GENERATOR =
45+ new Base64StringKeyGenerator ( Base64 .getUrlEncoder ());
5346 private final ClientRegistrationRepository clientRegistrationRepository ;
54- private final AntPathRequestMatcher authorizationRequestMatcher ;
47+ private final PathPatternRequestMatcher authorizationRequestMatcher ;
5548
56- private final RestOperations restOperations ;
5749 private final JwtClientAssertionParametersService jwtClientAssertionParametersService ;
5850 private final Map <String , OAuth2ClientDetailProperties .Registration > registrations ;
5951
52+ private final RestClient restClient =
53+ RestClient .builder ()
54+ .messageConverters (
55+ (messageConverters ) -> {
56+ messageConverters .clear ();
57+ messageConverters .add (new FormHttpMessageConverter ());
58+ messageConverters .add (new OAuth2AccessTokenResponseHttpMessageConverter ());
59+ messageConverters .add (new MappingJackson2HttpMessageConverter ());
60+ })
61+ .defaultStatusHandler (new OAuth2ErrorResponseErrorHandler ())
62+ .build ();
63+
6064 public PARAuthorizationWithPkceRequestResolver (
6165 ClientRegistrationRepository clientRegistrationRepository ,
6266 Map <String , OAuth2ClientDetailProperties .Registration > registrations ,
6367 String authorizationRequestBaseUri ) {
6468 Assert .notNull (clientRegistrationRepository , "clientRegistrationRepository cannot be null" );
6569 Assert .hasText (authorizationRequestBaseUri , "authorizationRequestBaseUri cannot be empty" );
6670 this .clientRegistrationRepository = clientRegistrationRepository ;
67- this .authorizationRequestMatcher = new AntPathRequestMatcher (
68- authorizationRequestBaseUri + "/{" + REGISTRATION_ID_URI_VARIABLE_NAME + "}" );
69-
70- RestTemplate restTemplate = new RestTemplate (
71- Arrays .asList (new FormHttpMessageConverter (),
72- new OAuth2AccessTokenResponseHttpMessageConverter (),
73- new MappingJackson2HttpMessageConverter ()));
74- restTemplate .setErrorHandler (new OAuth2ErrorResponseErrorHandler ());
75- this .restOperations = restTemplate ;
71+ this .authorizationRequestMatcher =
72+ PathPatternRequestMatcher .withDefaults ()
73+ .matcher (authorizationRequestBaseUri + "/{" + REGISTRATION_ID_URI_VARIABLE_NAME + "}" );
7674
7775 this .registrations = registrations ;
7876 jwtClientAssertionParametersService = new JwtClientAssertionParametersService (registrations );
@@ -97,95 +95,99 @@ public OAuth2AuthorizationRequest resolve(HttpServletRequest request, String reg
9795 return resolve (request , registrationId , redirectUriAction );
9896 }
9997
100- private OAuth2AuthorizationRequest resolve (HttpServletRequest request , String registrationId ,
101- String redirectUriAction ) {
98+ private OAuth2AuthorizationRequest resolve (
99+ HttpServletRequest request , String registrationId , String redirectUriAction ) {
102100 if (registrationId == null ) {
103101 return null ;
104102 }
105- ClientRegistration clientRegistration = this . clientRegistrationRepository . findByRegistrationId (
106- registrationId );
103+ ClientRegistration clientRegistration =
104+ this . clientRegistrationRepository . findByRegistrationId ( registrationId );
107105 if (clientRegistration == null ) {
108106 throw new IllegalArgumentException ("Invalid Client Registration with Id: " + registrationId );
109107 }
110108
111109 String redirectUri = expandRedirectUri (request , clientRegistration , redirectUriAction );
112110 String state = DEFAULT_STATE_GENERATOR .generateKey ();
113111 Map <String , Object > pkceParameters = buildPkceParameters (clientRegistration );
114- String parRequestUri = sendParRequest ( redirectUri , state , clientRegistration ,
115- pkceParameters ).request_uri ;
112+ String parRequestUri =
113+ sendParRequest ( redirectUri , state , clientRegistration , pkceParameters ).request_uri ;
116114
117115 return buildOAuth2AuthorizationRequest (
118116 pkceParameters , clientRegistration , redirectUri , state , parRequestUri );
119117 }
120118
121119 private OAuth2AuthorizationRequest buildOAuth2AuthorizationRequest (
122- Map <String , Object > pkceParameters , ClientRegistration clientRegistration , String redirectUri ,
123- String state , String parRequestUri ) {
124- String codeVerifier = pkceParameters
125- .get (PkceParameterNames .CODE_VERIFIER ).toString ();
126-
127- String authorizationEndpoint = clientRegistration .getProviderDetails ()
128- .getConfigurationMetadata ().get ("authorization_endpoint" ).toString ();
129-
130- OAuth2AuthorizationRequest oAuth2AuthorizationRequest = OAuth2AuthorizationRequest .authorizationCode ()
131- .attributes ((attrs ) ->
132- {
133- attrs .put (OAuth2ParameterNames .REGISTRATION_ID , clientRegistration .getRegistrationId ());
134- attrs .put (PkceParameterNames .CODE_VERIFIER , codeVerifier );
135- })
120+ Map <String , Object > pkceParameters ,
121+ ClientRegistration clientRegistration ,
122+ String redirectUri ,
123+ String state ,
124+ String parRequestUri ) {
125+ String codeVerifier = pkceParameters .get (PkceParameterNames .CODE_VERIFIER ).toString ();
126+
127+ String authorizationEndpoint =
128+ clientRegistration
129+ .getProviderDetails ()
130+ .getConfigurationMetadata ()
131+ .get ("authorization_endpoint" )
132+ .toString ();
133+
134+ String authorizationRequestUri =
135+ getAuthorizationRequestUri (
136+ clientRegistration .getClientId (), authorizationEndpoint , parRequestUri );
137+
138+ return OAuth2AuthorizationRequest .authorizationCode ()
139+ .attributes (
140+ (attrs ) -> {
141+ attrs .put (
142+ OAuth2ParameterNames .REGISTRATION_ID , clientRegistration .getRegistrationId ());
143+ attrs .put (PkceParameterNames .CODE_VERIFIER , codeVerifier );
144+ })
136145 .redirectUri (redirectUri )
137146 .clientId (clientRegistration .getClientId ())
138147 .scope (OidcScopes .OPENID )
139148 .authorizationUri (authorizationEndpoint )
140- .state (state ).build ();
141-
142- String authorizationRequestUri = getAuthorizationRequestUri (parRequestUri ,
143- oAuth2AuthorizationRequest );
144-
145- Field authorizationRequestUriField = ReflectionUtils .findField (OAuth2AuthorizationRequest .class ,
146- "authorizationRequestUri" );
147- if (authorizationRequestUriField != null ) {
148- ReflectionUtils .makeAccessible (authorizationRequestUriField );
149- ReflectionUtils .setField (authorizationRequestUriField , oAuth2AuthorizationRequest ,
150- authorizationRequestUri );
151- }
152- return oAuth2AuthorizationRequest ;
149+ .authorizationRequestUri (authorizationRequestUri )
150+ .state (state )
151+ .build ();
153152 }
154153
155- private String getAuthorizationRequestUri (String parRequestUri ,
156- OAuth2AuthorizationRequest oAuth2AuthorizationRequest ) {
157- DefaultUriBuilderFactory uriBuilderFactory = new DefaultUriBuilderFactory ();
158- uriBuilderFactory .setEncodingMode (DefaultUriBuilderFactory .EncodingMode .NONE );
154+ private String getAuthorizationRequestUri (
155+ String clientId , String authorizationEndpoint , String parRequestUri ) {
159156
160157 MultiValueMap <String , String > queryParams = new LinkedMultiValueMap <>();
161158 queryParams .put ("request_uri" , List .of (parRequestUri ));
162- queryParams .put (OAuth2ParameterNames .CLIENT_ID ,
163- List .of (oAuth2AuthorizationRequest .getClientId ()));
164- UriBuilder uriBuilder = uriBuilderFactory .uriString (
165- oAuth2AuthorizationRequest .getAuthorizationUri ()).queryParams (queryParams );
159+ queryParams .put (OAuth2ParameterNames .CLIENT_ID , List .of (clientId ));
160+
161+ DefaultUriBuilderFactory uriBuilderFactory = new DefaultUriBuilderFactory ();
162+ uriBuilderFactory .setEncodingMode (DefaultUriBuilderFactory .EncodingMode .NONE );
163+ UriBuilder uriBuilder =
164+ uriBuilderFactory .uriString (authorizationEndpoint ).queryParams (queryParams );
166165 return uriBuilder .build ().toString ();
167166 }
168167
169- public ParResponse sendParRequest (String redirectUri , String state ,
170- ClientRegistration clientRegistration , Map <String , Object > pkceParameters ) {
171- String parEndpoint = clientRegistration .getProviderDetails ()
172- .getConfigurationMetadata ().get ("pushed_authorization_request_endpoint" ).toString ();
173-
174- String codeChallengeMethod = pkceParameters
175- .get (PkceParameterNames .CODE_CHALLENGE_METHOD ).toString ();
176- String codeChallenge = pkceParameters
177- .get (PkceParameterNames .CODE_CHALLENGE ).toString ();
178-
179- MultiValueMap <String , String > assertionParameters = jwtClientAssertionParametersService .buildClientAssertionParameters (
180- clientRegistration );
181-
182- String clientAssertionType = assertionParameters .get (OAuth2ParameterNames .CLIENT_ASSERTION_TYPE )
183- .getFirst ();
184- String clientAssertion = assertionParameters .get (OAuth2ParameterNames .CLIENT_ASSERTION )
185- .getFirst ();
186-
187- HttpHeaders headers = new HttpHeaders ();
188- headers .setContentType (MediaType .APPLICATION_FORM_URLENCODED );
168+ public ParResponse sendParRequest (
169+ String redirectUri ,
170+ String state ,
171+ ClientRegistration clientRegistration ,
172+ Map <String , Object > pkceParameters ) {
173+ String parEndpoint =
174+ clientRegistration
175+ .getProviderDetails ()
176+ .getConfigurationMetadata ()
177+ .get ("pushed_authorization_request_endpoint" )
178+ .toString ();
179+
180+ String codeChallengeMethod =
181+ pkceParameters .get (PkceParameterNames .CODE_CHALLENGE_METHOD ).toString ();
182+ String codeChallenge = pkceParameters .get (PkceParameterNames .CODE_CHALLENGE ).toString ();
183+
184+ MultiValueMap <String , String > assertionParameters =
185+ jwtClientAssertionParametersService .buildClientAssertionParameters (clientRegistration );
186+
187+ String clientAssertionType =
188+ assertionParameters .get (OAuth2ParameterNames .CLIENT_ASSERTION_TYPE ).getFirst ();
189+ String clientAssertion =
190+ assertionParameters .get (OAuth2ParameterNames .CLIENT_ASSERTION ).getFirst ();
189191
190192 MultiValueMap <String , String > body = new LinkedMultiValueMap <>();
191193 body .add (OAuth2ParameterNames .CLIENT_ID , clientRegistration .getClientId ());
@@ -201,17 +203,23 @@ public ParResponse sendParRequest(String redirectUri, String state,
201203
202204 Registration registration = registrations .get (clientRegistration .getRegistrationId ());
203205 if (registration != null ) {
204- if (registration .getAcrValues () != null ) {
206+ if (registration .getAcrValues () != null && ! registration . getAcrValues (). isEmpty () ) {
205207 body .add ("acr_values" , registration .getAcrValues ());
206208 }
207- if (registration .getPrompt () != null ) {
209+ if (registration .getPrompt () != null && ! registration . getPrompt (). isEmpty () ) {
208210 body .add ("prompt" , registration .getPrompt ());
209211 }
210212 }
211213
212- HttpEntity <MultiValueMap <String , String >> requestEntity = new HttpEntity <>(body , headers );
213- ResponseEntity <ParResponse > response = restOperations .postForEntity (parEndpoint , requestEntity ,
214- ParResponse .class );
214+ ResponseEntity <ParResponse > response =
215+ restClient
216+ .post ()
217+ .uri (parEndpoint )
218+ .body (body )
219+ .headers (
220+ httpHeaders -> httpHeaders .setContentType (MediaType .APPLICATION_FORM_URLENCODED ))
221+ .retrieve ()
222+ .toEntity (ParResponse .class );
215223
216224 if (response .hasBody ()) {
217225 return response .getBody ();
@@ -220,27 +228,25 @@ public ParResponse sendParRequest(String redirectUri, String state,
220228 }
221229
222230 private Map <String , Object > buildPkceParameters (ClientRegistration clientRegistration ) {
223- Builder builder = OAuth2AuthorizationRequest .authorizationCode ()
224- .clientId (clientRegistration .getClientId ())
225- .authorizationUri (clientRegistration .getProviderDetails ().getAuthorizationUri ());
231+ Builder builder =
232+ OAuth2AuthorizationRequest .authorizationCode ()
233+ .clientId (clientRegistration .getClientId ())
234+ .authorizationUri (clientRegistration .getProviderDetails ().getAuthorizationUri ());
226235
227- OAuth2AuthorizationRequestCustomizers
228- .withPkce ().accept (builder );
236+ OAuth2AuthorizationRequestCustomizers .withPkce ().accept (builder );
229237 OAuth2AuthorizationRequest build = builder .build ();
230238 Map <String , Object > additionalParameters = new HashMap <>(build .getAdditionalParameters ());
231- additionalParameters .put (PkceParameterNames . CODE_VERIFIER ,
232- build .getAttribute (PkceParameterNames .CODE_VERIFIER ));
239+ additionalParameters .put (
240+ PkceParameterNames . CODE_VERIFIER , build .getAttribute (PkceParameterNames .CODE_VERIFIER ));
233241 return additionalParameters ;
234242 }
235243
236-
237244 public static class ParResponse {
238245
239246 public String request_uri ;
240247 public String expires_in ;
241248 }
242249
243-
244250 private String getAction (HttpServletRequest request , String defaultAction ) {
245251 String action = request .getParameter ("action" );
246252 if (action == null ) {
@@ -251,24 +257,25 @@ private String getAction(HttpServletRequest request, String defaultAction) {
251257
252258 private String resolveRegistrationId (HttpServletRequest request ) {
253259 if (this .authorizationRequestMatcher .matches (request )) {
254- return this .authorizationRequestMatcher .matcher (request )
260+ return this .authorizationRequestMatcher
261+ .matcher (request )
255262 .getVariables ()
256263 .get (REGISTRATION_ID_URI_VARIABLE_NAME );
257264 }
258265 return null ;
259266 }
260267
261- private String expandRedirectUri (HttpServletRequest request ,
262- ClientRegistration clientRegistration ,
263- String action ) {
268+ private String expandRedirectUri (
269+ HttpServletRequest request , ClientRegistration clientRegistration , String action ) {
264270 Map <String , String > uriVariables = new HashMap <>();
265271 uriVariables .put ("registrationId" , clientRegistration .getRegistrationId ());
266272 // @formatter:off
267- UriComponents uriComponents = UriComponentsBuilder .fromHttpUrl (UrlUtils .buildFullRequestUrl (request ))
268- .replacePath (request .getContextPath ())
269- .replaceQuery (null )
270- .fragment (null )
271- .build ();
273+ UriComponents uriComponents =
274+ UriComponentsBuilder .fromUriString (UrlUtils .buildFullRequestUrl (request ))
275+ .replacePath (request .getContextPath ())
276+ .replaceQuery (null )
277+ .fragment (null )
278+ .build ();
272279 // @formatter:on
273280 String scheme = uriComponents .getScheme ();
274281 uriVariables .put ("baseScheme" , (scheme != null ) ? scheme : "" );
@@ -290,5 +297,4 @@ private String expandRedirectUri(HttpServletRequest request,
290297 .buildAndExpand (uriVariables )
291298 .toUriString ();
292299 }
293-
294- }
300+ }
0 commit comments