Skip to content

Commit 915c7ff

Browse files
authored
feat(grpc): make ChildManager call the parent work_scheduler (#2443)
1 parent 7fafa9e commit 915c7ff

File tree

4 files changed

+202
-17
lines changed

4 files changed

+202
-17
lines changed

grpc/src/client/load_balancing/child_manager.rs

Lines changed: 178 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
2828
// TODO: This is mainly provided as a fairly complex example of the current LB
2929
// policy in use. Complete tests must be written before it can be used in
30-
// production. Also, support for the work scheduler is missing.
30+
// production.
3131

3232
use std::collections::HashSet;
3333
use std::fmt::Debug;
@@ -53,6 +53,7 @@ pub(crate) struct ChildManager<T: Debug, S: ResolverUpdateSharder<T>> {
5353
pending_work: Arc<Mutex<HashSet<usize>>>,
5454
runtime: Arc<dyn Runtime>,
5555
updated: bool, // Set when any child updates its picker; cleared when accessed.
56+
work_scheduler: Arc<dyn WorkScheduler>,
5657
}
5758

5859
#[non_exhaustive]
@@ -98,13 +99,18 @@ where
9899
{
99100
/// Creates a new ChildManager LB policy. shard_update is called whenever a
100101
/// resolver_update operation occurs.
101-
pub fn new(update_sharder: S, runtime: Arc<dyn Runtime>) -> Self {
102+
pub fn new(
103+
update_sharder: S,
104+
runtime: Arc<dyn Runtime>,
105+
work_scheduler: Arc<dyn WorkScheduler>,
106+
) -> Self {
102107
Self {
103108
update_sharder,
104109
subchannel_to_child_idx: Default::default(),
105110
children: Default::default(),
106111
pending_work: Default::default(),
107112
runtime,
113+
work_scheduler,
108114
updated: false,
109115
}
110116
}
@@ -272,6 +278,7 @@ where
272278
let work_scheduler = Arc::new(ChildWorkScheduler {
273279
pending_work: self.pending_work.clone(),
274280
idx: Mutex::new(Some(new_idx)),
281+
work_scheduler: self.work_scheduler.clone(),
275282
});
276283
let policy = builder.build(LbPolicyOptions {
277284
work_scheduler: work_scheduler.clone(),
@@ -395,8 +402,9 @@ impl ChannelController for WrappedController<'_> {
395402

396403
#[derive(Debug)]
397404
struct ChildWorkScheduler {
405+
work_scheduler: Arc<dyn WorkScheduler>, // The real work scheduler of the channel.
398406
pending_work: Arc<Mutex<HashSet<usize>>>, // Must be taken first for correctness
399-
idx: Mutex<Option<usize>>, // None if the child is deleted.
407+
idx: Mutex<Option<usize>>, // None if the child is deleted.
400408
}
401409

402410
impl WorkScheduler for ChildWorkScheduler {
@@ -405,6 +413,12 @@ impl WorkScheduler for ChildWorkScheduler {
405413
if let Some(idx) = *self.idx.lock().unwrap() {
406414
pending_work.insert(idx);
407415
}
416+
// Call the real work scheduler with the lock held to avoid a scenario
417+
// where we schedule work and get called before the lock can be taken,
418+
// and to avoid the scenario where the child is called before the
419+
// schedule_work call is done due to a concurrent call to
420+
// ChildManager::work().
421+
self.work_scheduler.schedule_work();
408422
}
409423
}
410424

@@ -414,7 +428,7 @@ mod test {
414428
ChildManager, ChildUpdate, ResolverUpdateSharder,
415429
};
416430
use crate::client::load_balancing::test_utils::{
417-
self, StubPolicyData, StubPolicyFuncs, TestChannelController, TestEvent,
431+
self, StubPolicyFuncs, TestChannelController, TestEvent, TestWorkScheduler,
418432
};
419433
use crate::client::load_balancing::{
420434
ChannelController, LbPolicy, LbPolicyBuilder, LbState, QueuingPicker, Subchannel,
@@ -424,9 +438,11 @@ mod test {
424438
use crate::client::service_config::LbConfig;
425439
use crate::client::ConnectivityState;
426440
use crate::rt::default_runtime;
441+
use std::collections::HashMap;
427442
use std::error::Error;
428443
use std::panic;
429444
use std::sync::Arc;
445+
use std::sync::Mutex;
430446
use tokio::sync::mpsc;
431447

432448
// TODO: This needs to be moved to a common place that can be shared between
@@ -492,10 +508,16 @@ mod test {
492508
) {
493509
test_utils::reg_stub_policy(test_name, funcs);
494510
let (tx_events, rx_events) = mpsc::unbounded_channel::<TestEvent>();
495-
let tcc = Box::new(TestChannelController { tx_events });
511+
let tcc = Box::new(TestChannelController {
512+
tx_events: tx_events.clone(),
513+
});
496514
let builder: Arc<dyn LbPolicyBuilder> = GLOBAL_LB_REGISTRY.get_policy(test_name).unwrap();
497515
let endpoint_sharder = EndpointSharder { builder };
498-
let child_manager = ChildManager::new(endpoint_sharder, default_runtime());
516+
let child_manager = ChildManager::new(
517+
endpoint_sharder,
518+
default_runtime(),
519+
Arc::new(TestWorkScheduler { tx_events }),
520+
);
499521
(rx_events, child_manager, tcc)
500522
}
501523

@@ -567,7 +589,6 @@ mod test {
567589
// Defines the functions resolver_update and subchannel_update to test
568590
// aggregate_states.
569591
fn create_verifying_funcs_for_aggregate_tests() -> StubPolicyFuncs {
570-
let data = StubPolicyData::new();
571592
StubPolicyFuncs {
572593
// Closure for resolver_update. resolver_update should only receive
573594
// one endpoint and create one subchannel for the endpoint it
@@ -590,6 +611,7 @@ mod test {
590611
});
591612
},
592613
)),
614+
work: None,
593615
}
594616
}
595617

@@ -759,4 +781,153 @@ mod test {
759781
ConnectivityState::TransientFailure
760782
);
761783
}
784+
785+
struct ScheduleWorkStubData {
786+
requested_work: bool,
787+
}
788+
789+
fn create_funcs_for_schedule_work_tests(name: &'static str) -> StubPolicyFuncs {
790+
StubPolicyFuncs {
791+
resolver_update: Some(Arc::new(move |data, _update, lbcfg, _controller| {
792+
if data.test_data.is_none() {
793+
data.test_data = Some(Box::new(ScheduleWorkStubData {
794+
requested_work: false,
795+
}));
796+
}
797+
let stubdata = data
798+
.test_data
799+
.as_mut()
800+
.unwrap()
801+
.downcast_mut::<ScheduleWorkStubData>()
802+
.unwrap();
803+
assert!(!stubdata.requested_work);
804+
if lbcfg
805+
.unwrap()
806+
.convert_to::<Mutex<HashMap<&'static str, ()>>>()
807+
.unwrap()
808+
.lock()
809+
.unwrap()
810+
.contains_key(name)
811+
{
812+
stubdata.requested_work = true;
813+
data.lb_policy_options.work_scheduler.schedule_work();
814+
}
815+
Ok(())
816+
})),
817+
subchannel_update: None,
818+
work: Some(Arc::new(move |data, _controller| {
819+
println!("work called for {name}");
820+
let stubdata = data
821+
.test_data
822+
.as_mut()
823+
.unwrap()
824+
.downcast_mut::<ScheduleWorkStubData>()
825+
.unwrap();
826+
stubdata.requested_work = false;
827+
})),
828+
}
829+
}
830+
831+
#[derive(Debug)]
832+
struct ScheduleWorkSharder {
833+
names: Vec<&'static str>,
834+
}
835+
836+
impl ResolverUpdateSharder<()> for ScheduleWorkSharder {
837+
fn shard_update(
838+
&mut self,
839+
resolver_update: ResolverUpdate,
840+
update: Option<&LbConfig>,
841+
) -> Result<impl Iterator<Item = ChildUpdate<()>>, Box<dyn Error + Send + Sync>> {
842+
let mut res = Vec::with_capacity(self.names.len());
843+
for name in &self.names {
844+
let child_policy_builder: Arc<dyn LbPolicyBuilder> =
845+
GLOBAL_LB_REGISTRY.get_policy(name).unwrap();
846+
res.push(ChildUpdate {
847+
child_identifier: (),
848+
child_policy_builder,
849+
child_update: Some((ResolverUpdate::default(), update.cloned())),
850+
});
851+
}
852+
Ok(res.into_iter())
853+
}
854+
}
855+
856+
// Tests that the child manager properly delegates to the children that
857+
// called schedule_work when work is called.
858+
#[tokio::test]
859+
async fn childmanager_schedule_work_works() {
860+
let name1 = "childmanager_schedule_work_works-one";
861+
let name2 = "childmanager_schedule_work_works-two";
862+
test_utils::reg_stub_policy(name1, create_funcs_for_schedule_work_tests(name1));
863+
test_utils::reg_stub_policy(name2, create_funcs_for_schedule_work_tests(name2));
864+
865+
let (tx_events, mut rx_events) = mpsc::unbounded_channel::<TestEvent>();
866+
let mut tcc = TestChannelController {
867+
tx_events: tx_events.clone(),
868+
};
869+
870+
let sharder = ScheduleWorkSharder {
871+
names: vec![name1, name2],
872+
};
873+
let mut child_manager = ChildManager::new(
874+
sharder,
875+
default_runtime(),
876+
Arc::new(TestWorkScheduler { tx_events }),
877+
);
878+
879+
// Request that child one requests work.
880+
let cfg = LbConfig::new(Mutex::new(HashMap::<&'static str, ()>::new()));
881+
let children = cfg
882+
.convert_to::<Mutex<HashMap<&'static str, ()>>>()
883+
.unwrap();
884+
children.lock().unwrap().insert(name1, ());
885+
886+
child_manager
887+
.resolver_update(ResolverUpdate::default(), Some(&cfg), &mut tcc)
888+
.unwrap();
889+
890+
// Confirm that child one has requested work.
891+
match rx_events.recv().await.unwrap() {
892+
TestEvent::ScheduleWork => {}
893+
other => panic!("unexpected event {:?}", other),
894+
};
895+
assert_eq!(child_manager.pending_work.lock().unwrap().len(), 1);
896+
let idx = *child_manager
897+
.pending_work
898+
.lock()
899+
.unwrap()
900+
.iter()
901+
.next()
902+
.unwrap();
903+
assert_eq!(child_manager.children[idx].builder.name(), name1);
904+
905+
// Perform the work call and assert the pending_work set is empty.
906+
child_manager.work(&mut tcc);
907+
assert_eq!(child_manager.pending_work.lock().unwrap().len(), 0);
908+
909+
// Now have both children request work.
910+
children.lock().unwrap().insert(name2, ());
911+
912+
child_manager
913+
.resolver_update(ResolverUpdate::default(), Some(&cfg), &mut tcc)
914+
.unwrap();
915+
916+
// Confirm that both children requested work.
917+
match rx_events.recv().await.unwrap() {
918+
TestEvent::ScheduleWork => {}
919+
other => panic!("unexpected event {:?}", other),
920+
};
921+
assert_eq!(child_manager.pending_work.lock().unwrap().len(), 2);
922+
923+
// Perform the work call and assert the pending_work set is empty.
924+
child_manager.work(&mut tcc);
925+
assert_eq!(child_manager.pending_work.lock().unwrap().len(), 0);
926+
927+
// Perform one final call to resolver_update which asserts that both
928+
// child policies had their work methods called.
929+
child_manager
930+
.resolver_update(ResolverUpdate::default(), Some(&cfg), &mut tcc)
931+
.unwrap();
932+
}
762933
}

grpc/src/client/load_balancing/graceful_switch.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use crate::client::load_balancing::child_manager::{
33
};
44
use crate::client::load_balancing::{
55
ChannelController, LbConfig, LbPolicy, LbPolicyBuilder, LbState, ParsedJsonLbConfig,
6-
Subchannel, SubchannelState, GLOBAL_LB_REGISTRY,
6+
Subchannel, SubchannelState, WorkScheduler, GLOBAL_LB_REGISTRY,
77
};
88
use crate::client::name_resolution::ResolverUpdate;
99
use crate::client::ConnectivityState;
@@ -150,9 +150,9 @@ enum ChildKind {
150150

151151
impl GracefulSwitchPolicy {
152152
/// Creates a new Graceful Switch policy.
153-
pub fn new(runtime: Arc<dyn Runtime>) -> Self {
153+
pub fn new(runtime: Arc<dyn Runtime>, work_scheduler: Arc<dyn WorkScheduler>) -> Self {
154154
GracefulSwitchPolicy {
155-
child_manager: ChildManager::new(UpdateSharder::new(), runtime),
155+
child_manager: ChildManager::new(UpdateSharder::new(), runtime, work_scheduler),
156156
last_update: LbState::initial(),
157157
}
158158
}
@@ -372,6 +372,7 @@ mod test {
372372
});
373373
},
374374
)),
375+
work: None,
375376
}
376377
}
377378

@@ -400,9 +401,12 @@ mod test {
400401
tx_events: tx_events.clone(),
401402
});
402403

403-
let tcc = Box::new(TestChannelController { tx_events });
404+
let tcc = Box::new(TestChannelController {
405+
tx_events: tx_events.clone(),
406+
});
404407

405-
let graceful_switch = GracefulSwitchPolicy::new(default_runtime());
408+
let graceful_switch =
409+
GracefulSwitchPolicy::new(default_runtime(), Arc::new(TestWorkScheduler { tx_events }));
406410
(rx_events, Box::new(graceful_switch), tcc)
407411
}
408412

grpc/src/client/load_balancing/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ pub(crate) use registry::GLOBAL_LB_REGISTRY;
5959

6060
/// A collection of data configured on the channel that is constructing this
6161
/// LbPolicy.
62+
#[derive(Debug)]
6263
pub(crate) struct LbPolicyOptions {
6364
/// A hook into the channel's work scheduler that allows the LbPolicy to
6465
/// request the ability to perform operations on the ChannelController.

grpc/src/client/load_balancing/test_utils.rs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,15 @@ type SubchannelUpdateFn = Arc<
170170
+ Sync,
171171
>;
172172

173+
type WorkFn = Arc<dyn Fn(&mut StubPolicyData, &mut dyn ChannelController) + Send + Sync>;
174+
173175
/// This struct holds `LbPolicy` trait stub functions that tests are expected to
174176
/// implement.
175177
#[derive(Clone)]
176178
pub(crate) struct StubPolicyFuncs {
177179
pub resolver_update: Option<ResolverUpdateFn>,
178180
pub subchannel_update: Option<SubchannelUpdateFn>,
181+
pub work: Option<WorkFn>,
179182
}
180183

181184
impl Debug for StubPolicyFuncs {
@@ -187,13 +190,17 @@ impl Debug for StubPolicyFuncs {
187190
/// Data holds test data that will be passed all to functions in PolicyFuncs
188191
#[derive(Debug)]
189192
pub(crate) struct StubPolicyData {
193+
pub lb_policy_options: LbPolicyOptions,
190194
pub test_data: Option<Box<dyn Any + Send + Sync>>,
191195
}
192196

193197
impl StubPolicyData {
194198
/// Creates an instance of StubPolicyData.
195-
pub fn new() -> Self {
196-
Self { test_data: None }
199+
pub fn new(lb_policy_options: LbPolicyOptions) -> Self {
200+
Self {
201+
test_data: None,
202+
lb_policy_options,
203+
}
197204
}
198205
}
199206

@@ -232,8 +239,10 @@ impl LbPolicy for StubPolicy {
232239
todo!("Implement exit_idle for StubPolicy")
233240
}
234241

235-
fn work(&mut self, _channel_controller: &mut dyn ChannelController) {
236-
todo!("Implement work for StubPolicy")
242+
fn work(&mut self, channel_controller: &mut dyn ChannelController) {
243+
if let Some(f) = &self.funcs.work {
244+
f(&mut self.data, channel_controller);
245+
}
237246
}
238247
}
239248

@@ -252,7 +261,7 @@ pub(super) struct MockConfig {
252261

253262
impl LbPolicyBuilder for StubPolicyBuilder {
254263
fn build(&self, options: LbPolicyOptions) -> Box<dyn LbPolicy> {
255-
let data = StubPolicyData::new();
264+
let data = StubPolicyData::new(options);
256265
Box::new(StubPolicy {
257266
funcs: self.funcs.clone(),
258267
data,

0 commit comments

Comments
 (0)