@@ -211,22 +211,20 @@ private SslStreamSettings GetTlsStreamSettings(string kmsProvider)
211211
212212 private void ProcessNeedKmsState ( CryptContext context , CancellationToken cancellationToken )
213213 {
214- var requests = context . GetKmsMessageRequests ( ) ;
215- foreach ( var request in requests )
214+ while ( context . GetNextKmsMessageRequest ( ) is { } request )
216215 {
217216 SendKmsRequest ( request , cancellationToken ) ;
218217 }
219- requests . MarkDone ( ) ;
218+ context . MarkKmsDone ( ) ;
220219 }
221220
222221 private async Task ProcessNeedKmsStateAsync ( CryptContext context , CancellationToken cancellationToken )
223222 {
224- var requests = context . GetKmsMessageRequests ( ) ;
225- foreach ( var request in requests )
223+ while ( context . GetNextKmsMessageRequest ( ) is { } request )
226224 {
227225 await SendKmsRequestAsync ( request , cancellationToken ) . ConfigureAwait ( false ) ;
228226 }
229- requests . MarkDone ( ) ;
227+ context . MarkKmsDone ( ) ;
230228 }
231229
232230 private void ProcessNeedMongoKeysState ( CryptContext context , CancellationToken cancellationToken )
@@ -278,48 +276,90 @@ private static byte[] ProcessReadyState(CryptContext context)
278276
279277 private void SendKmsRequest ( KmsRequest request , CancellationToken cancellation )
280278 {
281- var endpoint = CreateKmsEndPoint ( request . Endpoint ) ;
282-
283- var tlsStreamSettings = GetTlsStreamSettings ( request . KmsProvider ) ;
284- var sslStreamFactory = new SslStreamFactory ( tlsStreamSettings , _networkStreamFactory ) ;
285- using ( var sslStream = sslStreamFactory . CreateStream ( endpoint , cancellation ) )
286- using ( var binary = request . GetMessage ( ) )
279+ try
287280 {
281+ var endpoint = CreateKmsEndPoint ( request . Endpoint ) ;
282+
283+ var tlsStreamSettings = GetTlsStreamSettings ( request . KmsProvider ) ;
284+ var sslStreamFactory = new SslStreamFactory ( tlsStreamSettings , _networkStreamFactory ) ;
285+ using var sslStream = sslStreamFactory . CreateStream ( endpoint , cancellation ) ;
286+
287+ var sleepMs = request . Sleep ;
288+ if ( sleepMs > 0 )
289+ {
290+ Thread . Sleep ( sleepMs ) ;
291+ }
292+
293+ using var binary = request . GetMessage ( ) ;
288294 var requestBytes = binary . ToArray ( ) ;
289295 sslStream . Write ( requestBytes , 0 , requestBytes . Length ) ;
290296
291297 while ( request . BytesNeeded > 0 )
292298 {
293299 var buffer = new byte [ request . BytesNeeded ] ; // BytesNeeded is the maximum number of bytes that libmongocrypt wants to receive.
294300 var count = sslStream . Read ( buffer , 0 , buffer . Length ) ;
301+
302+ if ( count == 0 )
303+ {
304+ throw new IOException ( "Unexpected end of stream. No data was read from the SSL stream." ) ;
305+ }
306+
295307 var responseBytes = new byte [ count ] ;
296308 Buffer . BlockCopy ( buffer , 0 , responseBytes , 0 , count ) ;
297309 request . Feed ( responseBytes ) ;
298310 }
299311 }
312+ catch ( Exception ex ) when ( ex is IOException or SocketException )
313+ {
314+ if ( ! request . Fail ( ) )
315+ {
316+ throw ;
317+ }
318+ }
300319 }
301320
302321 private async Task SendKmsRequestAsync ( KmsRequest request , CancellationToken cancellation )
303322 {
304- var endpoint = CreateKmsEndPoint ( request . Endpoint ) ;
305-
306- var tlsStreamSettings = GetTlsStreamSettings ( request . KmsProvider ) ;
307- var sslStreamFactory = new SslStreamFactory ( tlsStreamSettings , _networkStreamFactory ) ;
308- using ( var sslStream = await sslStreamFactory . CreateStreamAsync ( endpoint , cancellation ) . ConfigureAwait ( false ) )
309- using ( var binary = request . GetMessage ( ) )
323+ try
310324 {
325+ var endpoint = CreateKmsEndPoint ( request . Endpoint ) ;
326+
327+ var tlsStreamSettings = GetTlsStreamSettings ( request . KmsProvider ) ;
328+ var sslStreamFactory = new SslStreamFactory ( tlsStreamSettings , _networkStreamFactory ) ;
329+ using var sslStream = await sslStreamFactory . CreateStreamAsync ( endpoint , cancellation ) . ConfigureAwait ( false ) ;
330+
331+ var sleepMs = request . Sleep ;
332+ if ( sleepMs > 0 )
333+ {
334+ await Task . Delay ( sleepMs , cancellation ) . ConfigureAwait ( false ) ;
335+ }
336+
337+ using var binary = request . GetMessage ( ) ;
311338 var requestBytes = binary . ToArray ( ) ;
312339 await sslStream . WriteAsync ( requestBytes , 0 , requestBytes . Length ) . ConfigureAwait ( false ) ;
313340
314341 while ( request . BytesNeeded > 0 )
315342 {
316343 var buffer = new byte [ request . BytesNeeded ] ; // BytesNeeded is the maximum number of bytes that libmongocrypt wants to receive.
317344 var count = await sslStream . ReadAsync ( buffer , 0 , buffer . Length ) . ConfigureAwait ( false ) ;
345+
346+ if ( count == 0 )
347+ {
348+ throw new IOException ( "Unexpected end of stream. No data was read from the SSL stream." ) ;
349+ }
350+
318351 var responseBytes = new byte [ count ] ;
319352 Buffer . BlockCopy ( buffer , 0 , responseBytes , 0 , count ) ;
320353 request . Feed ( responseBytes ) ;
321354 }
322355 }
356+ catch ( Exception ex ) when ( ex is IOException or SocketException )
357+ {
358+ if ( ! request . Fail ( ) )
359+ {
360+ throw ;
361+ }
362+ }
323363 }
324364
325365 // nested type
0 commit comments