Skip to content

Commit 9fb60ad

Browse files
committed
Add message length checks and tests for NoiseTransportImpl
1 parent 585eb54 commit 9fb60ad

File tree

2 files changed

+155
-0
lines changed

2 files changed

+155
-0
lines changed

src/main/java/com/eatthepath/noise/NoiseTransportImpl.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ class NoiseTransportImpl implements NoiseTransport {
99
private final CipherState readerState;
1010
private final CipherState writerState;
1111

12+
static final int MAX_NOISE_MESSAGE_SIZE = 65_535;
13+
1214
NoiseTransportImpl(final CipherState readerState, final CipherState writerState) {
1315
this.readerState = readerState;
1416
this.writerState = writerState;
@@ -26,16 +28,26 @@ public int getCiphertextLength(final int plaintextLength) {
2628

2729
@Override
2830
public ByteBuffer readMessage(final ByteBuffer ciphertext) throws AEADBadTagException {
31+
checkInboundMessageSize(ciphertext.remaining());
32+
2933
return readerState.decrypt(null, ciphertext);
3034
}
3135

3236
@Override
3337
public int readMessage(final ByteBuffer ciphertext, final ByteBuffer plaintext) throws ShortBufferException, AEADBadTagException {
38+
checkInboundMessageSize(ciphertext.remaining());
39+
40+
if (plaintext.remaining() < getPlaintextLength(ciphertext.remaining())) {
41+
throw new ShortBufferException("Plaintext buffer does not have enough remaining capacity to hold plaintext");
42+
}
43+
3444
return readerState.decrypt(null, ciphertext, plaintext);
3545
}
3646

3747
@Override
3848
public byte[] readMessage(final byte[] ciphertext) throws AEADBadTagException {
49+
checkInboundMessageSize(ciphertext.length);
50+
3951
return readerState.decrypt(null, ciphertext);
4052
}
4153

@@ -46,23 +58,45 @@ public int readMessage(final byte[] ciphertext,
4658
final byte[] plaintext,
4759
final int plaintextOffset) throws ShortBufferException, AEADBadTagException {
4860

61+
checkInboundMessageSize(ciphertextLength);
62+
63+
if (plaintext.length - plaintextOffset < getPlaintextLength(ciphertextLength)) {
64+
throw new ShortBufferException("Plaintext array after offset is not large enough to hold plaintext");
65+
}
66+
4967
return readerState.decrypt(null, 0, 0,
5068
ciphertext, ciphertextOffset, ciphertextLength,
5169
plaintext, plaintextOffset);
5270
}
5371

72+
private void checkInboundMessageSize(final int ciphertextLength) {
73+
if (ciphertextLength > MAX_NOISE_MESSAGE_SIZE) {
74+
throw new IllegalArgumentException("Message is larger than maximum allowed Noise transport message size");
75+
}
76+
}
77+
5478
@Override
5579
public ByteBuffer writeMessage(final ByteBuffer plaintext) {
80+
checkOutboundMessageSize(plaintext.remaining());
81+
5682
return writerState.encrypt(null, plaintext);
5783
}
5884

5985
@Override
6086
public int writeMessage(final ByteBuffer plaintext, final ByteBuffer ciphertext) throws ShortBufferException {
87+
checkOutboundMessageSize(plaintext.remaining());
88+
89+
if (ciphertext.remaining() < getCiphertextLength(plaintext.remaining())) {
90+
throw new ShortBufferException("Ciphertext buffer does not have enough remaining capacity to hold ciphertext");
91+
}
92+
6193
return writerState.encrypt(null, plaintext, ciphertext);
6294
}
6395

6496
@Override
6597
public byte[] writeMessage(final byte[] plaintext) {
98+
checkOutboundMessageSize(plaintext.length);
99+
66100
return writerState.encrypt(null, plaintext);
67101
}
68102

@@ -73,11 +107,23 @@ public int writeMessage(final byte[] plaintext,
73107
final byte[] ciphertext,
74108
final int ciphertextOffset) throws ShortBufferException {
75109

110+
checkOutboundMessageSize(plaintextLength);
111+
112+
if (ciphertext.length - ciphertextOffset < getCiphertextLength(plaintextLength)) {
113+
throw new ShortBufferException("Ciphertext array after offset is not large enough to hold ciphertext");
114+
}
115+
76116
return writerState.encrypt(null, 0, 0,
77117
plaintext, plaintextOffset, plaintextLength,
78118
ciphertext, ciphertextOffset);
79119
}
80120

121+
void checkOutboundMessageSize(final int plaintextLength) {
122+
if (getCiphertextLength(plaintextLength) > MAX_NOISE_MESSAGE_SIZE) {
123+
throw new IllegalArgumentException("Ciphertext would be larger than maximum allowed Noise transport message size");
124+
}
125+
}
126+
81127
@Override
82128
public void rekeyReader() {
83129
readerState.rekey();
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
package com.eatthepath.noise;
2+
3+
import org.junit.jupiter.api.BeforeEach;
4+
import org.junit.jupiter.api.Test;
5+
6+
import javax.crypto.AEADBadTagException;
7+
import javax.crypto.ShortBufferException;
8+
import java.nio.ByteBuffer;
9+
import java.security.NoSuchAlgorithmException;
10+
11+
import static org.junit.jupiter.api.Assertions.*;
12+
13+
class NoiseTransportImplTest {
14+
15+
private NoiseTransportImpl noiseTransport;
16+
17+
@BeforeEach
18+
void setUp() throws NoSuchAlgorithmException, AEADBadTagException {
19+
final NoiseHandshake initiatorHandshake =
20+
NoiseHandshakeBuilder.forNNInitiator()
21+
.setComponentsFromProtocolName("Noise_NN_25519_AESGCM_SHA256")
22+
.build();
23+
24+
final NoiseHandshake responderHandshake =
25+
NoiseHandshakeBuilder.forNNResponder()
26+
.setComponentsFromProtocolName("Noise_NN_25519_AESGCM_SHA256")
27+
.build();
28+
29+
responderHandshake.readMessage(initiatorHandshake.writeMessage((byte[]) null));
30+
initiatorHandshake.readMessage(responderHandshake.writeMessage((byte[]) null));
31+
32+
noiseTransport = (NoiseTransportImpl) initiatorHandshake.toTransport();
33+
}
34+
35+
@Test
36+
void getPlaintextLength() {
37+
final int ciphertextLength = 77;
38+
assertEquals(ciphertextLength - 16, noiseTransport.getPlaintextLength(ciphertextLength));
39+
}
40+
41+
@Test
42+
void getCiphertextLength() {
43+
final int plaintextLength = 83;
44+
assertEquals(plaintextLength + 16, noiseTransport.getCiphertextLength(plaintextLength));
45+
}
46+
47+
@Test
48+
void writeMessageOversize() {
49+
// We want to make sure we're testing the size of the resulting message (which may include key material and AEAD
50+
// tags) rather than the length of just the payload
51+
final int plaintextLength = NoiseTransportImpl.MAX_NOISE_MESSAGE_SIZE - 1;
52+
final int messageLength = noiseTransport.getCiphertextLength(plaintextLength);
53+
54+
assertTrue(messageLength > NoiseTransportImpl.MAX_NOISE_MESSAGE_SIZE);
55+
56+
assertThrows(IllegalArgumentException.class,
57+
() -> noiseTransport.writeMessage(new byte[plaintextLength]));
58+
59+
assertThrows(IllegalArgumentException.class,
60+
() -> noiseTransport.writeMessage(new byte[plaintextLength], 0, plaintextLength, new byte[messageLength], 0));
61+
62+
assertThrows(IllegalArgumentException.class,
63+
() -> noiseTransport.writeMessage(ByteBuffer.allocate(plaintextLength)));
64+
65+
assertThrows(IllegalArgumentException.class,
66+
() -> noiseTransport.writeMessage(ByteBuffer.allocate(plaintextLength), ByteBuffer.allocate(messageLength)));
67+
}
68+
69+
@Test
70+
void writeMessageShortBuffer() {
71+
final byte[] plaintext = new byte[32];
72+
final byte[] message = new byte[noiseTransport.getCiphertextLength(plaintext.length) - 1];
73+
74+
assertThrows(ShortBufferException.class, () ->
75+
noiseTransport.writeMessage(plaintext, 0, plaintext.length, message, 0));
76+
77+
assertThrows(ShortBufferException.class, () ->
78+
noiseTransport.writeMessage(ByteBuffer.wrap(plaintext), ByteBuffer.wrap(message)));
79+
}
80+
81+
@Test
82+
void readMessageOversize() throws NoSuchAlgorithmException {
83+
final int messageLength = NoiseTransportImpl.MAX_NOISE_MESSAGE_SIZE + 1;
84+
85+
assertThrows(IllegalArgumentException.class, () ->
86+
noiseTransport.readMessage(new byte[messageLength]));
87+
88+
assertThrows(IllegalArgumentException.class, () ->
89+
noiseTransport.readMessage(new byte[messageLength], 0, messageLength, new byte[messageLength], 0));
90+
91+
assertThrows(IllegalArgumentException.class, () ->
92+
noiseTransport.readMessage(ByteBuffer.allocate(messageLength)));
93+
94+
assertThrows(IllegalArgumentException.class, () ->
95+
noiseTransport.readMessage(ByteBuffer.allocate(messageLength), ByteBuffer.allocate(messageLength)));
96+
}
97+
98+
@Test
99+
void readMessageShortBuffer() {
100+
final byte[] message = new byte[128];
101+
final int plaintextLength = noiseTransport.getPlaintextLength(message.length);
102+
103+
assertThrows(ShortBufferException.class, () ->
104+
noiseTransport.readMessage(message, 0, message.length, new byte[plaintextLength - 1], 0));
105+
106+
assertThrows(ShortBufferException.class, () ->
107+
noiseTransport.readMessage(ByteBuffer.wrap(message), ByteBuffer.allocate(plaintextLength - 1)));
108+
}
109+
}

0 commit comments

Comments
 (0)