Skip to content

Commit 6b1e07c

Browse files
committed
Resolve NoiseHandshake TODOs and add tests for message length checks
1 parent 228cafc commit 6b1e07c

File tree

2 files changed

+133
-4
lines changed

2 files changed

+133
-4
lines changed

src/main/java/com/eatthepath/noise/NoiseHandshake.java

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ public class NoiseHandshake {
161161

162162
private int currentPreSharedKey;
163163

164+
static final int MAX_NOISE_MESSAGE_SIZE = 65_535;
165+
164166
private static final byte[] EMPTY_BYTE_ARRAY = new byte[0];
165167
private static final ByteBuffer EMPTY_BYTE_BUFFER = ByteBuffer.wrap(EMPTY_BYTE_ARRAY);
166168

@@ -643,8 +645,9 @@ static int getPayloadLength(final HandshakePattern handshakePattern,
643645
* @see <a href="https://noiseprotocol.org/noise.html#payload-security-properties">The Noise Protocol Framework - Payload security properties</a>
644646
*/
645647
public byte[] writeMessage(@Nullable final byte[] payload) {
646-
// TODO Verify that message size is within bounds
647648
final int payloadLength = payload != null ? payload.length : 0;
649+
checkOutboundMessageSize(payloadLength);
650+
648651
final byte[] message = new byte[getOutboundMessageLength(payloadLength)];
649652

650653
try {
@@ -690,7 +693,11 @@ public int writeMessage(@Nullable final byte[] payload,
690693
final byte[] message,
691694
final int messageOffset) throws ShortBufferException {
692695

693-
// TODO Check message buffer length, or just let plumbing deeper down complain?
696+
checkOutboundMessageSize(payloadLength);
697+
698+
if (message.length - messageOffset < getOutboundMessageLength(payloadLength)) {
699+
throw new ShortBufferException("Message array after offset is not large enough to hold handshake message");
700+
}
694701

695702
if (!isExpectingWrite()) {
696703
throw new IllegalStateException("Handshake not currently expecting to write a message");
@@ -770,7 +777,10 @@ public int writeMessage(@Nullable final byte[] payload,
770777
* @see <a href="https://noiseprotocol.org/noise.html#payload-security-properties">The Noise Protocol Framework - Payload security properties</a>
771778
*/
772779
public ByteBuffer writeMessage(@Nullable final ByteBuffer payload) {
773-
final ByteBuffer message = ByteBuffer.allocate(getOutboundMessageLength(payload != null ? payload.remaining() : 0));
780+
final int payloadLength = payload != null ? payload.remaining() : 0;
781+
checkOutboundMessageSize(payloadLength);
782+
783+
final ByteBuffer message = ByteBuffer.allocate(getOutboundMessageLength(payloadLength));
774784

775785
try {
776786
writeMessage(payload, message);
@@ -813,7 +823,12 @@ public ByteBuffer writeMessage(@Nullable final ByteBuffer payload) {
813823
public int writeMessage(@Nullable final ByteBuffer payload,
814824
final ByteBuffer message) throws ShortBufferException {
815825

816-
// TODO Check message buffer length, or just let plumbing deeper down complain?
826+
final int payloadLength = payload != null ? payload.remaining() : 0;
827+
checkOutboundMessageSize(payloadLength);
828+
829+
if (message.remaining() < getOutboundMessageLength(payloadLength)) {
830+
throw new ShortBufferException("Message buffer is not large enough to hold handshake message");
831+
}
817832

818833
if (!isExpectingWrite()) {
819834
throw new IllegalStateException("Handshake not currently expecting to write a message");
@@ -869,6 +884,12 @@ public int writeMessage(@Nullable final ByteBuffer payload,
869884
return bytesWritten;
870885
}
871886

887+
private void checkOutboundMessageSize(final int payloadLength) {
888+
if (getOutboundMessageLength(payloadLength) > MAX_NOISE_MESSAGE_SIZE) {
889+
throw new IllegalArgumentException("Message containing payload would be larger than maximum allowed Noise message size");
890+
}
891+
}
892+
872893
/**
873894
* Reads the next handshake message, advancing this handshake's internal state.
874895
*
@@ -881,6 +902,8 @@ public int writeMessage(@Nullable final ByteBuffer payload,
881902
* @throws IllegalArgumentException if the given message is too short to contain the expected handshake message
882903
*/
883904
public byte[] readMessage(final byte[] message) throws AEADBadTagException {
905+
checkInboundMessageSize(message.length);
906+
884907
final byte[] payload = new byte[getPayloadLength(message.length)];
885908

886909
try {
@@ -921,6 +944,12 @@ public int readMessage(final byte[] message,
921944
final byte[] payload,
922945
final int payloadOffset) throws ShortBufferException, AEADBadTagException {
923946

947+
checkInboundMessageSize(messageLength);
948+
949+
if (payload.length - payloadOffset < getPayloadLength(messageLength)) {
950+
throw new ShortBufferException("Payload array after offset is not large enough to hold payload");
951+
}
952+
924953
if (!isExpectingRead()) {
925954
throw new IllegalStateException("Handshake not currently expecting to read a message");
926955
}
@@ -987,6 +1016,8 @@ public int readMessage(final byte[] message,
9871016
* @throws IllegalArgumentException if the given message is too short to contain the expected handshake message
9881017
*/
9891018
public ByteBuffer readMessage(final ByteBuffer message) throws AEADBadTagException {
1019+
checkInboundMessageSize(message.remaining());
1020+
9901021
final ByteBuffer payload = ByteBuffer.allocate(getPayloadLength(message.remaining()));
9911022

9921023
try {
@@ -1025,6 +1056,12 @@ public ByteBuffer readMessage(final ByteBuffer message) throws AEADBadTagExcepti
10251056
public int readMessage(final ByteBuffer message,
10261057
final ByteBuffer payload) throws ShortBufferException, AEADBadTagException {
10271058

1059+
checkInboundMessageSize(message.remaining());
1060+
1061+
if (payload.remaining() < getPayloadLength(message.remaining())) {
1062+
throw new ShortBufferException("Payload buffer is not large enough to hold payload");
1063+
}
1064+
10281065
if (!isExpectingRead()) {
10291066
throw new IllegalStateException("Handshake not currently expecting to read a message");
10301067
}
@@ -1077,6 +1114,12 @@ public int readMessage(final ByteBuffer message,
10771114
return decryptAndHash(message, payload);
10781115
}
10791116

1117+
private void checkInboundMessageSize(final int messageSize) {
1118+
if (messageSize > MAX_NOISE_MESSAGE_SIZE) {
1119+
throw new IllegalArgumentException("Message is larger than maximum allowed Noise message size");
1120+
}
1121+
}
1122+
10801123
private void handleMixKeyToken(final HandshakePattern.Token token) {
10811124
switch (token) {
10821125
case EE -> {

src/test/java/com/eatthepath/noise/NoiseHandshakeTest.java

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import javax.annotation.Nullable;
1818
import javax.crypto.AEADBadTagException;
19+
import javax.crypto.ShortBufferException;
1920
import java.io.IOException;
2021
import java.io.InputStream;
2122
import java.nio.ByteBuffer;
@@ -191,6 +192,91 @@ void getPayloadLength() throws NoSuchPatternException {
191192
() -> NoiseHandshake.getPayloadLength(handshakePattern, 0, publicKeyLength, 55));
192193
}
193194

195+
@Test
196+
void writeMessageOversize() throws NoSuchAlgorithmException {
197+
final NoiseKeyAgreement keyAgreement = NoiseKeyAgreement.getInstance("25519");
198+
199+
final NoiseHandshake handshake =
200+
NoiseHandshakeBuilder.forIKInitiator(keyAgreement.generateKeyPair(), keyAgreement.generateKeyPair().getPublic())
201+
.setComponentsFromProtocolName("Noise_IK_25519_AESGCM_SHA256")
202+
.build();
203+
204+
// We want to make sure we're testing the size of the resulting message (which may include key material and AEAD
205+
// tags) rather than the length of just the payload
206+
final int payloadLength = NoiseHandshake.MAX_NOISE_MESSAGE_SIZE - 1;
207+
final int messageLength = handshake.getOutboundMessageLength(payloadLength);
208+
209+
assertTrue(messageLength > NoiseHandshake.MAX_NOISE_MESSAGE_SIZE);
210+
211+
assertThrows(IllegalArgumentException.class,
212+
() -> handshake.writeMessage(new byte[payloadLength]));
213+
214+
assertThrows(IllegalArgumentException.class,
215+
() -> handshake.writeMessage(new byte[payloadLength], 0, payloadLength, new byte[messageLength], 0));
216+
217+
assertThrows(IllegalArgumentException.class,
218+
() -> handshake.writeMessage(ByteBuffer.allocate(payloadLength)));
219+
220+
assertThrows(IllegalArgumentException.class,
221+
() -> handshake.writeMessage(ByteBuffer.allocate(payloadLength), ByteBuffer.allocate(messageLength)));
222+
}
223+
224+
@Test
225+
void writeMessageShortBuffer() throws NoSuchAlgorithmException {
226+
final NoiseHandshake handshake =
227+
NoiseHandshakeBuilder.forNNInitiator()
228+
.setComponentsFromProtocolName("Noise_NN_25519_AESGCM_SHA256")
229+
.build();
230+
231+
final byte[] payload = new byte[32];
232+
final byte[] message = new byte[payload.length - 1];
233+
234+
assertThrows(ShortBufferException.class, () ->
235+
handshake.writeMessage(payload, 0, payload.length, message, 0));
236+
237+
assertThrows(ShortBufferException.class, () ->
238+
handshake.writeMessage(ByteBuffer.wrap(payload), ByteBuffer.wrap(message)));
239+
}
240+
241+
@Test
242+
void readMessageOversize() throws NoSuchAlgorithmException {
243+
final NoiseHandshake handshake =
244+
NoiseHandshakeBuilder.forNNResponder()
245+
.setComponentsFromProtocolName("Noise_NN_25519_AESGCM_SHA256")
246+
.build();
247+
248+
final int messageLength = NoiseHandshake.MAX_NOISE_MESSAGE_SIZE + 1;
249+
250+
assertThrows(IllegalArgumentException.class, () ->
251+
handshake.readMessage(new byte[messageLength]));
252+
253+
assertThrows(IllegalArgumentException.class, () ->
254+
handshake.readMessage(new byte[messageLength], 0, messageLength, new byte[messageLength], 0));
255+
256+
assertThrows(IllegalArgumentException.class, () ->
257+
handshake.readMessage(ByteBuffer.allocate(messageLength)));
258+
259+
assertThrows(IllegalArgumentException.class, () ->
260+
handshake.readMessage(ByteBuffer.allocate(messageLength), ByteBuffer.allocate(messageLength)));
261+
}
262+
263+
@Test
264+
void readMessageShortBuffer() throws NoSuchAlgorithmException {
265+
final NoiseHandshake handshake =
266+
NoiseHandshakeBuilder.forNNResponder()
267+
.setComponentsFromProtocolName("Noise_NN_25519_AESGCM_SHA256")
268+
.build();
269+
270+
final byte[] message = new byte[128];
271+
final int payloadLength = handshake.getPayloadLength(message.length);
272+
273+
assertThrows(ShortBufferException.class, () ->
274+
handshake.readMessage(message, 0, message.length, new byte[payloadLength - 1], 0));
275+
276+
assertThrows(ShortBufferException.class, () ->
277+
handshake.readMessage(ByteBuffer.wrap(message), ByteBuffer.allocate(payloadLength - 1)));
278+
}
279+
194280
@ParameterizedTest
195281
@MethodSource("cacophonyTestVectors")
196282
void cacophonyTestsWithNewByteArray(final CacophonyTestVector testVector) throws AEADBadTagException {

0 commit comments

Comments
 (0)