44using System . IO ;
55using System . Linq ;
66using System . Net ;
7+ using System . Net . Security ;
78using System . Net . Sockets ;
89using System . Text ;
910using System . Threading ;
@@ -18,16 +19,19 @@ public partial class PooledSocket : IDisposable
1819 private readonly ILogger _logger ;
1920
2021 private bool _isAlive ;
22+ private bool _useSslStream ;
2123 private Socket _socket ;
2224 private readonly EndPoint _endpoint ;
2325 private readonly int _connectionTimeout ;
2426
2527 private NetworkStream _inputStream ;
28+ private SslStream _sslStream ;
2629
27- public PooledSocket ( EndPoint endpoint , TimeSpan connectionTimeout , TimeSpan receiveTimeout , ILogger logger )
30+ public PooledSocket ( EndPoint endpoint , TimeSpan connectionTimeout , TimeSpan receiveTimeout , ILogger logger , bool useSslStream )
2831 {
2932 _logger = logger ;
3033 _isAlive = true ;
34+ _useSslStream = useSslStream ;
3135
3236 var socket = new Socket ( AddressFamily . InterNetwork , SocketType . Stream , ProtocolType . Tcp ) ;
3337 socket . SetSocketOption ( SocketOptionLevel . Socket , SocketOptionName . KeepAlive , true ) ;
@@ -90,7 +94,15 @@ void Cancel()
9094
9195 if ( success )
9296 {
93- _inputStream = new NetworkStream ( _socket ) ;
97+ if ( _useSslStream )
98+ {
99+ _sslStream = new SslStream ( new NetworkStream ( _socket ) ) ;
100+ _sslStream . AuthenticateAsClient ( ( ( DnsEndPoint ) _endpoint ) . Host ) ;
101+ }
102+ else
103+ {
104+ _inputStream = new NetworkStream ( _socket ) ;
105+ }
94106 }
95107 else
96108 {
@@ -141,7 +153,15 @@ public async Task ConnectAsync()
141153
142154 if ( success )
143155 {
144- _inputStream = new NetworkStream ( _socket ) ;
156+ if ( _useSslStream )
157+ {
158+ _sslStream = new SslStream ( new NetworkStream ( _socket ) ) ;
159+ await _sslStream . AuthenticateAsClientAsync ( ( ( DnsEndPoint ) _endpoint ) . Host ) ;
160+ }
161+ else
162+ {
163+ _inputStream = new NetworkStream ( _socket ) ;
164+ }
145165 }
146166 else
147167 {
@@ -251,7 +271,13 @@ protected void Dispose(bool disposing)
251271 _inputStream . Dispose ( ) ;
252272 }
253273
274+ if ( _sslStream != null )
275+ {
276+ _sslStream . Dispose ( ) ;
277+ }
278+
254279 _inputStream = null ;
280+ _sslStream = null ;
255281 _socket = null ;
256282 this . CleanupCallback = null ;
257283 }
@@ -290,7 +316,7 @@ public int ReadByte()
290316
291317 try
292318 {
293- return _inputStream . ReadByte ( ) ;
319+ return ( _useSslStream ? _sslStream . ReadByte ( ) : _inputStream . ReadByte ( ) ) ;
294320 }
295321 catch ( Exception ex )
296322 {
@@ -309,7 +335,7 @@ public int ReadByteAsync()
309335
310336 try
311337 {
312- return _inputStream . ReadByte ( ) ;
338+ return ( _useSslStream ? _sslStream . ReadByte ( ) : _inputStream . ReadByte ( ) ) ;
313339 }
314340 catch ( Exception ex )
315341 {
@@ -332,7 +358,7 @@ public async Task ReadAsync(byte[] buffer, int offset, int count)
332358 {
333359 try
334360 {
335- int currentRead = await _inputStream . ReadAsync ( buffer , offset , shouldRead ) ;
361+ int currentRead = ( _useSslStream ? await _sslStream . ReadAsync ( buffer , offset , shouldRead ) : await _inputStream . ReadAsync ( buffer , offset , shouldRead ) ) ;
336362 if ( currentRead == count )
337363 break ;
338364 if ( currentRead < 1 )
@@ -372,7 +398,7 @@ public void Read(byte[] buffer, int offset, int count)
372398 {
373399 try
374400 {
375- int currentRead = _inputStream . Read ( buffer , offset , shouldRead ) ;
401+ int currentRead = ( _useSslStream ? _sslStream . Read ( buffer , offset , shouldRead ) : _inputStream . Read ( buffer , offset , shouldRead ) ) ;
376402 if ( currentRead == count )
377403 break ;
378404 if ( currentRead < 1 )
@@ -397,15 +423,34 @@ public void Write(byte[] data, int offset, int length)
397423 {
398424 this . CheckDisposed ( ) ;
399425
400- SocketError status ;
426+ if ( _useSslStream )
427+ {
428+ try
429+ {
430+ _inputStream . Write ( data , offset , length ) ;
431+ _inputStream . Flush ( ) ;
432+ }
433+ catch ( Exception ex )
434+ {
435+ if ( ex is IOException || ex is SocketException )
436+ {
437+ _isAlive = false ;
438+ }
439+ throw ;
440+ }
441+ }
442+ else
443+ {
444+ SocketError status ;
401445
402- _socket . Send ( data , offset , length , SocketFlags . None , out status ) ;
446+ _socket . Send ( data , offset , length , SocketFlags . None , out status ) ;
403447
404- if ( status != SocketError . Success )
405- {
406- _isAlive = false ;
448+ if ( status != SocketError . Success )
449+ {
450+ _isAlive = false ;
407451
408- ThrowHelper . ThrowSocketWriteError ( _endpoint , status ) ;
452+ ThrowHelper . ThrowSocketWriteError ( _endpoint , status ) ;
453+ }
409454 }
410455 }
411456
@@ -417,11 +462,22 @@ public void Write(IList<ArraySegment<byte>> buffers)
417462
418463 try
419464 {
420- _socket . Send ( buffers , SocketFlags . None , out status ) ;
421- if ( status != SocketError . Success )
465+ if ( _useSslStream )
422466 {
423- _isAlive = false ;
424- ThrowHelper . ThrowSocketWriteError ( _endpoint , status ) ;
467+ foreach ( var buf in buffers )
468+ {
469+ _sslStream . Write ( buf . Array ) ;
470+ }
471+ _sslStream . Flush ( ) ;
472+ }
473+ else
474+ {
475+ _socket . Send ( buffers , SocketFlags . None , out status ) ;
476+ if ( status != SocketError . Success )
477+ {
478+ _isAlive = false ;
479+ ThrowHelper . ThrowSocketWriteError ( _endpoint , status ) ;
480+ }
425481 }
426482 }
427483 catch ( Exception ex )
@@ -441,12 +497,23 @@ public async Task WriteAsync(IList<ArraySegment<byte>> buffers)
441497
442498 try
443499 {
444- var bytesTransferred = await _socket . SendAsync ( buffers , SocketFlags . None ) ;
445- if ( bytesTransferred <= 0 )
500+ if ( _useSslStream )
446501 {
447- _isAlive = false ;
448- _logger . LogError ( $ "Failed to { nameof ( PooledSocket . WriteAsync ) } . bytesTransferred: { bytesTransferred } ") ;
449- ThrowHelper . ThrowSocketWriteError ( _endpoint ) ;
502+ foreach ( var buf in buffers )
503+ {
504+ await _sslStream . WriteAsync ( buf . Array , 0 , buf . Count ) ;
505+ }
506+ await _sslStream . FlushAsync ( ) ;
507+ }
508+ else
509+ {
510+ var bytesTransferred = await _socket . SendAsync ( buffers , SocketFlags . None ) ;
511+ if ( bytesTransferred <= 0 )
512+ {
513+ _isAlive = false ;
514+ _logger . LogError ( $ "Failed to { nameof ( PooledSocket . WriteAsync ) } . bytesTransferred: { bytesTransferred } ") ;
515+ ThrowHelper . ThrowSocketWriteError ( _endpoint ) ;
516+ }
450517 }
451518 }
452519 catch ( Exception ex )
0 commit comments