@@ -16,6 +16,7 @@ use crate::{
1616 stats:: tikv_stats,
1717 store:: RegionStore ,
1818 transaction:: { resolve_locks, HasLocks } ,
19+ util:: iter:: FlatMapOkIterExt ,
1920 Error , Result ,
2021} ;
2122
@@ -63,6 +64,11 @@ pub struct RetryableMultiRegion<P: Plan, PdC: PdClient> {
6364 pub ( super ) inner : P ,
6465 pub pd_client : Arc < PdC > ,
6566 pub backoff : Backoff ,
67+
68+ /// Preserve all regions' results for other downstream plans to handle.
69+ /// If true, return Ok and preserve all regions' results, even if some of them are Err.
70+ /// Otherwise, return the first Err if there is any.
71+ pub preserve_region_results : bool ,
6672}
6773
6874impl < P : Plan + Shardable , PdC : PdClient > RetryableMultiRegion < P , PdC >
7682 current_plan : P ,
7783 backoff : Backoff ,
7884 permits : Arc < Semaphore > ,
85+ preserve_region_results : bool ,
7986 ) -> Result < <Self as Plan >:: Result > {
8087 let shards = current_plan. shards ( & pd_client) . collect :: < Vec < _ > > ( ) . await ;
8188 let mut handles = Vec :: new ( ) ;
@@ -89,16 +96,29 @@ where
8996 region_store,
9097 backoff. clone ( ) ,
9198 permits. clone ( ) ,
99+ preserve_region_results,
92100 ) ) ;
93101 handles. push ( handle) ;
94102 }
95- Ok ( try_join_all ( handles)
96- . await ?
97- . into_iter ( )
98- . collect :: < Result < Vec < _ > > > ( ) ?
99- . into_iter ( )
100- . flatten ( )
101- . collect ( ) )
103+
104+ let results = try_join_all ( handles) . await ?;
105+ if preserve_region_results {
106+ Ok ( results
107+ . into_iter ( )
108+ . flat_map_ok ( |x| x)
109+ . map ( |x| match x {
110+ Ok ( r) => r,
111+ Err ( e) => Err ( e) ,
112+ } )
113+ . collect ( ) )
114+ } else {
115+ Ok ( results
116+ . into_iter ( )
117+ . collect :: < Result < Vec < _ > > > ( ) ?
118+ . into_iter ( )
119+ . flatten ( )
120+ . collect ( ) )
121+ }
102122 }
103123
104124 #[ async_recursion]
@@ -108,6 +128,7 @@ where
108128 region_store : RegionStore ,
109129 mut backoff : Backoff ,
110130 permits : Arc < Semaphore > ,
131+ preserve_region_results : bool ,
111132 ) -> Result < <Self as Plan >:: Result > {
112133 // limit concurrent requests
113134 let permit = permits. acquire ( ) . await . unwrap ( ) ;
@@ -125,7 +146,14 @@ where
125146 if !region_error_resolved {
126147 futures_timer:: Delay :: new ( duration) . await ;
127148 }
128- Self :: single_plan_handler ( pd_client, plan, backoff, permits) . await
149+ Self :: single_plan_handler (
150+ pd_client,
151+ plan,
152+ backoff,
153+ permits,
154+ preserve_region_results,
155+ )
156+ . await
129157 }
130158 None => Err ( Error :: RegionError ( e) ) ,
131159 }
@@ -242,6 +270,7 @@ impl<P: Plan, PdC: PdClient> Clone for RetryableMultiRegion<P, PdC> {
242270 inner : self . inner . clone ( ) ,
243271 pd_client : self . pd_client . clone ( ) ,
244272 backoff : self . backoff . clone ( ) ,
273+ preserve_region_results : self . preserve_region_results ,
245274 }
246275 }
247276}
@@ -263,6 +292,7 @@ where
263292 self . inner . clone ( ) ,
264293 self . backoff . clone ( ) ,
265294 concurrency_permits. clone ( ) ,
295+ self . preserve_region_results ,
266296 )
267297 . await
268298 }
@@ -556,6 +586,7 @@ mod test {
556586 } ,
557587 pd_client : Arc :: new ( MockPdClient :: default ( ) ) ,
558588 backoff : Backoff :: no_backoff ( ) ,
589+ preserve_region_results : false ,
559590 } ;
560591 assert ! ( plan. execute( ) . await . is_err( ) )
561592 }
0 commit comments