Skip to content

Commit 599a2e0

Browse files
author
Simon MacMullen
committed
Merged from default
2 parents 83656b3 + b83a79a commit 599a2e0

File tree

6 files changed

+269
-15
lines changed

6 files changed

+269
-15
lines changed

src/com/rabbitmq/client/ConnectionFactory.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
import java.io.IOException;
3434
import java.security.KeyManagementException;
3535
import java.security.NoSuchAlgorithmException;
36-
import java.util.HashMap;
3736
import java.util.Map;
3837

3938
import java.net.Socket;
@@ -100,6 +99,7 @@ public class ConnectionFactory implements Cloneable {
10099
private int requestedHeartbeat = DEFAULT_HEARTBEAT;
101100
private Map<String, Object> _clientProperties = AMQConnection.defaultClientProperties();
102101
private SocketFactory factory = SocketFactory.getDefault();
102+
private SaslConfig saslConfig = new DefaultSaslConfig(this);
103103

104104
/**
105105
* Instantiate a ConnectionFactory with a default set of parameters.
@@ -261,6 +261,22 @@ public void setClientProperties(Map<String, Object> clientProperties) {
261261
_clientProperties = clientProperties;
262262
}
263263

264+
/**
265+
* Gets the sasl config to use when authenticating
266+
* @return
267+
*/
268+
public SaslConfig getSaslConfig() {
269+
return saslConfig;
270+
}
271+
272+
/**
273+
* Sets the sasl config to use when authenticating
274+
* @param saslConfig
275+
*/
276+
public void setSaslConfig(SaslConfig saslConfig) {
277+
this.saslConfig = saslConfig;
278+
}
279+
264280
/**
265281
* Retrieve the socket factory used to make connections with.
266282
*/
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package com.rabbitmq.client;
2+
3+
import javax.security.auth.callback.CallbackHandler;
4+
import javax.security.sasl.Sasl;
5+
import javax.security.sasl.SaslClient;
6+
import javax.security.sasl.SaslException;
7+
import java.util.Map;
8+
9+
/**
10+
*
11+
*/
12+
public class DefaultSaslConfig implements SaslConfig {
13+
private ConnectionFactory factory;
14+
private String authorizationId;
15+
private Map<String,?> mechanismProperties;
16+
private CallbackHandler callbackHandler;
17+
18+
public DefaultSaslConfig(ConnectionFactory factory) {
19+
this.factory = factory;
20+
callbackHandler = new UsernamePasswordCallbackHandler(factory);
21+
}
22+
23+
public void setAuthorizationId(String authorizationId) {
24+
this.authorizationId = authorizationId;
25+
}
26+
27+
public void setMechanismProperties(Map<String, ?> mechanismProperties) {
28+
this.mechanismProperties = mechanismProperties;
29+
}
30+
31+
public void setCallbackHandler(CallbackHandler callbackHandler) {
32+
this.callbackHandler = callbackHandler;
33+
}
34+
35+
public SaslClient getSaslClient(String[] mechanisms) throws SaslException {
36+
return Sasl.createSaslClient(mechanisms, authorizationId, "AMQP",
37+
factory.getHost(), mechanismProperties, callbackHandler);
38+
}
39+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package com.rabbitmq.client;
2+
3+
import javax.security.sasl.SaslClient;
4+
import javax.security.sasl.SaslException;
5+
6+
/**
7+
*
8+
*/
9+
public interface SaslConfig {
10+
SaslClient getSaslClient(String[] mechanisms) throws SaslException;
11+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package com.rabbitmq.client;
2+
3+
import javax.security.auth.callback.Callback;
4+
import javax.security.auth.callback.CallbackHandler;
5+
import javax.security.auth.callback.NameCallback;
6+
import javax.security.auth.callback.PasswordCallback;
7+
import javax.security.auth.callback.UnsupportedCallbackException;
8+
import java.io.IOException;
9+
10+
public class UsernamePasswordCallbackHandler implements CallbackHandler {
11+
private ConnectionFactory factory;
12+
public UsernamePasswordCallbackHandler(ConnectionFactory factory) {
13+
this.factory = factory;
14+
}
15+
16+
public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
17+
for (Callback callback: callbacks) {
18+
if (callback instanceof NameCallback) {
19+
NameCallback nc = (NameCallback)callback;
20+
nc.setName(factory.getUsername());
21+
22+
} else if (callback instanceof PasswordCallback) {
23+
PasswordCallback pc = (PasswordCallback)callback;
24+
pc.setPassword(factory.getPassword().toCharArray());
25+
26+
} else {
27+
throw new UnsupportedCallbackException
28+
(callback, "Unrecognized Callback");
29+
}
30+
}
31+
}
32+
}

src/com/rabbitmq/client/impl/AMQConnection.java

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
import com.rabbitmq.utility.BlockingCell;
5252
import com.rabbitmq.utility.Utility;
5353

54+
import javax.security.sasl.SaslClient;
55+
5456
/**
5557
* Concrete class representing and managing an AMQP connection to a broker.
5658
* <p>
@@ -152,7 +154,7 @@ public void ensureIsOpen()
152154
*/
153155
private int _heartbeat;
154156

155-
private final String _username, _password, _virtualHost;
157+
private final String _virtualHost;
156158
private final int _requestedChannelMax, _requestedFrameMax, _requestedHeartbeat;
157159
private final Map<String, Object> _clientProperties;
158160

@@ -200,8 +202,6 @@ public AMQConnection(ConnectionFactory factory,
200202
{
201203
checkPreconditions();
202204

203-
_username = factory.getUsername();
204-
_password = factory.getPassword();
205205
_virtualHost = factory.getVirtualHost();
206206
_requestedChannelMax = factory.getRequestedChannelMax();
207207
_requestedFrameMax = factory.getRequestedFrameMax();
@@ -255,8 +255,9 @@ public void start()
255255
ml.setName("AMQP Connection " + getHost() + ":" + getPort());
256256
ml.start();
257257

258+
AMQP.Connection.Start connStart = null;
258259
try {
259-
AMQP.Connection.Start connStart =
260+
connStart =
260261
(AMQP.Connection.Start) connStartBlocker.getReply().getMethod();
261262

262263
_serverProperties = connStart.getServerProperties();
@@ -274,18 +275,42 @@ public void start()
274275
throw AMQChannel.wrap(sse);
275276
}
276277

277-
LongString saslResponse = LongStringHelper.asLongString("\0" + _username +
278-
"\0" + _password);
279-
AMQImpl.Connection.StartOk startOk =
280-
new AMQImpl.Connection.StartOk(_clientProperties, "PLAIN",
281-
saslResponse, "en_US");
278+
String[] mechanisms = connStart.getMechanisms().toString().split(" ");
279+
SaslClient sc = _factory.getSaslConfig().getSaslClient(mechanisms);
280+
if (sc == null) {
281+
throw new IOException("No compatible authentication mechanism found - " +
282+
"server offered [" + connStart.getMechanisms() + "]");
283+
}
282284

285+
LongString challenge = null;
286+
LongString response = LongStringHelper.asLongString(
287+
sc.hasInitialResponse() ? sc.evaluateChallenge(new byte[0]) : null);
283288
AMQP.Connection.Tune connTune = null;
289+
do {
290+
Method method = (challenge == null)
291+
? new AMQImpl.Connection.StartOk(_clientProperties,
292+
sc.getMechanismName(),
293+
response, "en_US")
294+
: new AMQImpl.Connection.SecureOk(response);
284295

285-
try {
286-
connTune = (AMQP.Connection.Tune) _channel0.rpc(startOk).getMethod();
287-
} catch (ShutdownSignalException e) {
288-
throw new PossibleAuthenticationFailureException(e);
296+
try {
297+
Method serverResponse = _channel0.rpc(method).getMethod();
298+
if (serverResponse instanceof AMQP.Connection.Tune) {
299+
connTune = (AMQP.Connection.Tune) serverResponse;
300+
} else {
301+
challenge = ((AMQP.Connection.Secure) serverResponse).getChallenge();
302+
response = LongStringHelper.asLongString(sc.evaluateChallenge(challenge.getBytes()));
303+
}
304+
} catch (ShutdownSignalException e) {
305+
throw new PossibleAuthenticationFailureException(e);
306+
}
307+
} while (connTune == null);
308+
309+
sc.dispose();
310+
311+
if (!sc.isComplete()) {
312+
throw new RuntimeException(sc.getMechanismName() +
313+
" did not complete, server thought it did");
289314
}
290315

291316
int channelMax =
@@ -714,6 +739,6 @@ public void close(int closeCode,
714739
}
715740

716741
@Override public String toString() {
717-
return "amqp://" + _username + "@" + getHost() + ":" + getPort() + _virtualHost;
742+
return "amqp://" + _factory.getUsername() + "@" + getHost() + ":" + getPort() + _virtualHost;
718743
}
719744
}
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
package com.rabbitmq.client.impl;
2+
3+
import com.rabbitmq.client.ConnectionFactory;
4+
import com.rabbitmq.client.SaslConfig;
5+
import com.rabbitmq.client.UsernamePasswordCallbackHandler;
6+
7+
import javax.security.auth.callback.Callback;
8+
import javax.security.auth.callback.CallbackHandler;
9+
import javax.security.auth.callback.NameCallback;
10+
import javax.security.auth.callback.PasswordCallback;
11+
import javax.security.auth.callback.UnsupportedCallbackException;
12+
import javax.security.sasl.SaslClient;
13+
import javax.security.sasl.SaslException;
14+
import java.io.IOException;
15+
import java.io.UnsupportedEncodingException;
16+
import java.security.MessageDigest;
17+
import java.security.NoSuchAlgorithmException;
18+
import java.util.Arrays;
19+
20+
/**
21+
START-OK: Username
22+
SECURE: {Salt1, Salt2} (where Salt1 is the salt from the db and
23+
Salt2 differs every time)
24+
SECURE-OK: md5(Salt2 ++ md5(Salt1 ++ Password))
25+
26+
The second salt is there to defend against replay attacks. The
27+
first is needed since the passwords are salted in the db.
28+
29+
This is only somewhat improved security over PLAIN (if you can
30+
break MD5 you can still replay attack) but it's better than nothing
31+
and mostly there to prove the use of SECURE / SECURE-OK frames.
32+
*/
33+
34+
public class ScramMD5SaslClient implements SaslClient {
35+
private static final String NAME = "RABBIT-SCRAM-MD5";
36+
37+
private CallbackHandler handler;
38+
private int round = 0;
39+
40+
public ScramMD5SaslClient(CallbackHandler handler) {
41+
this.handler = handler;
42+
}
43+
44+
public String getMechanismName() {
45+
return NAME;
46+
}
47+
48+
public boolean hasInitialResponse() {
49+
return true;
50+
}
51+
52+
public byte[] evaluateChallenge(byte[] challenge) throws SaslException {
53+
byte[] resp;
54+
try {
55+
if (round == 0) {
56+
NameCallback nc = new NameCallback("Name:");
57+
handler.handle(new Callback[]{nc});
58+
resp = nc.getName().getBytes("utf-8");
59+
} else {
60+
byte[] salt1 = Arrays.copyOfRange(challenge, 0, 4);
61+
byte[] salt2 = Arrays.copyOfRange(challenge, 4, 8);
62+
PasswordCallback pc = new PasswordCallback("Password:", false);
63+
handler.handle(new Callback[]{pc});
64+
byte[] pw = new String(pc.getPassword()).getBytes("utf-8");
65+
resp = digest(salt2, digest(salt1, pw));
66+
}
67+
} catch (UnsupportedEncodingException e) {
68+
throw new RuntimeException(e);
69+
} catch (UnsupportedCallbackException e) {
70+
throw new SaslException("Bad callback", e);
71+
} catch (IOException e) {
72+
throw new SaslException("IO Exception", e);
73+
}
74+
75+
round++;
76+
return resp;
77+
}
78+
79+
public boolean isComplete() {
80+
return round == 2;
81+
}
82+
83+
public byte[] unwrap(byte[] bytes, int i, int i1) throws SaslException {
84+
throw new UnsupportedOperationException();
85+
}
86+
87+
public byte[] wrap(byte[] bytes, int i, int i1) throws SaslException {
88+
throw new UnsupportedOperationException();
89+
}
90+
91+
public Object getNegotiatedProperty(String s) {
92+
throw new UnsupportedOperationException();
93+
}
94+
95+
public void dispose() throws SaslException {
96+
// NOOP
97+
}
98+
99+
private static byte[] digest(byte[] arr1, byte[] arr2) {
100+
try {
101+
MessageDigest digest = MessageDigest.getInstance("MD5");
102+
return digest.digest(concat(arr1, arr2));
103+
104+
} catch (NoSuchAlgorithmException e) {
105+
throw new RuntimeException(e);
106+
}
107+
}
108+
109+
private static byte[] concat(byte[] first, byte[] second) {
110+
byte[] result = Arrays.copyOf(first, first.length + second.length);
111+
System.arraycopy(second, 0, result, first.length, second.length);
112+
return result;
113+
}
114+
115+
public static class ScramMD5SaslConfig implements SaslConfig {
116+
private ConnectionFactory factory;
117+
118+
public ScramMD5SaslConfig(ConnectionFactory factory) {
119+
this.factory = factory;
120+
}
121+
122+
public SaslClient getSaslClient(String[] mechanisms) throws SaslException {
123+
if (Arrays.asList(mechanisms).contains(NAME)) {
124+
return new ScramMD5SaslClient(new UsernamePasswordCallbackHandler(factory));
125+
}
126+
else {
127+
return null;
128+
}
129+
}
130+
}
131+
}

0 commit comments

Comments
 (0)