Skip to content

Commit 116f570

Browse files
committed
Create RefreshProtectedCredentialsProvider class
To protect from several actual token retrievals to happen at the same time. This class is used in the OAuth 2 client credentials grant provider. Add also a builder to make the OAuth 2 provider easier to configure, add TLS settings and a test. (cherry picked from commit 5752e55) Conflicts: src/test/java/com/rabbitmq/client/test/ClientTests.java
1 parent 67f9912 commit 116f570

File tree

8 files changed

+495
-132
lines changed

8 files changed

+495
-132
lines changed

pom.xml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
<mockito.version>3.0.0</mockito.version>
6565
<assertj.version>3.12.2</assertj.version>
6666
<jetty.version>9.4.19.v20190610</jetty.version>
67+
<bouncycastle.version>1.61</bouncycastle.version>
6768

6869
<maven.javadoc.plugin.version>3.0.1</maven.javadoc.plugin.version>
6970
<maven.release.plugin.version>2.5.3</maven.release.plugin.version>
@@ -751,6 +752,12 @@
751752
<version>${jetty.version}</version>
752753
<scope>test</scope>
753754
</dependency>
755+
<dependency>
756+
<groupId>org.bouncycastle</groupId>
757+
<artifactId>bcpkix-jdk15on</artifactId>
758+
<version>${bouncycastle.version}</version>
759+
<scope>test</scope>
760+
</dependency>
754761
</dependencies>
755762

756763
<build>

src/main/java/com/rabbitmq/client/impl/OAuth2ClientCredentialsGrantCredentialsProvider.java

Lines changed: 119 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -16,51 +16,60 @@
1616
package com.rabbitmq.client.impl;
1717

1818
import com.fasterxml.jackson.databind.ObjectMapper;
19-
import org.slf4j.Logger;
20-
import org.slf4j.LoggerFactory;
2119

20+
import javax.net.ssl.HostnameVerifier;
21+
import javax.net.ssl.HttpsURLConnection;
22+
import javax.net.ssl.SSLSocketFactory;
2223
import java.io.*;
2324
import java.net.HttpURLConnection;
2425
import java.net.URL;
2526
import java.net.URLEncoder;
2627
import java.nio.charset.StandardCharsets;
2728
import java.util.*;
28-
import java.util.concurrent.CountDownLatch;
29-
import java.util.concurrent.atomic.AtomicBoolean;
30-
import java.util.concurrent.atomic.AtomicReference;
31-
import java.util.concurrent.locks.Lock;
32-
import java.util.concurrent.locks.ReentrantLock;
3329

34-
35-
public class OAuth2ClientCredentialsGrantCredentialsProvider implements CredentialsProvider {
36-
37-
private static final Logger LOGGER = LoggerFactory.getLogger(OAuth2ClientCredentialsGrantCredentialsProvider.class);
30+
/**
31+
*
32+
* @see RefreshProtectedCredentialsProvider
33+
*/
34+
public class OAuth2ClientCredentialsGrantCredentialsProvider extends RefreshProtectedCredentialsProvider<OAuth2ClientCredentialsGrantCredentialsProvider.Token> {
3835

3936
private static final String UTF_8_CHARSET = "UTF-8";
40-
private final String serverUri; // should be renamed to tokenEndpointUri?
37+
private final String tokenEndpointUri;
4138
private final String clientId;
4239
private final String clientSecret;
4340
private final String grantType;
44-
// UAA specific, to distinguish between different users
45-
private final String username, password;
41+
42+
private final Map<String, String> parameters;
4643

4744
private final ObjectMapper objectMapper = new ObjectMapper();
4845

4946
private final String id;
5047

51-
private final AtomicReference<Token> token = new AtomicReference<>();
48+
private final HostnameVerifier hostnameVerifier;
49+
private final SSLSocketFactory sslSocketFactory;
50+
51+
public OAuth2ClientCredentialsGrantCredentialsProvider(String tokenEndpointUri, String clientId, String clientSecret, String grantType) {
52+
this(tokenEndpointUri, clientId, clientSecret, grantType, new HashMap<>());
53+
}
54+
55+
public OAuth2ClientCredentialsGrantCredentialsProvider(String tokenEndpointUri, String clientId, String clientSecret, String grantType,
56+
HostnameVerifier hostnameVerifier, SSLSocketFactory sslSocketFactory) {
57+
this(tokenEndpointUri, clientId, clientSecret, grantType, new HashMap<>(), hostnameVerifier, sslSocketFactory);
58+
}
5259

53-
private final Lock refreshLock = new ReentrantLock();
54-
private final AtomicReference<CountDownLatch> latch = new AtomicReference<>();
55-
private AtomicBoolean refreshInProcess = new AtomicBoolean(false);
60+
public OAuth2ClientCredentialsGrantCredentialsProvider(String tokenEndpointUri, String clientId, String clientSecret, String grantType, Map<String, String> parameters) {
61+
this(tokenEndpointUri, clientId, clientSecret, grantType, parameters, null, null);
62+
}
5663

57-
public OAuth2ClientCredentialsGrantCredentialsProvider(String serverUri, String clientId, String clientSecret, String grantType, String username, String password) {
58-
this.serverUri = serverUri;
64+
public OAuth2ClientCredentialsGrantCredentialsProvider(String tokenEndpointUri, String clientId, String clientSecret, String grantType, Map<String, String> parameters,
65+
HostnameVerifier hostnameVerifier, SSLSocketFactory sslSocketFactory) {
66+
this.tokenEndpointUri = tokenEndpointUri;
5967
this.clientId = clientId;
6068
this.clientSecret = clientSecret;
6169
this.grantType = grantType;
62-
this.username = username;
63-
this.password = password;
70+
this.parameters = Collections.unmodifiableMap(new HashMap<>(parameters));
71+
this.hostnameVerifier = hostnameVerifier;
72+
this.sslSocketFactory = sslSocketFactory;
6473
this.id = UUID.randomUUID().toString();
6574
}
6675

@@ -90,19 +99,8 @@ public String getUsername() {
9099
}
91100

92101
@Override
93-
public String getPassword() {
94-
if (token.get() == null) {
95-
refresh();
96-
}
97-
return token.get().getAccess();
98-
}
99-
100-
@Override
101-
public Date getExpiration() {
102-
if (token.get() == null) {
103-
refresh();
104-
}
105-
return token.get().getExpiration();
102+
protected String usernameFromToken(Token token) {
103+
return "";
106104
}
107105

108106
protected Token parseToken(String response) {
@@ -118,47 +116,21 @@ protected Token parseToken(String response) {
118116
}
119117

120118
@Override
121-
public void refresh() {
122-
// refresh should happen at once. Other calls wait for the refresh to finish and move on.
123-
if (refreshLock.tryLock()) {
124-
LOGGER.debug("Refreshing token");
125-
try {
126-
latch.set(new CountDownLatch(1));
127-
refreshInProcess.set(true);
128-
token.set(retrieveToken());
129-
LOGGER.debug("Token refreshed");
130-
} finally {
131-
latch.get().countDown();
132-
refreshInProcess.set(false);
133-
refreshLock.unlock();
134-
}
135-
} else {
136-
try {
137-
LOGGER.debug("Waiting for token refresh to be finished");
138-
while (!refreshInProcess.get()) {
139-
Thread.sleep(10);
140-
}
141-
latch.get().await();
142-
LOGGER.debug("Done waiting for token refresh");
143-
} catch (InterruptedException e) {
144-
Thread.currentThread().interrupt();
145-
}
146-
}
147-
}
148-
149119
protected Token retrieveToken() {
150-
// FIXME handle TLS specific settings
151120
try {
152121
StringBuilder urlParameters = new StringBuilder();
153122
encode(urlParameters, "grant_type", grantType);
154-
encode(urlParameters, "username", username);
155-
encode(urlParameters, "password", password);
123+
for (Map.Entry<String, String> parameter : parameters.entrySet()) {
124+
encode(urlParameters, parameter.getKey(), parameter.getValue());
125+
}
156126
byte[] postData = urlParameters.toString().getBytes(StandardCharsets.UTF_8);
157127
int postDataLength = postData.length;
158-
URL url = new URL(serverUri);
128+
URL url = new URL(tokenEndpointUri);
129+
159130
// FIXME close connection?
160131
// FIXME set timeout on request
161132
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
133+
162134
conn.setDoOutput(true);
163135
conn.setInstanceFollowRedirects(false);
164136
conn.setRequestMethod("POST");
@@ -168,6 +140,9 @@ protected Token retrieveToken() {
168140
conn.setRequestProperty("accept", "application/json");
169141
conn.setRequestProperty("content-length", Integer.toString(postDataLength));
170142
conn.setUseCaches(false);
143+
144+
configureHttpConnection(conn);
145+
171146
try (DataOutputStream wr = new DataOutputStream(conn.getOutputStream())) {
172147
wr.write(postData);
173148
}
@@ -196,6 +171,28 @@ protected Token retrieveToken() {
196171
}
197172
}
198173

174+
@Override
175+
protected String passwordFromToken(Token token) {
176+
return token.getAccess();
177+
}
178+
179+
@Override
180+
protected Date expirationFromToken(Token token) {
181+
return token.getExpiration();
182+
}
183+
184+
protected void configureHttpConnection(HttpURLConnection connection) {
185+
if (connection instanceof HttpsURLConnection) {
186+
HttpsURLConnection securedConnection = (HttpsURLConnection) connection;
187+
if (this.hostnameVerifier != null) {
188+
securedConnection.setHostnameVerifier(this.hostnameVerifier);
189+
}
190+
if (this.sslSocketFactory != null) {
191+
securedConnection.setSSLSocketFactory(this.sslSocketFactory);
192+
}
193+
}
194+
}
195+
199196
@Override
200197
public boolean equals(Object o) {
201198
if (this == o) return true;
@@ -230,4 +227,59 @@ public String getAccess() {
230227
return access;
231228
}
232229
}
230+
231+
public static class OAuth2ClientCredentialsGrantCredentialsProviderBuilder {
232+
233+
private final Map<String, String> parameters = new HashMap<>();
234+
private String tokenEndpointUri;
235+
private String clientId;
236+
private String clientSecret;
237+
private String grantType = "client_credentials";
238+
private HostnameVerifier hostnameVerifier;
239+
240+
private SSLSocketFactory sslSocketFactory;
241+
242+
public OAuth2ClientCredentialsGrantCredentialsProviderBuilder tokenEndpointUri(String tokenEndpointUri) {
243+
this.tokenEndpointUri = tokenEndpointUri;
244+
return this;
245+
}
246+
247+
public OAuth2ClientCredentialsGrantCredentialsProviderBuilder clientId(String clientId) {
248+
this.clientId = clientId;
249+
return this;
250+
}
251+
252+
public OAuth2ClientCredentialsGrantCredentialsProviderBuilder clientSecret(String clientSecret) {
253+
this.clientSecret = clientSecret;
254+
return this;
255+
}
256+
257+
public OAuth2ClientCredentialsGrantCredentialsProviderBuilder grantType(String grantType) {
258+
this.grantType = grantType;
259+
return this;
260+
}
261+
262+
public OAuth2ClientCredentialsGrantCredentialsProviderBuilder parameter(String name, String value) {
263+
this.parameters.put(name, value);
264+
return this;
265+
}
266+
267+
public OAuth2ClientCredentialsGrantCredentialsProviderBuilder setHostnameVerifier(HostnameVerifier hostnameVerifier) {
268+
this.hostnameVerifier = hostnameVerifier;
269+
return this;
270+
}
271+
272+
public OAuth2ClientCredentialsGrantCredentialsProviderBuilder setSslSocketFactory(SSLSocketFactory sslSocketFactory) {
273+
this.sslSocketFactory = sslSocketFactory;
274+
return this;
275+
}
276+
277+
public OAuth2ClientCredentialsGrantCredentialsProvider build() {
278+
return new OAuth2ClientCredentialsGrantCredentialsProvider(
279+
tokenEndpointUri, clientId, clientSecret, grantType, parameters,
280+
hostnameVerifier, sslSocketFactory
281+
);
282+
}
283+
284+
}
233285
}
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// Copyright (c) 2019 Pivotal Software, Inc. All rights reserved.
2+
//
3+
// This software, the RabbitMQ Java client library, is triple-licensed under the
4+
// Mozilla Public License 1.1 ("MPL"), the GNU General Public License version 2
5+
// ("GPL") and the Apache License version 2 ("ASL"). For the MPL, please see
6+
// LICENSE-MPL-RabbitMQ. For the GPL, please see LICENSE-GPL2. For the ASL,
7+
// please see LICENSE-APACHE2.
8+
//
9+
// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND,
10+
// either express or implied. See the LICENSE file for specific language governing
11+
// rights and limitations of this software.
12+
//
13+
// If you have any questions regarding licensing, please contact us at
14+
// info@rabbitmq.com.
15+
16+
package com.rabbitmq.client.impl;
17+
18+
import org.slf4j.Logger;
19+
import org.slf4j.LoggerFactory;
20+
21+
import java.util.Date;
22+
import java.util.concurrent.CountDownLatch;
23+
import java.util.concurrent.atomic.AtomicBoolean;
24+
import java.util.concurrent.atomic.AtomicReference;
25+
import java.util.concurrent.locks.Lock;
26+
import java.util.concurrent.locks.ReentrantLock;
27+
28+
/**
29+
* An abstract {@link CredentialsProvider} that does not let token refresh happen concurrently.
30+
* <p>
31+
* A token is usually long-lived (several minutes or more), can be re-used inside the same application,
32+
* and refreshing it is a costly operation. This base class lets a first call to {@link #refresh()}
33+
* pass and block concurrent calls until the first call is over. Concurrent calls are then unblocked and
34+
* can benefit from the refresh. This avoids unnecessary refresh operations to happen if a token
35+
* is already being renewed.
36+
* <p>
37+
* Subclasses need to provide the actual token retrieval (whether is a first retrieval or a renewal is
38+
* a implementation detail) and how to extract information (username, password, expiration date) from the retrieved
39+
* token.
40+
*
41+
* @param <T> the type of token (usually specified by the subclass)
42+
*/
43+
public abstract class RefreshProtectedCredentialsProvider<T> implements CredentialsProvider {
44+
45+
private static final Logger LOGGER = LoggerFactory.getLogger(RefreshProtectedCredentialsProvider.class);
46+
47+
private final AtomicReference<T> token = new AtomicReference<>();
48+
49+
private final Lock refreshLock = new ReentrantLock();
50+
private final AtomicReference<CountDownLatch> latch = new AtomicReference<>();
51+
private AtomicBoolean refreshInProcess = new AtomicBoolean(false);
52+
53+
@Override
54+
public String getUsername() {
55+
if (token.get() == null) {
56+
refresh();
57+
}
58+
return usernameFromToken(token.get());
59+
}
60+
61+
@Override
62+
public String getPassword() {
63+
if (token.get() == null) {
64+
refresh();
65+
}
66+
return passwordFromToken(token.get());
67+
}
68+
69+
@Override
70+
public Date getExpiration() {
71+
if (token.get() == null) {
72+
refresh();
73+
}
74+
return expirationFromToken(token.get());
75+
}
76+
77+
@Override
78+
public void refresh() {
79+
// refresh should happen at once. Other calls wait for the refresh to finish and move on.
80+
if (refreshLock.tryLock()) {
81+
LOGGER.debug("Refreshing token");
82+
try {
83+
latch.set(new CountDownLatch(1));
84+
refreshInProcess.set(true);
85+
token.set(retrieveToken());
86+
LOGGER.debug("Token refreshed");
87+
} finally {
88+
latch.get().countDown();
89+
refreshInProcess.set(false);
90+
refreshLock.unlock();
91+
}
92+
} else {
93+
try {
94+
LOGGER.debug("Waiting for token refresh to be finished");
95+
while (!refreshInProcess.get()) {
96+
Thread.sleep(10);
97+
}
98+
latch.get().await();
99+
LOGGER.debug("Done waiting for token refresh");
100+
} catch (InterruptedException e) {
101+
Thread.currentThread().interrupt();
102+
}
103+
}
104+
}
105+
106+
protected abstract T retrieveToken();
107+
108+
protected abstract String usernameFromToken(T token);
109+
110+
protected abstract String passwordFromToken(T token);
111+
112+
protected abstract Date expirationFromToken(T token);
113+
}

0 commit comments

Comments
 (0)