From 46e9dff4464aeb8a530623eabb2341415797e538 Mon Sep 17 00:00:00 2001 From: Rob Hague Date: Sat, 4 Oct 2025 18:01:25 +0100 Subject: [PATCH] Add a socket receive buffer Currently an array is allocated to read each packet from the socket, followed by decryption which allocates another array for the plaintext payload. We can save one of these two allocations by adding a persistent buffer for socket receives, and allowing the cipher implementations to decrypt into the given payload array. We can save the other allocation similarly, but in a separate change. --- .../Abstractions/CryptoAbstraction.cs | 36 +++ .../Security/Cryptography/Cipher.cs | 23 ++ .../Cryptography/Ciphers/AesCipher.BclImpl.cs | 96 ++++---- .../Cryptography/Ciphers/AesCipher.CtrImpl.cs | 79 +++++-- .../Cryptography/Ciphers/AesCipher.cs | 6 + .../Cryptography/Ciphers/AesGcmCipher.cs | 30 ++- .../Ciphers/ChaCha20Poly1305Cipher.cs | 53 +++-- src/Renci.SshNet/Session.cs | 214 ++++++++++++------ 8 files changed, 365 insertions(+), 172 deletions(-) diff --git a/src/Renci.SshNet/Abstractions/CryptoAbstraction.cs b/src/Renci.SshNet/Abstractions/CryptoAbstraction.cs index 0081e8860..01fdfef54 100644 --- a/src/Renci.SshNet/Abstractions/CryptoAbstraction.cs +++ b/src/Renci.SshNet/Abstractions/CryptoAbstraction.cs @@ -1,3 +1,7 @@ +using System; +#if !NET +using System.Runtime.CompilerServices; +#endif using System.Security.Cryptography; using Org.BouncyCastle.Crypto.Prng; @@ -80,6 +84,38 @@ public static byte[] HashSHA512(byte[] source) { return sha512.ComputeHash(source); } +#endif + } + +#if !NET + [MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.NoOptimization)] +#endif + public static bool FixedTimeEquals(ReadOnlySpan left, ReadOnlySpan right) + { +#if NET + return CryptographicOperations.FixedTimeEquals(left, right); +#else + // https://github.com/dotnet/runtime/blob/1d1bf92fcf43aa6981804dc53c5174445069c9e4/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/CryptographicOperations.cs + + // NoOptimization because we want this method to be exactly as non-short-circuiting + // as written. + // + // NoInlining because the NoOptimization would get lost if the method got inlined. + + if (left.Length != right.Length) + { + return false; + } + + var length = left.Length; + var accum = 0; + + for (var i = 0; i < length; i++) + { + accum |= left[i] - right[i]; + } + + return accum == 0; #endif } } diff --git a/src/Renci.SshNet/Security/Cryptography/Cipher.cs b/src/Renci.SshNet/Security/Cryptography/Cipher.cs index 39ddaeb36..872c2131d 100644 --- a/src/Renci.SshNet/Security/Cryptography/Cipher.cs +++ b/src/Renci.SshNet/Security/Cryptography/Cipher.cs @@ -1,3 +1,6 @@ +#nullable enable +using System; + namespace Renci.SshNet.Security.Cryptography { /// @@ -73,5 +76,25 @@ public virtual byte[] Decrypt(byte[] input) /// The decrypted data. /// public abstract byte[] Decrypt(byte[] input, int offset, int length); + + /// + /// Decrypts the specified input into a given buffer. + /// + /// The input. + /// The zero-based offset in at which to begin decrypting. + /// The number of bytes to decrypt from . + /// The output buffer to write to. + /// The zero-based offset in at which to write decrypted output. + /// + /// The number of bytes written to . + /// + public virtual int Decrypt(byte[] input, int offset, int length, byte[] output, int outputOffset) + { + var plaintext = Decrypt(input, offset, length); + + plaintext.AsSpan().CopyTo(output.AsSpan(outputOffset)); + + return plaintext.Length; + } } } diff --git a/src/Renci.SshNet/Security/Cryptography/Ciphers/AesCipher.BclImpl.cs b/src/Renci.SshNet/Security/Cryptography/Ciphers/AesCipher.BclImpl.cs index 76e43e949..c8bc62a0a 100644 --- a/src/Renci.SshNet/Security/Cryptography/Ciphers/AesCipher.BclImpl.cs +++ b/src/Renci.SshNet/Security/Cryptography/Ciphers/AesCipher.BclImpl.cs @@ -1,4 +1,5 @@ -using System; +#nullable enable +using System; using System.Security.Cryptography; using Renci.SshNet.Common; @@ -39,53 +40,45 @@ public BclImpl( } public override byte[] Encrypt(byte[] input, int offset, int length) + { + return Transform(_encryptor, input, offset, length, output: null, 0, out _); + } + + public override byte[] Decrypt(byte[] input, int offset, int length) + { + return Transform(_decryptor, input, offset, length, output: null, 0, out _); + } + + public override int Decrypt(byte[] input, int offset, int length, byte[] output, int outputOffset) + { + _ = Transform(_decryptor, input, offset, length, output, outputOffset, out var bytesWritten); + + return bytesWritten; + } + + private byte[] Transform(ICryptoTransform transform, byte[] input, int offset, int length, byte[]? output, int outputOffset, out int bytesWritten) { if (_aes.Padding != PaddingMode.None) { // If padding has been specified, call TransformFinalBlock to apply // the padding and reset the state. - return _encryptor.TransformFinalBlock(input, offset, length); - } - var paddingLength = 0; - if (length % BlockSize > 0) - { - if (_aes.Mode is System.Security.Cryptography.CipherMode.CFB or System.Security.Cryptography.CipherMode.OFB) + var finalBlock = transform.TransformFinalBlock(input, offset, length); + + if (output is not null) { - // Manually pad the input for cfb and ofb cipher mode as BCL doesn't support partial block. - // See https://github.com/dotnet/runtime/blob/e7d837da5b1aacd9325a8b8f2214cfaf4d3f0ff6/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/SymmetricPadding.cs#L20-L21 - paddingLength = BlockSize - (length % BlockSize); - input = input.Take(offset, length); - length += paddingLength; - Array.Resize(ref input, length); - offset = 0; + finalBlock.AsSpan().CopyTo(output.AsSpan(outputOffset)); } + + bytesWritten = finalBlock.Length; + + return finalBlock; } // Otherwise, (the most important case) assume this instance is // used for one direction of an SSH connection, whereby the // encrypted data in all packets are considered a single data - // stream i.e. we do not want to reset the state between calls to Encrypt. - var output = new byte[length]; - _ = _encryptor.TransformBlock(input, offset, length, output, 0); - - if (paddingLength > 0) - { - // Manually unpad the output. - Array.Resize(ref output, output.Length - paddingLength); - } - - return output; - } - - public override byte[] Decrypt(byte[] input, int offset, int length) - { - if (_aes.Padding != PaddingMode.None) - { - // If padding has been specified, call TransformFinalBlock to apply - // the padding and reset the state. - return _decryptor.TransformFinalBlock(input, offset, length); - } + // stream i.e. we do not want to reset the state between calls to Decrypt. var paddingLength = 0; if (length % BlockSize > 0) @@ -95,24 +88,33 @@ public override byte[] Decrypt(byte[] input, int offset, int length) // Manually pad the input for cfb and ofb cipher mode as BCL doesn't support partial block. // See https://github.com/dotnet/runtime/blob/e7d837da5b1aacd9325a8b8f2214cfaf4d3f0ff6/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/SymmetricPadding.cs#L20-L21 paddingLength = BlockSize - (length % BlockSize); - input = input.Take(offset, length); - length += paddingLength; - Array.Resize(ref input, length); + + var tmp = new byte[length + paddingLength]; + + input.AsSpan(offset, length).CopyTo(tmp); + + input = tmp; offset = 0; + length = tmp.Length; } } - // Otherwise, (the most important case) assume this instance is - // used for one direction of an SSH connection, whereby the - // encrypted data in all packets are considered a single data - // stream i.e. we do not want to reset the state between calls to Decrypt. - var output = new byte[length]; - _ = _decryptor.TransformBlock(input, offset, length, output, 0); - - if (paddingLength > 0) + if (output is null) { + output = new byte[length]; + + bytesWritten = transform.TransformBlock(input, offset, length, output, outputOffset); + + bytesWritten -= paddingLength; + // Manually unpad the output. - Array.Resize(ref output, output.Length - paddingLength); + Array.Resize(ref output, bytesWritten); + } + else + { + bytesWritten = transform.TransformBlock(input, offset, length, output, outputOffset); + + bytesWritten -= paddingLength; } return output; diff --git a/src/Renci.SshNet/Security/Cryptography/Ciphers/AesCipher.CtrImpl.cs b/src/Renci.SshNet/Security/Cryptography/Ciphers/AesCipher.CtrImpl.cs index 0d4dde5cd..7e569c9ad 100644 --- a/src/Renci.SshNet/Security/Cryptography/Ciphers/AesCipher.CtrImpl.cs +++ b/src/Renci.SshNet/Security/Cryptography/Ciphers/AesCipher.CtrImpl.cs @@ -1,8 +1,12 @@ -using System; +#nullable enable +using System; using System.Buffers.Binary; +using System.Diagnostics; using System.Numerics; using System.Security.Cryptography; +using Renci.SshNet.Common; + namespace Renci.SshNet.Security.Cryptography.Ciphers { public partial class AesCipher @@ -34,12 +38,32 @@ public CtrImpl( public override byte[] Encrypt(byte[] input, int offset, int length) { - return CTREncryptDecrypt(input, offset, length); + return Decrypt(input, offset, length); } public override byte[] Decrypt(byte[] input, int offset, int length) { - return CTREncryptDecrypt(input, offset, length); + ThrowHelper.ThrowIfNull(input); + + var buffer = CTREncryptDecrypt(input, offset, length, output: null, 0); + + // adjust output for non-blocksized lengths + if (buffer.Length > length) + { + Array.Resize(ref buffer, length); + } + + return buffer; + } + + public override int Decrypt(byte[] input, int offset, int length, byte[] output, int outputOffset) + { + ThrowHelper.ThrowIfNull(input); + ThrowHelper.ThrowIfNull(output); + + _ = CTREncryptDecrypt(input, offset, length, output, outputOffset); + + return length; } public override int DecryptBlock(byte[] inputBuffer, int inputOffset, int inputCount, byte[] outputBuffer, int outputOffset) @@ -52,35 +76,46 @@ public override int EncryptBlock(byte[] inputBuffer, int inputOffset, int inputC throw new NotImplementedException($"Invalid usage of {nameof(EncryptBlock)}."); } - private byte[] CTREncryptDecrypt(byte[] data, int offset, int length) + private byte[] CTREncryptDecrypt(byte[] data, int offset, int length, byte[]? output, int outputOffset) { - var count = length / BlockSize; - if (length % BlockSize != 0) + var blockSizedLength = length; + if (blockSizedLength % BlockSize != 0) { - count++; + blockSizedLength += BlockSize - (blockSizedLength % BlockSize); } - var buffer = new byte[count * BlockSize]; - CTRCreateCounterArray(buffer); - _ = _encryptor.TransformBlock(buffer, 0, buffer.Length, buffer, 0); - ArrayXOR(buffer, data, offset, length); + Debug.Assert(blockSizedLength % BlockSize == 0); - // adjust output for non-blocksized lengths - if (buffer.Length > length) + if (output is null) { - Array.Resize(ref buffer, length); + output = new byte[blockSizedLength]; + outputOffset = 0; + } + else if (data.AsSpan(offset, length).Overlaps(output.AsSpan(outputOffset, blockSizedLength))) + { + throw new ArgumentException("Input and output buffers must not overlap"); } - return buffer; + CTRCreateCounterArray(output.AsSpan(outputOffset, blockSizedLength)); + + var bytesWritten = _encryptor.TransformBlock(output, outputOffset, blockSizedLength, output, outputOffset); + + Debug.Assert(bytesWritten == blockSizedLength); + + ArrayXOR(output, outputOffset, data, offset, length); + + return output; } // creates the Counter array filled with incrementing copies of IV - private void CTRCreateCounterArray(byte[] buffer) + private void CTRCreateCounterArray(Span buffer) { + Debug.Assert(buffer.Length % 16 == 0); + for (var i = 0; i < buffer.Length; i += 16) { - BinaryPrimitives.WriteUInt64BigEndian(buffer.AsSpan(i + 8), _ivLower); - BinaryPrimitives.WriteUInt64BigEndian(buffer.AsSpan(i), _ivUpper); + BinaryPrimitives.WriteUInt64BigEndian(buffer.Slice(i + 8), _ivLower); + BinaryPrimitives.WriteUInt64BigEndian(buffer.Slice(i), _ivUpper); _ivLower += 1; _ivUpper += (_ivLower == 0) ? 1UL : 0UL; @@ -88,20 +123,20 @@ private void CTRCreateCounterArray(byte[] buffer) } // XOR 2 arrays using Vector - private static void ArrayXOR(byte[] buffer, byte[] data, int offset, int length) + private static void ArrayXOR(byte[] buffer, int bufferOffset, byte[] data, int offset, int length) { var i = 0; var oneVectorFromEnd = length - Vector.Count; for (; i <= oneVectorFromEnd; i += Vector.Count) { - var v = new Vector(buffer, i) ^ new Vector(data, offset + i); - v.CopyTo(buffer, i); + var v = new Vector(buffer, bufferOffset + i) ^ new Vector(data, offset + i); + v.CopyTo(buffer, bufferOffset + i); } for (; i < length; i++) { - buffer[i] ^= data[offset + i]; + buffer[bufferOffset + i] ^= data[offset + i]; } } diff --git a/src/Renci.SshNet/Security/Cryptography/Ciphers/AesCipher.cs b/src/Renci.SshNet/Security/Cryptography/Ciphers/AesCipher.cs index e0963dff2..eb335549a 100644 --- a/src/Renci.SshNet/Security/Cryptography/Ciphers/AesCipher.cs +++ b/src/Renci.SshNet/Security/Cryptography/Ciphers/AesCipher.cs @@ -77,6 +77,12 @@ public override byte[] Decrypt(byte[] input, int offset, int length) return _impl.Decrypt(input, offset, length); } + /// + public override int Decrypt(byte[] input, int offset, int length, byte[] output, int outputOffset) + { + return _impl.Decrypt(input, offset, length, output, outputOffset); + } + /// public void Dispose() { diff --git a/src/Renci.SshNet/Security/Cryptography/Ciphers/AesGcmCipher.cs b/src/Renci.SshNet/Security/Cryptography/Ciphers/AesGcmCipher.cs index ded36fc0d..dcd686c3f 100644 --- a/src/Renci.SshNet/Security/Cryptography/Ciphers/AesGcmCipher.cs +++ b/src/Renci.SshNet/Security/Cryptography/Ciphers/AesGcmCipher.cs @@ -110,6 +110,17 @@ public override byte[] Encrypt(byte[] input, int offset, int length) return output; } + public override byte[] Decrypt(byte[] input, int offset, int length) + { + var output = new byte[length]; + + var bytesWritten = Decrypt(input, offset, length, output, 0); + + Debug.Assert(bytesWritten == length); + + return output; + } + /// /// Decrypts the specified input. /// @@ -121,17 +132,12 @@ public override byte[] Encrypt(byte[] input, int offset, int length) /// /// The zero-based offset in at which to begin decrypting and authenticating. /// The number of bytes to decrypt and authenticate from . - /// - /// The decrypted data with below format: - /// - /// [----Plain Text----] - /// - /// - public override byte[] Decrypt(byte[] input, int offset, int length) + /// The buffer to which to write the decrypted bytes. + /// The zero-based offset in at which to write the decrypted bytes. + /// The number of plaintext bytes written to . + public override int Decrypt(byte[] input, int offset, int length, byte[] output, int outputOffset) { - Debug.Assert(offset >= _aadLength, "The offset must be greater than or equals to aad length"); - - var output = new byte[length]; + Debug.Assert(offset >= _aadLength, "The offset must be greater than or equal to aad length"); _impl.Decrypt( input, @@ -140,11 +146,11 @@ public override byte[] Decrypt(byte[] input, int offset, int length) associatedDataOffset: offset - _aadLength, associatedDataLength: _aadLength, output, - plainTextOffset: 0); + outputOffset); IncrementCounter(); - return output; + return length; } /// diff --git a/src/Renci.SshNet/Security/Cryptography/Ciphers/ChaCha20Poly1305Cipher.cs b/src/Renci.SshNet/Security/Cryptography/Ciphers/ChaCha20Poly1305Cipher.cs index 1559c2631..4107db278 100644 --- a/src/Renci.SshNet/Security/Cryptography/Ciphers/ChaCha20Poly1305Cipher.cs +++ b/src/Renci.SshNet/Security/Cryptography/Ciphers/ChaCha20Poly1305Cipher.cs @@ -112,18 +112,39 @@ public override byte[] Encrypt(byte[] input, int offset, int length) } /// - /// Decrypts the AAD. + /// Decrypts the specified bytes. /// - /// The encrypted AAD. - /// The decrypted AAD. - public override byte[] Decrypt(byte[] input) + /// + /// If a positive AAD length was specified for this instance, the bytes are + /// decrypted as if they are AAD (i.e. the packet length of an SSH packet). + /// + /// The buffer containing the ciphertext to decrypt. + /// The offset of the ciphertext in the buffer. + /// The length of the ciphertext to decrypt. + /// The decrypted plaintext. + public override byte[] Decrypt(byte[] input, int offset, int length) { - Debug.Assert(_aadCipher != null, "The aadCipher must not be null"); + byte[] output; + + if (_aadLength > 0) + { + // If we are in 'AAD mode', then put these bytes through the AAD cipher. + + Debug.Assert(_aadCipher != null); - _aadCipher.Init(forEncryption: false, new ParametersWithIV(_aadKeyParameter, _iv)); + _aadCipher.Init(forEncryption: false, new ParametersWithIV(_aadKeyParameter, _iv)); + + output = new byte[length]; + _aadCipher.ProcessBytes(input, offset, length, output, 0); + } + else + { + output = new byte[length]; - var output = new byte[input.Length]; - _aadCipher.ProcessBytes(input, 0, input.Length, output, 0); + var bytesWritten = Decrypt(input, offset, length, output, 0); + + Debug.Assert(bytesWritten == length); + } return output; } @@ -139,13 +160,10 @@ public override byte[] Decrypt(byte[] input) /// /// The zero-based offset in at which to begin decrypting and authenticating. /// The number of bytes to decrypt and authenticate from . - /// - /// The decrypted data with below format: - /// - /// [----Plain Text----] - /// - /// - public override byte[] Decrypt(byte[] input, int offset, int length) + /// The buffer to which to write the decrypted bytes. + /// The zero-based offset in at which to write the decrypted bytes. + /// The number of plaintext bytes written to . + public override int Decrypt(byte[] input, int offset, int length, byte[] output, int outputOffset) { Debug.Assert(offset >= _aadLength, "The offset must be greater than or equals to aad length"); @@ -163,10 +181,9 @@ public override byte[] Decrypt(byte[] input, int offset, int length) throw new SshConnectionException("MAC error", DisconnectReason.MacError); } - var output = new byte[length]; - _cipher.ProcessBytes(input, offset, length, output, 0); + _cipher.ProcessBytes(input, offset, length, output, outputOffset); - return output; + return length; } internal override void SetSequenceNumber(uint sequenceNumber) diff --git a/src/Renci.SshNet/Session.cs b/src/Renci.SshNet/Session.cs index 77fe9d4c2..fb6da7633 100644 --- a/src/Renci.SshNet/Session.cs +++ b/src/Renci.SshNet/Session.cs @@ -3,6 +3,7 @@ using System.Diagnostics; using System.Globalization; using System.Linq; +using System.Net; using System.Net.Sockets; using System.Security.Cryptography; #if !NET @@ -201,6 +202,8 @@ public sealed class Session : ISession /// private Socket _socket; + private ArrayBuffer _receiveBuffer = new(4 * 1024); + /// /// Gets the session semaphore that controls session channels. /// @@ -1213,7 +1216,8 @@ private Message ReceiveMessage(Socket socket) int blockSize; - // Determine the size of the first block which is 8 or cipher block size (whichever is larger) bytes, or 4 if "packet length" field is handled separately. + // Determine the size of the first block which is 8 or cipher block size (whichever is larger) bytes, + // or 4 if "packet length" field is handled separately. if (_serverEtm || _serverAead) { blockSize = (byte)4; @@ -1238,118 +1242,160 @@ private Message ReceiveMessage(Socket socket) serverMacLength = _serverMac.HashSize / 8; } - byte[] data; - uint packetLength; - - // Read first block - which starts with the packet length - var firstBlock = new byte[blockSize]; - if (TrySocketRead(socket, firstBlock, 0, blockSize) == 0) + if (_receiveBuffer.ActiveLength < blockSize) { - // connection with SSH server was closed - return null; + var bytesNeeded = blockSize - _receiveBuffer.ActiveLength; + + _receiveBuffer.EnsureAvailableSpace(bytesNeeded); + + var bytesRead = TrySocketRead( + socket, + buffer: _receiveBuffer.DangerousGetUnderlyingBuffer(), + offset: _receiveBuffer.ActiveStartOffset + _receiveBuffer.ActiveLength, + length: _receiveBuffer.AvailableLength, + minimumLength: bytesNeeded); + + _receiveBuffer.Commit(bytesRead); + + if (bytesRead < bytesNeeded) + { + // connection with SSH server was closed + return null; + } } + var firstBlock = new ArraySegment( + _receiveBuffer.DangerousGetUnderlyingBuffer(), + _receiveBuffer.ActiveStartOffset, + blockSize); + var plainFirstBlock = firstBlock; - // First block is not encrypted in AES GCM mode. + // For ETM or AES-GCM, firstBlock holds the packet length which is + // not encrypted. Otherwise, we decrypt the first "blockSize" bytes. + // (For chacha20-poly1305, this means passing the encrypted packet + // length as AAD). if (_serverCipher is not null and not Security.Cryptography.Ciphers.AesGcmCipher) { _serverCipher.SetSequenceNumber(_inboundPacketSequence); - // First block is not encrypted in ETM mode. if (_serverMac == null || !_serverEtm) { - plainFirstBlock = _serverCipher.Decrypt(firstBlock); + plainFirstBlock = new ArraySegment(_serverCipher.Decrypt( + firstBlock.Array, + firstBlock.Offset, + firstBlock.Count)); } } - packetLength = BinaryPrimitives.ReadUInt32BigEndian(plainFirstBlock); + var packetLength = BinaryPrimitives.ReadInt32BigEndian(plainFirstBlock); // Test packet minimum and maximum boundaries if (packetLength < Math.Max((byte)8, blockSize) - 4 || packetLength > MaximumSshPacketSize - 4) { - throw new SshConnectionException(string.Format(CultureInfo.CurrentCulture, "Bad packet length: {0}.", packetLength), - DisconnectReason.ProtocolError); + throw new SshConnectionException( + string.Format(CultureInfo.CurrentCulture, "Bad packet length: {0}.", (uint)packetLength), + DisconnectReason.ProtocolError); } - // Determine the number of bytes left to read; We've already read "blockSize" bytes, but the - // "packet length" field itself - which is 4 bytes - is not included in the length of the packet - var bytesToRead = (int)(packetLength - (blockSize - packetLengthFieldLength)) + serverMacLength; - - // Construct buffer for holding the payload and the inbound packet sequence as we need both in order - // to generate the hash. - // - // The total length of the "data" buffer is an addition of: - // - inboundPacketSequenceLength (4 bytes) - // - packetLength - // - serverMacLength - // - // We include the inbound packet sequence to allow us to have the the full SSH packet in a single - // byte[] for the purpose of calculating the client hash. Room for the server MAC is foreseen - // to read the packet including server MAC in a single pass (except for the initial block). - data = new byte[bytesToRead + blockSize + inboundPacketSequenceLength]; - BinaryPrimitives.WriteUInt32BigEndian(data, _inboundPacketSequence); + var totalPacketLength = 4 + packetLength + serverMacLength; - // Use raw packet length field to calculate the mac in AEAD mode. - if (_serverAead) - { - Buffer.BlockCopy(firstBlock, 0, data, inboundPacketSequenceLength, blockSize); - } - else + if (_receiveBuffer.ActiveLength < totalPacketLength) { - Buffer.BlockCopy(plainFirstBlock, 0, data, inboundPacketSequenceLength, blockSize); - } + var bytesNeeded = totalPacketLength - _receiveBuffer.ActiveLength; - if (bytesToRead > 0) - { - if (TrySocketRead(socket, data, blockSize + inboundPacketSequenceLength, bytesToRead) == 0) + _receiveBuffer.EnsureAvailableSpace(bytesNeeded); + + var bytesRead = TrySocketRead( + socket, + buffer: _receiveBuffer.DangerousGetUnderlyingBuffer(), + offset: _receiveBuffer.ActiveStartOffset + _receiveBuffer.ActiveLength, + length: _receiveBuffer.AvailableLength, + minimumLength: bytesNeeded); + + _receiveBuffer.Commit(bytesRead); + + if (bytesRead < bytesNeeded) { + // connection with SSH server was closed return null; } } - // validate encrypted message against MAC + // Construct buffer for holding the payload and the inbound packet sequence as we need both in order + // to generate the hash. + var data = new byte[4 + totalPacketLength - serverMacLength]; + + BinaryPrimitives.WriteUInt32BigEndian(data, _inboundPacketSequence); + + plainFirstBlock.AsSpan().CopyTo(data.AsSpan(4)); + if (_serverMac != null && _serverEtm) { - var clientHash = _serverMac.ComputeHash(data, 0, data.Length - serverMacLength); -#if NET - if (!CryptographicOperations.FixedTimeEquals(clientHash, new ReadOnlySpan(data, data.Length - serverMacLength, serverMacLength))) -#else - if (!Org.BouncyCastle.Utilities.Arrays.FixedTimeEquals(serverMacLength, clientHash, 0, data, data.Length - serverMacLength)) -#endif + // ETM mac = MAC(key, sequence_number || packet_length || encrypted_packet) + + // sequence_number + _ = _serverMac.TransformBlock( + inputBuffer: data, + inputOffset: 0, + inputCount: 4, + outputBuffer: null, + outputOffset: 0); + + // packet_length || encrypted_packet + _ = _serverMac.TransformBlock( + inputBuffer: _receiveBuffer.DangerousGetUnderlyingBuffer(), + inputOffset: _receiveBuffer.ActiveStartOffset, + inputCount: totalPacketLength - serverMacLength, + outputBuffer: null, + outputOffset: 0); + + _ = _serverMac.TransformFinalBlock(Array.Empty(), 0, 0); + + if (!CryptoAbstraction.FixedTimeEquals(_serverMac.Hash, _receiveBuffer.ActiveSpan.Slice(totalPacketLength - serverMacLength, serverMacLength))) { throw new SshConnectionException("MAC error", DisconnectReason.MacError); } } - if (_serverCipher != null) + var numberOfBytesToDecrypt = 4 + packetLength - blockSize; + + if (_serverCipher != null && numberOfBytesToDecrypt > 0) { - var numberOfBytesToDecrypt = data.Length - (blockSize + inboundPacketSequenceLength + serverMacLength); - if (numberOfBytesToDecrypt > 0) - { - var decryptedData = _serverCipher.Decrypt(data, blockSize + inboundPacketSequenceLength, numberOfBytesToDecrypt); - Buffer.BlockCopy(decryptedData, 0, data, blockSize + inboundPacketSequenceLength, decryptedData.Length); - } - } + Debug.Assert(numberOfBytesToDecrypt % blockSize == 0); - var paddingLength = data[inboundPacketSequenceLength + packetLengthFieldLength]; - var messagePayloadLength = (int)packetLength - paddingLength - paddingLengthFieldLength; - var messagePayloadOffset = inboundPacketSequenceLength + packetLengthFieldLength + paddingLengthFieldLength; + var numberOfBytesDecrypted = _serverCipher.Decrypt( + input: _receiveBuffer.DangerousGetUnderlyingBuffer(), + offset: _receiveBuffer.ActiveStartOffset + blockSize, + length: numberOfBytesToDecrypt, + output: data, + outputOffset: 4 + blockSize); + + Debug.Assert(numberOfBytesDecrypted == numberOfBytesToDecrypt); + } + else + { + _receiveBuffer.ActiveReadOnlySpan.Slice(blockSize, numberOfBytesToDecrypt).CopyTo(data.AsSpan(4 + blockSize)); + } - // validate decrypted message against MAC if (_serverMac != null && !_serverEtm) { - var clientHash = _serverMac.ComputeHash(data, 0, data.Length - serverMacLength); -#if NET - if (!CryptographicOperations.FixedTimeEquals(clientHash, new ReadOnlySpan(data, data.Length - serverMacLength, serverMacLength))) -#else - if (!Org.BouncyCastle.Utilities.Arrays.FixedTimeEquals(serverMacLength, clientHash, 0, data, data.Length - serverMacLength)) -#endif + // non-ETM mac = MAC(key, sequence_number || unencrypted_packet) + + var clientHash = _serverMac.ComputeHash(data); + + if (!CryptoAbstraction.FixedTimeEquals(clientHash, _receiveBuffer.ActiveSpan.Slice(totalPacketLength - serverMacLength, serverMacLength))) { throw new SshConnectionException("MAC error", DisconnectReason.MacError); } } + _receiveBuffer.Discard(totalPacketLength); + + var paddingLength = data[inboundPacketSequenceLength + packetLengthFieldLength]; + var messagePayloadLength = packetLength - paddingLength - paddingLengthFieldLength; + var messagePayloadOffset = inboundPacketSequenceLength + packetLengthFieldLength + paddingLengthFieldLength; + if (_serverDecompression != null) { data = _serverDecompression.Decompress(data, messagePayloadOffset, messagePayloadLength); @@ -1865,20 +1911,37 @@ private static string ToHex(byte[] bytes) } /// - /// Performs a blocking read on the socket until bytes are received. + /// Performs a blocking read on the socket until at least bytes are received. /// /// The to read from. /// An array of type that is the storage location for the received data. /// The position in parameter to store the received data. - /// The number of bytes to read. + /// The maximum number of bytes to read. + /// The minimum number of bytes to read. /// /// The number of bytes read. /// - /// The read has timed-out. /// The read failed. - private static int TrySocketRead(Socket socket, byte[] buffer, int offset, int length) + private static int TrySocketRead(Socket socket, byte[] buffer, int offset, int length, int minimumLength) { - return SocketAbstraction.Read(socket, buffer, offset, length, Timeout.InfiniteTimeSpan); + Debug.Assert(offset >= 0); + Debug.Assert((uint)length <= buffer.Length - offset); + Debug.Assert(minimumLength <= length); + + var totalRead = 0; + while (totalRead < minimumLength) + { + var read = socket.Receive(buffer, offset + totalRead, length - totalRead, SocketFlags.None); + + if (read == 0) + { + return totalRead; + } + + totalRead += read; + } + + return totalRead; } /// @@ -1927,6 +1990,11 @@ private void MessageListener() { try { + if (_socket is { } s) + { + s.ReceiveTimeout = 0; + } + // remain in message loop until socket is shut down or until we're disconnecting while (true) {