@@ -299,6 +299,7 @@ impl Client {
299299 }
300300
301301 let stream_description = connection. stream_description ( ) ?;
302+ let is_sharded = stream_description. initial_server_type == ServerType :: Mongos ;
302303 let mut cmd = op. build ( stream_description) ?;
303304 self . inner
304305 . topology
@@ -337,15 +338,22 @@ impl Client {
337338 cmd. set_start_transaction ( ) ;
338339 cmd. set_autocommit ( ) ;
339340 cmd. set_txn_read_concern ( * session) ?;
340- if stream_description . initial_server_type == ServerType :: Mongos {
341+ if is_sharded {
341342 session. pin_mongos ( connection. address ( ) . clone ( ) ) ;
342343 }
343344 session. transaction . state = TransactionState :: InProgress ;
344345 }
345- TransactionState :: InProgress
346- | TransactionState :: Committed { .. }
347- | TransactionState :: Aborted => {
346+ TransactionState :: InProgress => cmd. set_autocommit ( ) ,
347+ TransactionState :: Committed { .. } | TransactionState :: Aborted => {
348348 cmd. set_autocommit ( ) ;
349+
350+ // Append the recovery token to the command if we are committing or aborting
351+ // on a sharded transaction.
352+ if is_sharded {
353+ if let Some ( ref recovery_token) = session. transaction . recovery_token {
354+ cmd. set_recovery_token ( recovery_token) ;
355+ }
356+ }
349357 }
350358 _ => { }
351359 }
@@ -414,6 +422,9 @@ impl Client {
414422 Ok ( r) => {
415423 self . update_cluster_time ( & r, session) . await ;
416424 if r. is_success ( ) {
425+ // Retrieve recovery token from successful response.
426+ Client :: update_recovery_token ( is_sharded, & r, session) . await ;
427+
417428 Ok ( CommandResult {
418429 raw : response,
419430 deserialized : r. into_body ( ) ,
@@ -458,7 +469,15 @@ impl Client {
458469 } ) )
459470 }
460471 // for ok: 1 just return the original deserialization error.
461- _ => Err ( deserialize_error) ,
472+ _ => {
473+ Client :: update_recovery_token (
474+ is_sharded,
475+ & error_response,
476+ session,
477+ )
478+ . await ;
479+ Err ( deserialize_error)
480+ }
462481 }
463482 }
464483 // We failed to deserialize even that, so just return the original
@@ -635,6 +654,18 @@ impl Client {
635654 }
636655 }
637656 }
657+
658+ async fn update_recovery_token < T : Response > (
659+ is_sharded : bool ,
660+ response : & T ,
661+ session : & mut Option < & mut ClientSession > ,
662+ ) {
663+ if let Some ( ref mut session) = session {
664+ if is_sharded && session. in_transaction ( ) {
665+ session. transaction . recovery_token = response. recovery_token ( ) . cloned ( ) ;
666+ }
667+ }
668+ }
638669}
639670
640671impl Error {
0 commit comments