diff --git a/src/Renci.SshNet/Abstractions/CryptoAbstraction.cs b/src/Renci.SshNet/Abstractions/CryptoAbstraction.cs index 0081e8860..01fdfef54 100644 --- a/src/Renci.SshNet/Abstractions/CryptoAbstraction.cs +++ b/src/Renci.SshNet/Abstractions/CryptoAbstraction.cs @@ -1,3 +1,7 @@ +using System; +#if !NET +using System.Runtime.CompilerServices; +#endif using System.Security.Cryptography; using Org.BouncyCastle.Crypto.Prng; @@ -80,6 +84,38 @@ public static byte[] HashSHA512(byte[] source) { return sha512.ComputeHash(source); } +#endif + } + +#if !NET + [MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.NoOptimization)] +#endif + public static bool FixedTimeEquals(ReadOnlySpan left, ReadOnlySpan right) + { +#if NET + return CryptographicOperations.FixedTimeEquals(left, right); +#else + // https://github.com/dotnet/runtime/blob/1d1bf92fcf43aa6981804dc53c5174445069c9e4/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/CryptographicOperations.cs + + // NoOptimization because we want this method to be exactly as non-short-circuiting + // as written. + // + // NoInlining because the NoOptimization would get lost if the method got inlined. + + if (left.Length != right.Length) + { + return false; + } + + var length = left.Length; + var accum = 0; + + for (var i = 0; i < length; i++) + { + accum |= left[i] - right[i]; + } + + return accum == 0; #endif } } diff --git a/src/Renci.SshNet/Security/Cryptography/Cipher.cs b/src/Renci.SshNet/Security/Cryptography/Cipher.cs index 39ddaeb36..872c2131d 100644 --- a/src/Renci.SshNet/Security/Cryptography/Cipher.cs +++ b/src/Renci.SshNet/Security/Cryptography/Cipher.cs @@ -1,3 +1,6 @@ +#nullable enable +using System; + namespace Renci.SshNet.Security.Cryptography { /// @@ -73,5 +76,25 @@ public virtual byte[] Decrypt(byte[] input) /// The decrypted data. /// public abstract byte[] Decrypt(byte[] input, int offset, int length); + + /// + /// Decrypts the specified input into a given buffer. + /// + /// The input. + /// The zero-based offset in at which to begin decrypting. + /// The number of bytes to decrypt from . + /// The output buffer to write to. + /// The zero-based offset in at which to write decrypted output. + /// + /// The number of bytes written to . + /// + public virtual int Decrypt(byte[] input, int offset, int length, byte[] output, int outputOffset) + { + var plaintext = Decrypt(input, offset, length); + + plaintext.AsSpan().CopyTo(output.AsSpan(outputOffset)); + + return plaintext.Length; + } } } diff --git a/src/Renci.SshNet/Security/Cryptography/Ciphers/AesCipher.BclImpl.cs b/src/Renci.SshNet/Security/Cryptography/Ciphers/AesCipher.BclImpl.cs index 76e43e949..c8bc62a0a 100644 --- a/src/Renci.SshNet/Security/Cryptography/Ciphers/AesCipher.BclImpl.cs +++ b/src/Renci.SshNet/Security/Cryptography/Ciphers/AesCipher.BclImpl.cs @@ -1,4 +1,5 @@ -using System; +#nullable enable +using System; using System.Security.Cryptography; using Renci.SshNet.Common; @@ -39,53 +40,45 @@ public BclImpl( } public override byte[] Encrypt(byte[] input, int offset, int length) + { + return Transform(_encryptor, input, offset, length, output: null, 0, out _); + } + + public override byte[] Decrypt(byte[] input, int offset, int length) + { + return Transform(_decryptor, input, offset, length, output: null, 0, out _); + } + + public override int Decrypt(byte[] input, int offset, int length, byte[] output, int outputOffset) + { + _ = Transform(_decryptor, input, offset, length, output, outputOffset, out var bytesWritten); + + return bytesWritten; + } + + private byte[] Transform(ICryptoTransform transform, byte[] input, int offset, int length, byte[]? output, int outputOffset, out int bytesWritten) { if (_aes.Padding != PaddingMode.None) { // If padding has been specified, call TransformFinalBlock to apply // the padding and reset the state. - return _encryptor.TransformFinalBlock(input, offset, length); - } - var paddingLength = 0; - if (length % BlockSize > 0) - { - if (_aes.Mode is System.Security.Cryptography.CipherMode.CFB or System.Security.Cryptography.CipherMode.OFB) + var finalBlock = transform.TransformFinalBlock(input, offset, length); + + if (output is not null) { - // Manually pad the input for cfb and ofb cipher mode as BCL doesn't support partial block. - // See https://github.com/dotnet/runtime/blob/e7d837da5b1aacd9325a8b8f2214cfaf4d3f0ff6/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/SymmetricPadding.cs#L20-L21 - paddingLength = BlockSize - (length % BlockSize); - input = input.Take(offset, length); - length += paddingLength; - Array.Resize(ref input, length); - offset = 0; + finalBlock.AsSpan().CopyTo(output.AsSpan(outputOffset)); } + + bytesWritten = finalBlock.Length; + + return finalBlock; } // Otherwise, (the most important case) assume this instance is // used for one direction of an SSH connection, whereby the // encrypted data in all packets are considered a single data - // stream i.e. we do not want to reset the state between calls to Encrypt. - var output = new byte[length]; - _ = _encryptor.TransformBlock(input, offset, length, output, 0); - - if (paddingLength > 0) - { - // Manually unpad the output. - Array.Resize(ref output, output.Length - paddingLength); - } - - return output; - } - - public override byte[] Decrypt(byte[] input, int offset, int length) - { - if (_aes.Padding != PaddingMode.None) - { - // If padding has been specified, call TransformFinalBlock to apply - // the padding and reset the state. - return _decryptor.TransformFinalBlock(input, offset, length); - } + // stream i.e. we do not want to reset the state between calls to Decrypt. var paddingLength = 0; if (length % BlockSize > 0) @@ -95,24 +88,33 @@ public override byte[] Decrypt(byte[] input, int offset, int length) // Manually pad the input for cfb and ofb cipher mode as BCL doesn't support partial block. // See https://github.com/dotnet/runtime/blob/e7d837da5b1aacd9325a8b8f2214cfaf4d3f0ff6/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/SymmetricPadding.cs#L20-L21 paddingLength = BlockSize - (length % BlockSize); - input = input.Take(offset, length); - length += paddingLength; - Array.Resize(ref input, length); + + var tmp = new byte[length + paddingLength]; + + input.AsSpan(offset, length).CopyTo(tmp); + + input = tmp; offset = 0; + length = tmp.Length; } } - // Otherwise, (the most important case) assume this instance is - // used for one direction of an SSH connection, whereby the - // encrypted data in all packets are considered a single data - // stream i.e. we do not want to reset the state between calls to Decrypt. - var output = new byte[length]; - _ = _decryptor.TransformBlock(input, offset, length, output, 0); - - if (paddingLength > 0) + if (output is null) { + output = new byte[length]; + + bytesWritten = transform.TransformBlock(input, offset, length, output, outputOffset); + + bytesWritten -= paddingLength; + // Manually unpad the output. - Array.Resize(ref output, output.Length - paddingLength); + Array.Resize(ref output, bytesWritten); + } + else + { + bytesWritten = transform.TransformBlock(input, offset, length, output, outputOffset); + + bytesWritten -= paddingLength; } return output; diff --git a/src/Renci.SshNet/Security/Cryptography/Ciphers/AesCipher.CtrImpl.cs b/src/Renci.SshNet/Security/Cryptography/Ciphers/AesCipher.CtrImpl.cs index 0d4dde5cd..7e569c9ad 100644 --- a/src/Renci.SshNet/Security/Cryptography/Ciphers/AesCipher.CtrImpl.cs +++ b/src/Renci.SshNet/Security/Cryptography/Ciphers/AesCipher.CtrImpl.cs @@ -1,8 +1,12 @@ -using System; +#nullable enable +using System; using System.Buffers.Binary; +using System.Diagnostics; using System.Numerics; using System.Security.Cryptography; +using Renci.SshNet.Common; + namespace Renci.SshNet.Security.Cryptography.Ciphers { public partial class AesCipher @@ -34,12 +38,32 @@ public CtrImpl( public override byte[] Encrypt(byte[] input, int offset, int length) { - return CTREncryptDecrypt(input, offset, length); + return Decrypt(input, offset, length); } public override byte[] Decrypt(byte[] input, int offset, int length) { - return CTREncryptDecrypt(input, offset, length); + ThrowHelper.ThrowIfNull(input); + + var buffer = CTREncryptDecrypt(input, offset, length, output: null, 0); + + // adjust output for non-blocksized lengths + if (buffer.Length > length) + { + Array.Resize(ref buffer, length); + } + + return buffer; + } + + public override int Decrypt(byte[] input, int offset, int length, byte[] output, int outputOffset) + { + ThrowHelper.ThrowIfNull(input); + ThrowHelper.ThrowIfNull(output); + + _ = CTREncryptDecrypt(input, offset, length, output, outputOffset); + + return length; } public override int DecryptBlock(byte[] inputBuffer, int inputOffset, int inputCount, byte[] outputBuffer, int outputOffset) @@ -52,35 +76,46 @@ public override int EncryptBlock(byte[] inputBuffer, int inputOffset, int inputC throw new NotImplementedException($"Invalid usage of {nameof(EncryptBlock)}."); } - private byte[] CTREncryptDecrypt(byte[] data, int offset, int length) + private byte[] CTREncryptDecrypt(byte[] data, int offset, int length, byte[]? output, int outputOffset) { - var count = length / BlockSize; - if (length % BlockSize != 0) + var blockSizedLength = length; + if (blockSizedLength % BlockSize != 0) { - count++; + blockSizedLength += BlockSize - (blockSizedLength % BlockSize); } - var buffer = new byte[count * BlockSize]; - CTRCreateCounterArray(buffer); - _ = _encryptor.TransformBlock(buffer, 0, buffer.Length, buffer, 0); - ArrayXOR(buffer, data, offset, length); + Debug.Assert(blockSizedLength % BlockSize == 0); - // adjust output for non-blocksized lengths - if (buffer.Length > length) + if (output is null) { - Array.Resize(ref buffer, length); + output = new byte[blockSizedLength]; + outputOffset = 0; + } + else if (data.AsSpan(offset, length).Overlaps(output.AsSpan(outputOffset, blockSizedLength))) + { + throw new ArgumentException("Input and output buffers must not overlap"); } - return buffer; + CTRCreateCounterArray(output.AsSpan(outputOffset, blockSizedLength)); + + var bytesWritten = _encryptor.TransformBlock(output, outputOffset, blockSizedLength, output, outputOffset); + + Debug.Assert(bytesWritten == blockSizedLength); + + ArrayXOR(output, outputOffset, data, offset, length); + + return output; } // creates the Counter array filled with incrementing copies of IV - private void CTRCreateCounterArray(byte[] buffer) + private void CTRCreateCounterArray(Span buffer) { + Debug.Assert(buffer.Length % 16 == 0); + for (var i = 0; i < buffer.Length; i += 16) { - BinaryPrimitives.WriteUInt64BigEndian(buffer.AsSpan(i + 8), _ivLower); - BinaryPrimitives.WriteUInt64BigEndian(buffer.AsSpan(i), _ivUpper); + BinaryPrimitives.WriteUInt64BigEndian(buffer.Slice(i + 8), _ivLower); + BinaryPrimitives.WriteUInt64BigEndian(buffer.Slice(i), _ivUpper); _ivLower += 1; _ivUpper += (_ivLower == 0) ? 1UL : 0UL; @@ -88,20 +123,20 @@ private void CTRCreateCounterArray(byte[] buffer) } // XOR 2 arrays using Vector - private static void ArrayXOR(byte[] buffer, byte[] data, int offset, int length) + private static void ArrayXOR(byte[] buffer, int bufferOffset, byte[] data, int offset, int length) { var i = 0; var oneVectorFromEnd = length - Vector.Count; for (; i <= oneVectorFromEnd; i += Vector.Count) { - var v = new Vector(buffer, i) ^ new Vector(data, offset + i); - v.CopyTo(buffer, i); + var v = new Vector(buffer, bufferOffset + i) ^ new Vector(data, offset + i); + v.CopyTo(buffer, bufferOffset + i); } for (; i < length; i++) { - buffer[i] ^= data[offset + i]; + buffer[bufferOffset + i] ^= data[offset + i]; } } diff --git a/src/Renci.SshNet/Security/Cryptography/Ciphers/AesCipher.cs b/src/Renci.SshNet/Security/Cryptography/Ciphers/AesCipher.cs index e0963dff2..eb335549a 100644 --- a/src/Renci.SshNet/Security/Cryptography/Ciphers/AesCipher.cs +++ b/src/Renci.SshNet/Security/Cryptography/Ciphers/AesCipher.cs @@ -77,6 +77,12 @@ public override byte[] Decrypt(byte[] input, int offset, int length) return _impl.Decrypt(input, offset, length); } + /// + public override int Decrypt(byte[] input, int offset, int length, byte[] output, int outputOffset) + { + return _impl.Decrypt(input, offset, length, output, outputOffset); + } + /// public void Dispose() { diff --git a/src/Renci.SshNet/Security/Cryptography/Ciphers/AesGcmCipher.cs b/src/Renci.SshNet/Security/Cryptography/Ciphers/AesGcmCipher.cs index ded36fc0d..dcd686c3f 100644 --- a/src/Renci.SshNet/Security/Cryptography/Ciphers/AesGcmCipher.cs +++ b/src/Renci.SshNet/Security/Cryptography/Ciphers/AesGcmCipher.cs @@ -110,6 +110,17 @@ public override byte[] Encrypt(byte[] input, int offset, int length) return output; } + public override byte[] Decrypt(byte[] input, int offset, int length) + { + var output = new byte[length]; + + var bytesWritten = Decrypt(input, offset, length, output, 0); + + Debug.Assert(bytesWritten == length); + + return output; + } + /// /// Decrypts the specified input. /// @@ -121,17 +132,12 @@ public override byte[] Encrypt(byte[] input, int offset, int length) /// /// The zero-based offset in at which to begin decrypting and authenticating. /// The number of bytes to decrypt and authenticate from . - /// - /// The decrypted data with below format: - /// - /// [----Plain Text----] - /// - /// - public override byte[] Decrypt(byte[] input, int offset, int length) + /// The buffer to which to write the decrypted bytes. + /// The zero-based offset in at which to write the decrypted bytes. + /// The number of plaintext bytes written to . + public override int Decrypt(byte[] input, int offset, int length, byte[] output, int outputOffset) { - Debug.Assert(offset >= _aadLength, "The offset must be greater than or equals to aad length"); - - var output = new byte[length]; + Debug.Assert(offset >= _aadLength, "The offset must be greater than or equal to aad length"); _impl.Decrypt( input, @@ -140,11 +146,11 @@ public override byte[] Decrypt(byte[] input, int offset, int length) associatedDataOffset: offset - _aadLength, associatedDataLength: _aadLength, output, - plainTextOffset: 0); + outputOffset); IncrementCounter(); - return output; + return length; } /// diff --git a/src/Renci.SshNet/Security/Cryptography/Ciphers/ChaCha20Poly1305Cipher.cs b/src/Renci.SshNet/Security/Cryptography/Ciphers/ChaCha20Poly1305Cipher.cs index 1559c2631..4107db278 100644 --- a/src/Renci.SshNet/Security/Cryptography/Ciphers/ChaCha20Poly1305Cipher.cs +++ b/src/Renci.SshNet/Security/Cryptography/Ciphers/ChaCha20Poly1305Cipher.cs @@ -112,18 +112,39 @@ public override byte[] Encrypt(byte[] input, int offset, int length) } /// - /// Decrypts the AAD. + /// Decrypts the specified bytes. /// - /// The encrypted AAD. - /// The decrypted AAD. - public override byte[] Decrypt(byte[] input) + /// + /// If a positive AAD length was specified for this instance, the bytes are + /// decrypted as if they are AAD (i.e. the packet length of an SSH packet). + /// + /// The buffer containing the ciphertext to decrypt. + /// The offset of the ciphertext in the buffer. + /// The length of the ciphertext to decrypt. + /// The decrypted plaintext. + public override byte[] Decrypt(byte[] input, int offset, int length) { - Debug.Assert(_aadCipher != null, "The aadCipher must not be null"); + byte[] output; + + if (_aadLength > 0) + { + // If we are in 'AAD mode', then put these bytes through the AAD cipher. + + Debug.Assert(_aadCipher != null); - _aadCipher.Init(forEncryption: false, new ParametersWithIV(_aadKeyParameter, _iv)); + _aadCipher.Init(forEncryption: false, new ParametersWithIV(_aadKeyParameter, _iv)); + + output = new byte[length]; + _aadCipher.ProcessBytes(input, offset, length, output, 0); + } + else + { + output = new byte[length]; - var output = new byte[input.Length]; - _aadCipher.ProcessBytes(input, 0, input.Length, output, 0); + var bytesWritten = Decrypt(input, offset, length, output, 0); + + Debug.Assert(bytesWritten == length); + } return output; } @@ -139,13 +160,10 @@ public override byte[] Decrypt(byte[] input) /// /// The zero-based offset in at which to begin decrypting and authenticating. /// The number of bytes to decrypt and authenticate from . - /// - /// The decrypted data with below format: - /// - /// [----Plain Text----] - /// - /// - public override byte[] Decrypt(byte[] input, int offset, int length) + /// The buffer to which to write the decrypted bytes. + /// The zero-based offset in at which to write the decrypted bytes. + /// The number of plaintext bytes written to . + public override int Decrypt(byte[] input, int offset, int length, byte[] output, int outputOffset) { Debug.Assert(offset >= _aadLength, "The offset must be greater than or equals to aad length"); @@ -163,10 +181,9 @@ public override byte[] Decrypt(byte[] input, int offset, int length) throw new SshConnectionException("MAC error", DisconnectReason.MacError); } - var output = new byte[length]; - _cipher.ProcessBytes(input, offset, length, output, 0); + _cipher.ProcessBytes(input, offset, length, output, outputOffset); - return output; + return length; } internal override void SetSequenceNumber(uint sequenceNumber) diff --git a/src/Renci.SshNet/Session.cs b/src/Renci.SshNet/Session.cs index 77fe9d4c2..fb6da7633 100644 --- a/src/Renci.SshNet/Session.cs +++ b/src/Renci.SshNet/Session.cs @@ -3,6 +3,7 @@ using System.Diagnostics; using System.Globalization; using System.Linq; +using System.Net; using System.Net.Sockets; using System.Security.Cryptography; #if !NET @@ -201,6 +202,8 @@ public sealed class Session : ISession /// private Socket _socket; + private ArrayBuffer _receiveBuffer = new(4 * 1024); + /// /// Gets the session semaphore that controls session channels. /// @@ -1213,7 +1216,8 @@ private Message ReceiveMessage(Socket socket) int blockSize; - // Determine the size of the first block which is 8 or cipher block size (whichever is larger) bytes, or 4 if "packet length" field is handled separately. + // Determine the size of the first block which is 8 or cipher block size (whichever is larger) bytes, + // or 4 if "packet length" field is handled separately. if (_serverEtm || _serverAead) { blockSize = (byte)4; @@ -1238,118 +1242,160 @@ private Message ReceiveMessage(Socket socket) serverMacLength = _serverMac.HashSize / 8; } - byte[] data; - uint packetLength; - - // Read first block - which starts with the packet length - var firstBlock = new byte[blockSize]; - if (TrySocketRead(socket, firstBlock, 0, blockSize) == 0) + if (_receiveBuffer.ActiveLength < blockSize) { - // connection with SSH server was closed - return null; + var bytesNeeded = blockSize - _receiveBuffer.ActiveLength; + + _receiveBuffer.EnsureAvailableSpace(bytesNeeded); + + var bytesRead = TrySocketRead( + socket, + buffer: _receiveBuffer.DangerousGetUnderlyingBuffer(), + offset: _receiveBuffer.ActiveStartOffset + _receiveBuffer.ActiveLength, + length: _receiveBuffer.AvailableLength, + minimumLength: bytesNeeded); + + _receiveBuffer.Commit(bytesRead); + + if (bytesRead < bytesNeeded) + { + // connection with SSH server was closed + return null; + } } + var firstBlock = new ArraySegment( + _receiveBuffer.DangerousGetUnderlyingBuffer(), + _receiveBuffer.ActiveStartOffset, + blockSize); + var plainFirstBlock = firstBlock; - // First block is not encrypted in AES GCM mode. + // For ETM or AES-GCM, firstBlock holds the packet length which is + // not encrypted. Otherwise, we decrypt the first "blockSize" bytes. + // (For chacha20-poly1305, this means passing the encrypted packet + // length as AAD). if (_serverCipher is not null and not Security.Cryptography.Ciphers.AesGcmCipher) { _serverCipher.SetSequenceNumber(_inboundPacketSequence); - // First block is not encrypted in ETM mode. if (_serverMac == null || !_serverEtm) { - plainFirstBlock = _serverCipher.Decrypt(firstBlock); + plainFirstBlock = new ArraySegment(_serverCipher.Decrypt( + firstBlock.Array, + firstBlock.Offset, + firstBlock.Count)); } } - packetLength = BinaryPrimitives.ReadUInt32BigEndian(plainFirstBlock); + var packetLength = BinaryPrimitives.ReadInt32BigEndian(plainFirstBlock); // Test packet minimum and maximum boundaries if (packetLength < Math.Max((byte)8, blockSize) - 4 || packetLength > MaximumSshPacketSize - 4) { - throw new SshConnectionException(string.Format(CultureInfo.CurrentCulture, "Bad packet length: {0}.", packetLength), - DisconnectReason.ProtocolError); + throw new SshConnectionException( + string.Format(CultureInfo.CurrentCulture, "Bad packet length: {0}.", (uint)packetLength), + DisconnectReason.ProtocolError); } - // Determine the number of bytes left to read; We've already read "blockSize" bytes, but the - // "packet length" field itself - which is 4 bytes - is not included in the length of the packet - var bytesToRead = (int)(packetLength - (blockSize - packetLengthFieldLength)) + serverMacLength; - - // Construct buffer for holding the payload and the inbound packet sequence as we need both in order - // to generate the hash. - // - // The total length of the "data" buffer is an addition of: - // - inboundPacketSequenceLength (4 bytes) - // - packetLength - // - serverMacLength - // - // We include the inbound packet sequence to allow us to have the the full SSH packet in a single - // byte[] for the purpose of calculating the client hash. Room for the server MAC is foreseen - // to read the packet including server MAC in a single pass (except for the initial block). - data = new byte[bytesToRead + blockSize + inboundPacketSequenceLength]; - BinaryPrimitives.WriteUInt32BigEndian(data, _inboundPacketSequence); + var totalPacketLength = 4 + packetLength + serverMacLength; - // Use raw packet length field to calculate the mac in AEAD mode. - if (_serverAead) - { - Buffer.BlockCopy(firstBlock, 0, data, inboundPacketSequenceLength, blockSize); - } - else + if (_receiveBuffer.ActiveLength < totalPacketLength) { - Buffer.BlockCopy(plainFirstBlock, 0, data, inboundPacketSequenceLength, blockSize); - } + var bytesNeeded = totalPacketLength - _receiveBuffer.ActiveLength; - if (bytesToRead > 0) - { - if (TrySocketRead(socket, data, blockSize + inboundPacketSequenceLength, bytesToRead) == 0) + _receiveBuffer.EnsureAvailableSpace(bytesNeeded); + + var bytesRead = TrySocketRead( + socket, + buffer: _receiveBuffer.DangerousGetUnderlyingBuffer(), + offset: _receiveBuffer.ActiveStartOffset + _receiveBuffer.ActiveLength, + length: _receiveBuffer.AvailableLength, + minimumLength: bytesNeeded); + + _receiveBuffer.Commit(bytesRead); + + if (bytesRead < bytesNeeded) { + // connection with SSH server was closed return null; } } - // validate encrypted message against MAC + // Construct buffer for holding the payload and the inbound packet sequence as we need both in order + // to generate the hash. + var data = new byte[4 + totalPacketLength - serverMacLength]; + + BinaryPrimitives.WriteUInt32BigEndian(data, _inboundPacketSequence); + + plainFirstBlock.AsSpan().CopyTo(data.AsSpan(4)); + if (_serverMac != null && _serverEtm) { - var clientHash = _serverMac.ComputeHash(data, 0, data.Length - serverMacLength); -#if NET - if (!CryptographicOperations.FixedTimeEquals(clientHash, new ReadOnlySpan(data, data.Length - serverMacLength, serverMacLength))) -#else - if (!Org.BouncyCastle.Utilities.Arrays.FixedTimeEquals(serverMacLength, clientHash, 0, data, data.Length - serverMacLength)) -#endif + // ETM mac = MAC(key, sequence_number || packet_length || encrypted_packet) + + // sequence_number + _ = _serverMac.TransformBlock( + inputBuffer: data, + inputOffset: 0, + inputCount: 4, + outputBuffer: null, + outputOffset: 0); + + // packet_length || encrypted_packet + _ = _serverMac.TransformBlock( + inputBuffer: _receiveBuffer.DangerousGetUnderlyingBuffer(), + inputOffset: _receiveBuffer.ActiveStartOffset, + inputCount: totalPacketLength - serverMacLength, + outputBuffer: null, + outputOffset: 0); + + _ = _serverMac.TransformFinalBlock(Array.Empty(), 0, 0); + + if (!CryptoAbstraction.FixedTimeEquals(_serverMac.Hash, _receiveBuffer.ActiveSpan.Slice(totalPacketLength - serverMacLength, serverMacLength))) { throw new SshConnectionException("MAC error", DisconnectReason.MacError); } } - if (_serverCipher != null) + var numberOfBytesToDecrypt = 4 + packetLength - blockSize; + + if (_serverCipher != null && numberOfBytesToDecrypt > 0) { - var numberOfBytesToDecrypt = data.Length - (blockSize + inboundPacketSequenceLength + serverMacLength); - if (numberOfBytesToDecrypt > 0) - { - var decryptedData = _serverCipher.Decrypt(data, blockSize + inboundPacketSequenceLength, numberOfBytesToDecrypt); - Buffer.BlockCopy(decryptedData, 0, data, blockSize + inboundPacketSequenceLength, decryptedData.Length); - } - } + Debug.Assert(numberOfBytesToDecrypt % blockSize == 0); - var paddingLength = data[inboundPacketSequenceLength + packetLengthFieldLength]; - var messagePayloadLength = (int)packetLength - paddingLength - paddingLengthFieldLength; - var messagePayloadOffset = inboundPacketSequenceLength + packetLengthFieldLength + paddingLengthFieldLength; + var numberOfBytesDecrypted = _serverCipher.Decrypt( + input: _receiveBuffer.DangerousGetUnderlyingBuffer(), + offset: _receiveBuffer.ActiveStartOffset + blockSize, + length: numberOfBytesToDecrypt, + output: data, + outputOffset: 4 + blockSize); + + Debug.Assert(numberOfBytesDecrypted == numberOfBytesToDecrypt); + } + else + { + _receiveBuffer.ActiveReadOnlySpan.Slice(blockSize, numberOfBytesToDecrypt).CopyTo(data.AsSpan(4 + blockSize)); + } - // validate decrypted message against MAC if (_serverMac != null && !_serverEtm) { - var clientHash = _serverMac.ComputeHash(data, 0, data.Length - serverMacLength); -#if NET - if (!CryptographicOperations.FixedTimeEquals(clientHash, new ReadOnlySpan(data, data.Length - serverMacLength, serverMacLength))) -#else - if (!Org.BouncyCastle.Utilities.Arrays.FixedTimeEquals(serverMacLength, clientHash, 0, data, data.Length - serverMacLength)) -#endif + // non-ETM mac = MAC(key, sequence_number || unencrypted_packet) + + var clientHash = _serverMac.ComputeHash(data); + + if (!CryptoAbstraction.FixedTimeEquals(clientHash, _receiveBuffer.ActiveSpan.Slice(totalPacketLength - serverMacLength, serverMacLength))) { throw new SshConnectionException("MAC error", DisconnectReason.MacError); } } + _receiveBuffer.Discard(totalPacketLength); + + var paddingLength = data[inboundPacketSequenceLength + packetLengthFieldLength]; + var messagePayloadLength = packetLength - paddingLength - paddingLengthFieldLength; + var messagePayloadOffset = inboundPacketSequenceLength + packetLengthFieldLength + paddingLengthFieldLength; + if (_serverDecompression != null) { data = _serverDecompression.Decompress(data, messagePayloadOffset, messagePayloadLength); @@ -1865,20 +1911,37 @@ private static string ToHex(byte[] bytes) } /// - /// Performs a blocking read on the socket until bytes are received. + /// Performs a blocking read on the socket until at least bytes are received. /// /// The to read from. /// An array of type that is the storage location for the received data. /// The position in parameter to store the received data. - /// The number of bytes to read. + /// The maximum number of bytes to read. + /// The minimum number of bytes to read. /// /// The number of bytes read. /// - /// The read has timed-out. /// The read failed. - private static int TrySocketRead(Socket socket, byte[] buffer, int offset, int length) + private static int TrySocketRead(Socket socket, byte[] buffer, int offset, int length, int minimumLength) { - return SocketAbstraction.Read(socket, buffer, offset, length, Timeout.InfiniteTimeSpan); + Debug.Assert(offset >= 0); + Debug.Assert((uint)length <= buffer.Length - offset); + Debug.Assert(minimumLength <= length); + + var totalRead = 0; + while (totalRead < minimumLength) + { + var read = socket.Receive(buffer, offset + totalRead, length - totalRead, SocketFlags.None); + + if (read == 0) + { + return totalRead; + } + + totalRead += read; + } + + return totalRead; } /// @@ -1927,6 +1990,11 @@ private void MessageListener() { try { + if (_socket is { } s) + { + s.ReceiveTimeout = 0; + } + // remain in message loop until socket is shut down or until we're disconnecting while (true) {