|
1 | | -use broadcaster::BroadcastChannel; |
2 | | - |
3 | | -use crate::sync::Mutex; |
| 1 | +use crate::sync::{Condvar,Mutex}; |
4 | 2 |
|
5 | 3 | /// A barrier enables multiple tasks to synchronize the beginning |
6 | 4 | /// of some computation. |
@@ -36,14 +34,13 @@ use crate::sync::Mutex; |
36 | 34 | #[derive(Debug)] |
37 | 35 | pub struct Barrier { |
38 | 36 | state: Mutex<BarrierState>, |
39 | | - wait: BroadcastChannel<(usize, usize)>, |
40 | | - n: usize, |
| 37 | + cvar: Condvar, |
| 38 | + num_tasks: usize, |
41 | 39 | } |
42 | 40 |
|
43 | 41 | // The inner state of a double barrier |
44 | 42 | #[derive(Debug)] |
45 | 43 | struct BarrierState { |
46 | | - waker: BroadcastChannel<(usize, usize)>, |
47 | 44 | count: usize, |
48 | 45 | generation_id: usize, |
49 | 46 | } |
@@ -81,25 +78,14 @@ impl Barrier { |
81 | 78 | /// |
82 | 79 | /// let barrier = Barrier::new(10); |
83 | 80 | /// ``` |
84 | | - pub fn new(mut n: usize) -> Barrier { |
85 | | - let waker = BroadcastChannel::new(); |
86 | | - let wait = waker.clone(); |
87 | | - |
88 | | - if n == 0 { |
89 | | - // if n is 0, it's not clear what behavior the user wants. |
90 | | - // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every |
91 | | - // .wait() immediately unblocks, so we adopt that here as well. |
92 | | - n = 1; |
93 | | - } |
94 | | - |
| 81 | + pub fn new(n: usize) -> Barrier { |
95 | 82 | Barrier { |
96 | 83 | state: Mutex::new(BarrierState { |
97 | | - waker, |
98 | 84 | count: 0, |
99 | 85 | generation_id: 1, |
100 | 86 | }), |
101 | | - n, |
102 | | - wait, |
| 87 | + cvar: Condvar::new(), |
| 88 | + num_tasks: n, |
103 | 89 | } |
104 | 90 | } |
105 | 91 |
|
@@ -143,35 +129,20 @@ impl Barrier { |
143 | 129 | /// # }); |
144 | 130 | /// ``` |
145 | 131 | pub async fn wait(&self) -> BarrierWaitResult { |
146 | | - let mut lock = self.state.lock().await; |
147 | | - let local_gen = lock.generation_id; |
148 | | - |
149 | | - lock.count += 1; |
| 132 | + let mut state = self.state.lock().await; |
| 133 | + let local_gen = state.generation_id; |
| 134 | + state.count += 1; |
150 | 135 |
|
151 | | - if lock.count < self.n { |
152 | | - let mut wait = self.wait.clone(); |
153 | | - |
154 | | - let mut generation_id = lock.generation_id; |
155 | | - let mut count = lock.count; |
156 | | - |
157 | | - drop(lock); |
158 | | - |
159 | | - while local_gen == generation_id && count < self.n { |
160 | | - let (g, c) = wait.recv().await.expect("sender has not been closed"); |
161 | | - generation_id = g; |
162 | | - count = c; |
| 136 | + if state.count < self.num_tasks { |
| 137 | + while local_gen == state.generation_id && state.count < self.num_tasks { |
| 138 | + state = self.cvar.wait(state).await; |
163 | 139 | } |
164 | 140 |
|
165 | 141 | BarrierWaitResult(false) |
166 | 142 | } else { |
167 | | - lock.count = 0; |
168 | | - lock.generation_id = lock.generation_id.wrapping_add(1); |
169 | | - |
170 | | - lock.waker |
171 | | - .send(&(lock.generation_id, lock.count)) |
172 | | - .await |
173 | | - .expect("there should be at least one receiver"); |
174 | | - |
| 143 | + state.count = 0; |
| 144 | + state.generation_id = state.generation_id.wrapping_add(1); |
| 145 | + self.cvar.notify_all(); |
175 | 146 | BarrierWaitResult(true) |
176 | 147 | } |
177 | 148 | } |
|
0 commit comments