2727
2828// TODO: This is mainly provided as a fairly complex example of the current LB
2929// policy in use. Complete tests must be written before it can be used in
30- // production. Also, support for the work scheduler is missing.
30+ // production.
3131
3232use std:: collections:: HashSet ;
3333use std:: fmt:: Debug ;
@@ -53,6 +53,7 @@ pub(crate) struct ChildManager<T: Debug, S: ResolverUpdateSharder<T>> {
5353 pending_work : Arc < Mutex < HashSet < usize > > > ,
5454 runtime : Arc < dyn Runtime > ,
5555 updated : bool , // Set when any child updates its picker; cleared when accessed.
56+ work_scheduler : Arc < dyn WorkScheduler > ,
5657}
5758
5859#[ non_exhaustive]
@@ -98,13 +99,18 @@ where
9899{
99100 /// Creates a new ChildManager LB policy. shard_update is called whenever a
100101 /// resolver_update operation occurs.
101- pub fn new ( update_sharder : S , runtime : Arc < dyn Runtime > ) -> Self {
102+ pub fn new (
103+ update_sharder : S ,
104+ runtime : Arc < dyn Runtime > ,
105+ work_scheduler : Arc < dyn WorkScheduler > ,
106+ ) -> Self {
102107 Self {
103108 update_sharder,
104109 subchannel_to_child_idx : Default :: default ( ) ,
105110 children : Default :: default ( ) ,
106111 pending_work : Default :: default ( ) ,
107112 runtime,
113+ work_scheduler,
108114 updated : false ,
109115 }
110116 }
@@ -272,6 +278,7 @@ where
272278 let work_scheduler = Arc :: new ( ChildWorkScheduler {
273279 pending_work : self . pending_work . clone ( ) ,
274280 idx : Mutex :: new ( Some ( new_idx) ) ,
281+ work_scheduler : self . work_scheduler . clone ( ) ,
275282 } ) ;
276283 let policy = builder. build ( LbPolicyOptions {
277284 work_scheduler : work_scheduler. clone ( ) ,
@@ -395,8 +402,9 @@ impl ChannelController for WrappedController<'_> {
395402
396403#[ derive( Debug ) ]
397404struct ChildWorkScheduler {
405+ work_scheduler : Arc < dyn WorkScheduler > , // The real work scheduler of the channel.
398406 pending_work : Arc < Mutex < HashSet < usize > > > , // Must be taken first for correctness
399- idx : Mutex < Option < usize > > , // None if the child is deleted.
407+ idx : Mutex < Option < usize > > , // None if the child is deleted.
400408}
401409
402410impl WorkScheduler for ChildWorkScheduler {
@@ -405,6 +413,12 @@ impl WorkScheduler for ChildWorkScheduler {
405413 if let Some ( idx) = * self . idx . lock ( ) . unwrap ( ) {
406414 pending_work. insert ( idx) ;
407415 }
416+ // Call the real work scheduler with the lock held to avoid a scenario
417+ // where we schedule work and get called before the lock can be taken,
418+ // and to avoid the scenario where the child is called before the
419+ // schedule_work call is done due to a concurrent call to
420+ // ChildManager::work().
421+ self . work_scheduler . schedule_work ( ) ;
408422 }
409423}
410424
@@ -414,7 +428,7 @@ mod test {
414428 ChildManager , ChildUpdate , ResolverUpdateSharder ,
415429 } ;
416430 use crate :: client:: load_balancing:: test_utils:: {
417- self , StubPolicyData , StubPolicyFuncs , TestChannelController , TestEvent ,
431+ self , StubPolicyFuncs , TestChannelController , TestEvent , TestWorkScheduler ,
418432 } ;
419433 use crate :: client:: load_balancing:: {
420434 ChannelController , LbPolicy , LbPolicyBuilder , LbState , QueuingPicker , Subchannel ,
@@ -424,9 +438,11 @@ mod test {
424438 use crate :: client:: service_config:: LbConfig ;
425439 use crate :: client:: ConnectivityState ;
426440 use crate :: rt:: default_runtime;
441+ use std:: collections:: HashMap ;
427442 use std:: error:: Error ;
428443 use std:: panic;
429444 use std:: sync:: Arc ;
445+ use std:: sync:: Mutex ;
430446 use tokio:: sync:: mpsc;
431447
432448 // TODO: This needs to be moved to a common place that can be shared between
@@ -492,10 +508,16 @@ mod test {
492508 ) {
493509 test_utils:: reg_stub_policy ( test_name, funcs) ;
494510 let ( tx_events, rx_events) = mpsc:: unbounded_channel :: < TestEvent > ( ) ;
495- let tcc = Box :: new ( TestChannelController { tx_events } ) ;
511+ let tcc = Box :: new ( TestChannelController {
512+ tx_events : tx_events. clone ( ) ,
513+ } ) ;
496514 let builder: Arc < dyn LbPolicyBuilder > = GLOBAL_LB_REGISTRY . get_policy ( test_name) . unwrap ( ) ;
497515 let endpoint_sharder = EndpointSharder { builder } ;
498- let child_manager = ChildManager :: new ( endpoint_sharder, default_runtime ( ) ) ;
516+ let child_manager = ChildManager :: new (
517+ endpoint_sharder,
518+ default_runtime ( ) ,
519+ Arc :: new ( TestWorkScheduler { tx_events } ) ,
520+ ) ;
499521 ( rx_events, child_manager, tcc)
500522 }
501523
@@ -567,7 +589,6 @@ mod test {
567589 // Defines the functions resolver_update and subchannel_update to test
568590 // aggregate_states.
569591 fn create_verifying_funcs_for_aggregate_tests ( ) -> StubPolicyFuncs {
570- let data = StubPolicyData :: new ( ) ;
571592 StubPolicyFuncs {
572593 // Closure for resolver_update. resolver_update should only receive
573594 // one endpoint and create one subchannel for the endpoint it
@@ -590,6 +611,7 @@ mod test {
590611 } ) ;
591612 } ,
592613 ) ) ,
614+ work : None ,
593615 }
594616 }
595617
@@ -759,4 +781,153 @@ mod test {
759781 ConnectivityState :: TransientFailure
760782 ) ;
761783 }
784+
785+ struct ScheduleWorkStubData {
786+ requested_work : bool ,
787+ }
788+
789+ fn create_funcs_for_schedule_work_tests ( name : & ' static str ) -> StubPolicyFuncs {
790+ StubPolicyFuncs {
791+ resolver_update : Some ( Arc :: new ( move |data, _update, lbcfg, _controller| {
792+ if data. test_data . is_none ( ) {
793+ data. test_data = Some ( Box :: new ( ScheduleWorkStubData {
794+ requested_work : false ,
795+ } ) ) ;
796+ }
797+ let stubdata = data
798+ . test_data
799+ . as_mut ( )
800+ . unwrap ( )
801+ . downcast_mut :: < ScheduleWorkStubData > ( )
802+ . unwrap ( ) ;
803+ assert ! ( !stubdata. requested_work) ;
804+ if lbcfg
805+ . unwrap ( )
806+ . convert_to :: < Mutex < HashMap < & ' static str , ( ) > > > ( )
807+ . unwrap ( )
808+ . lock ( )
809+ . unwrap ( )
810+ . contains_key ( name)
811+ {
812+ stubdata. requested_work = true ;
813+ data. lb_policy_options . work_scheduler . schedule_work ( ) ;
814+ }
815+ Ok ( ( ) )
816+ } ) ) ,
817+ subchannel_update : None ,
818+ work : Some ( Arc :: new ( move |data, _controller| {
819+ println ! ( "work called for {name}" ) ;
820+ let stubdata = data
821+ . test_data
822+ . as_mut ( )
823+ . unwrap ( )
824+ . downcast_mut :: < ScheduleWorkStubData > ( )
825+ . unwrap ( ) ;
826+ stubdata. requested_work = false ;
827+ } ) ) ,
828+ }
829+ }
830+
831+ #[ derive( Debug ) ]
832+ struct ScheduleWorkSharder {
833+ names : Vec < & ' static str > ,
834+ }
835+
836+ impl ResolverUpdateSharder < ( ) > for ScheduleWorkSharder {
837+ fn shard_update (
838+ & mut self ,
839+ resolver_update : ResolverUpdate ,
840+ update : Option < & LbConfig > ,
841+ ) -> Result < impl Iterator < Item = ChildUpdate < ( ) > > , Box < dyn Error + Send + Sync > > {
842+ let mut res = Vec :: with_capacity ( self . names . len ( ) ) ;
843+ for name in & self . names {
844+ let child_policy_builder: Arc < dyn LbPolicyBuilder > =
845+ GLOBAL_LB_REGISTRY . get_policy ( name) . unwrap ( ) ;
846+ res. push ( ChildUpdate {
847+ child_identifier : ( ) ,
848+ child_policy_builder,
849+ child_update : Some ( ( ResolverUpdate :: default ( ) , update. cloned ( ) ) ) ,
850+ } ) ;
851+ }
852+ Ok ( res. into_iter ( ) )
853+ }
854+ }
855+
856+ // Tests that the child manager properly delegates to the children that
857+ // called schedule_work when work is called.
858+ #[ tokio:: test]
859+ async fn childmanager_schedule_work_works ( ) {
860+ let name1 = "childmanager_schedule_work_works-one" ;
861+ let name2 = "childmanager_schedule_work_works-two" ;
862+ test_utils:: reg_stub_policy ( name1, create_funcs_for_schedule_work_tests ( name1) ) ;
863+ test_utils:: reg_stub_policy ( name2, create_funcs_for_schedule_work_tests ( name2) ) ;
864+
865+ let ( tx_events, mut rx_events) = mpsc:: unbounded_channel :: < TestEvent > ( ) ;
866+ let mut tcc = TestChannelController {
867+ tx_events : tx_events. clone ( ) ,
868+ } ;
869+
870+ let sharder = ScheduleWorkSharder {
871+ names : vec ! [ name1, name2] ,
872+ } ;
873+ let mut child_manager = ChildManager :: new (
874+ sharder,
875+ default_runtime ( ) ,
876+ Arc :: new ( TestWorkScheduler { tx_events } ) ,
877+ ) ;
878+
879+ // Request that child one requests work.
880+ let cfg = LbConfig :: new ( Mutex :: new ( HashMap :: < & ' static str , ( ) > :: new ( ) ) ) ;
881+ let children = cfg
882+ . convert_to :: < Mutex < HashMap < & ' static str , ( ) > > > ( )
883+ . unwrap ( ) ;
884+ children. lock ( ) . unwrap ( ) . insert ( name1, ( ) ) ;
885+
886+ child_manager
887+ . resolver_update ( ResolverUpdate :: default ( ) , Some ( & cfg) , & mut tcc)
888+ . unwrap ( ) ;
889+
890+ // Confirm that child one has requested work.
891+ match rx_events. recv ( ) . await . unwrap ( ) {
892+ TestEvent :: ScheduleWork => { }
893+ other => panic ! ( "unexpected event {:?}" , other) ,
894+ } ;
895+ assert_eq ! ( child_manager. pending_work. lock( ) . unwrap( ) . len( ) , 1 ) ;
896+ let idx = * child_manager
897+ . pending_work
898+ . lock ( )
899+ . unwrap ( )
900+ . iter ( )
901+ . next ( )
902+ . unwrap ( ) ;
903+ assert_eq ! ( child_manager. children[ idx] . builder. name( ) , name1) ;
904+
905+ // Perform the work call and assert the pending_work set is empty.
906+ child_manager. work ( & mut tcc) ;
907+ assert_eq ! ( child_manager. pending_work. lock( ) . unwrap( ) . len( ) , 0 ) ;
908+
909+ // Now have both children request work.
910+ children. lock ( ) . unwrap ( ) . insert ( name2, ( ) ) ;
911+
912+ child_manager
913+ . resolver_update ( ResolverUpdate :: default ( ) , Some ( & cfg) , & mut tcc)
914+ . unwrap ( ) ;
915+
916+ // Confirm that both children requested work.
917+ match rx_events. recv ( ) . await . unwrap ( ) {
918+ TestEvent :: ScheduleWork => { }
919+ other => panic ! ( "unexpected event {:?}" , other) ,
920+ } ;
921+ assert_eq ! ( child_manager. pending_work. lock( ) . unwrap( ) . len( ) , 2 ) ;
922+
923+ // Perform the work call and assert the pending_work set is empty.
924+ child_manager. work ( & mut tcc) ;
925+ assert_eq ! ( child_manager. pending_work. lock( ) . unwrap( ) . len( ) , 0 ) ;
926+
927+ // Perform one final call to resolver_update which asserts that both
928+ // child policies had their work methods called.
929+ child_manager
930+ . resolver_update ( ResolverUpdate :: default ( ) , Some ( & cfg) , & mut tcc)
931+ . unwrap ( ) ;
932+ }
762933}
0 commit comments