|
1 | | -use futures::io::{AsyncRead, AsyncWrite}; |
2 | | -use rustls::Session; |
3 | | -use std::io::{self, Read, Write}; |
4 | | -use std::marker::Unpin; |
5 | | -use std::pin::Pin; |
6 | | -use std::task::{Context, Poll}; |
7 | 1 | pub(crate) mod tls_state; |
8 | | - |
9 | | -pub struct Stream<'a, IO, S> { |
10 | | - pub io: &'a mut IO, |
11 | | - pub session: &'a mut S, |
12 | | - pub eof: bool, |
13 | | -} |
14 | | - |
15 | | -trait WriteTls<IO: AsyncWrite, S: Session> { |
16 | | - fn write_tls(&mut self, cx: &mut Context) -> io::Result<usize>; |
17 | | -} |
18 | | - |
19 | | -#[derive(Clone, Copy)] |
20 | | -enum Focus { |
21 | | - Empty, |
22 | | - Readable, |
23 | | - Writable, |
24 | | -} |
25 | | - |
26 | | -impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { |
27 | | - pub fn new(io: &'a mut IO, session: &'a mut S) -> Self { |
28 | | - Stream { |
29 | | - io, |
30 | | - session, |
31 | | - // The state so far is only used to detect EOF, so either Stream |
32 | | - // or EarlyData state should both be all right. |
33 | | - eof: false, |
34 | | - } |
35 | | - } |
36 | | - |
37 | | - pub fn set_eof(mut self, eof: bool) -> Self { |
38 | | - self.eof = eof; |
39 | | - self |
40 | | - } |
41 | | - |
42 | | - pub fn as_mut_pin(&mut self) -> Pin<&mut Self> { |
43 | | - Pin::new(self) |
44 | | - } |
45 | | - |
46 | | - pub fn complete_io(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> { |
47 | | - self.complete_inner_io(cx, Focus::Empty) |
48 | | - } |
49 | | - |
50 | | - fn complete_read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> { |
51 | | - struct Reader<'a, 'b, T> { |
52 | | - io: &'a mut T, |
53 | | - cx: &'a mut Context<'b>, |
54 | | - } |
55 | | - |
56 | | - impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> { |
57 | | - fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { |
58 | | - match Pin::new(&mut self.io).poll_read(self.cx, buf) { |
59 | | - Poll::Ready(result) => result, |
60 | | - Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), |
61 | | - } |
62 | | - } |
63 | | - } |
64 | | - |
65 | | - let mut reader = Reader { io: self.io, cx }; |
66 | | - |
67 | | - let n = match self.session.read_tls(&mut reader) { |
68 | | - Ok(n) => n, |
69 | | - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, |
70 | | - Err(err) => return Poll::Ready(Err(err)), |
71 | | - }; |
72 | | - |
73 | | - self.session.process_new_packets().map_err(|err| { |
74 | | - // In case we have an alert to send describing this error, |
75 | | - // try a last-gasp write -- but don't predate the primary |
76 | | - // error. |
77 | | - let _ = self.write_tls(cx); |
78 | | - |
79 | | - io::Error::new(io::ErrorKind::InvalidData, err) |
80 | | - })?; |
81 | | - |
82 | | - Poll::Ready(Ok(n)) |
83 | | - } |
84 | | - |
85 | | - fn complete_write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> { |
86 | | - match self.write_tls(cx) { |
87 | | - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, |
88 | | - result => Poll::Ready(result), |
89 | | - } |
90 | | - } |
91 | | - |
92 | | - fn complete_inner_io( |
93 | | - &mut self, |
94 | | - cx: &mut Context, |
95 | | - focus: Focus, |
96 | | - ) -> Poll<io::Result<(usize, usize)>> { |
97 | | - let mut wrlen = 0; |
98 | | - let mut rdlen = 0; |
99 | | - |
100 | | - loop { |
101 | | - let mut write_would_block = false; |
102 | | - let mut read_would_block = false; |
103 | | - |
104 | | - while self.session.wants_write() { |
105 | | - match self.complete_write_io(cx) { |
106 | | - Poll::Ready(Ok(n)) => wrlen += n, |
107 | | - Poll::Pending => { |
108 | | - write_would_block = true; |
109 | | - break; |
110 | | - } |
111 | | - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), |
112 | | - } |
113 | | - } |
114 | | - |
115 | | - if !self.eof && self.session.wants_read() { |
116 | | - match self.complete_read_io(cx) { |
117 | | - Poll::Ready(Ok(0)) => self.eof = true, |
118 | | - Poll::Ready(Ok(n)) => rdlen += n, |
119 | | - Poll::Pending => read_would_block = true, |
120 | | - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), |
121 | | - } |
122 | | - } |
123 | | - |
124 | | - let would_block = match focus { |
125 | | - Focus::Empty => write_would_block || read_would_block, |
126 | | - Focus::Readable => read_would_block, |
127 | | - Focus::Writable => write_would_block, |
128 | | - }; |
129 | | - |
130 | | - match (self.eof, self.session.is_handshaking(), would_block) { |
131 | | - (true, true, _) => { |
132 | | - let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof"); |
133 | | - return Poll::Ready(Err(err)); |
134 | | - } |
135 | | - (_, false, true) => { |
136 | | - let would_block = match focus { |
137 | | - Focus::Empty => rdlen == 0 && wrlen == 0, |
138 | | - Focus::Readable => rdlen == 0, |
139 | | - Focus::Writable => wrlen == 0, |
140 | | - }; |
141 | | - |
142 | | - return if would_block { |
143 | | - Poll::Pending |
144 | | - } else { |
145 | | - Poll::Ready(Ok((rdlen, wrlen))) |
146 | | - }; |
147 | | - } |
148 | | - (_, false, _) => return Poll::Ready(Ok((rdlen, wrlen))), |
149 | | - (_, true, true) => return Poll::Pending, |
150 | | - (..) => (), |
151 | | - } |
152 | | - } |
153 | | - } |
154 | | -} |
155 | | - |
156 | | -impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> WriteTls<IO, S> for Stream<'a, IO, S> { |
157 | | - fn write_tls(&mut self, cx: &mut Context) -> io::Result<usize> { |
158 | | - // TODO writev |
159 | | - |
160 | | - struct Writer<'a, 'b, T> { |
161 | | - io: &'a mut T, |
162 | | - cx: &'a mut Context<'b>, |
163 | | - } |
164 | | - |
165 | | - impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> { |
166 | | - fn write(&mut self, buf: &[u8]) -> io::Result<usize> { |
167 | | - match Pin::new(&mut self.io).poll_write(self.cx, buf) { |
168 | | - Poll::Ready(result) => result, |
169 | | - Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), |
170 | | - } |
171 | | - } |
172 | | - |
173 | | - fn flush(&mut self) -> io::Result<()> { |
174 | | - match Pin::new(&mut self.io).poll_flush(self.cx) { |
175 | | - Poll::Ready(result) => result, |
176 | | - Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), |
177 | | - } |
178 | | - } |
179 | | - } |
180 | | - |
181 | | - let mut writer = Writer { io: self.io, cx }; |
182 | | - self.session.write_tls(&mut writer) |
183 | | - } |
184 | | -} |
185 | | - |
186 | | -impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> { |
187 | | - fn poll_read( |
188 | | - self: Pin<&mut Self>, |
189 | | - cx: &mut Context, |
190 | | - buf: &mut [u8], |
191 | | - ) -> Poll<io::Result<usize>> { |
192 | | - let this = self.get_mut(); |
193 | | - |
194 | | - while this.session.wants_read() { |
195 | | - match this.complete_inner_io(cx, Focus::Readable) { |
196 | | - Poll::Ready(Ok((0, _))) => break, |
197 | | - Poll::Ready(Ok(_)) => (), |
198 | | - Poll::Pending => return Poll::Pending, |
199 | | - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), |
200 | | - } |
201 | | - } |
202 | | - |
203 | | - match this.session.read(buf) { |
204 | | - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, |
205 | | - result => Poll::Ready(result), |
206 | | - } |
207 | | - } |
208 | | -} |
209 | | - |
210 | | -impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'a, IO, S> { |
211 | | - fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> { |
212 | | - let this = self.get_mut(); |
213 | | - |
214 | | - let len = match this.session.write(buf) { |
215 | | - Ok(n) => n, |
216 | | - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, |
217 | | - Err(err) => return Poll::Ready(Err(err)), |
218 | | - }; |
219 | | - while this.session.wants_write() { |
220 | | - match this.complete_inner_io(cx, Focus::Writable) { |
221 | | - Poll::Ready(Ok(_)) => (), |
222 | | - Poll::Pending if len != 0 => break, |
223 | | - Poll::Pending => return Poll::Pending, |
224 | | - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), |
225 | | - } |
226 | | - } |
227 | | - |
228 | | - if len != 0 || buf.is_empty() { |
229 | | - Poll::Ready(Ok(len)) |
230 | | - } else { |
231 | | - // not write zero |
232 | | - match this.session.write(buf) { |
233 | | - Ok(0) => Poll::Pending, |
234 | | - Ok(n) => Poll::Ready(Ok(n)), |
235 | | - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, |
236 | | - Err(err) => Poll::Ready(Err(err)), |
237 | | - } |
238 | | - } |
239 | | - } |
240 | | - |
241 | | - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> { |
242 | | - let this = self.get_mut(); |
243 | | - |
244 | | - this.session.flush()?; |
245 | | - while this.session.wants_write() { |
246 | | - futures::ready!(this.complete_inner_io(cx, Focus::Writable))?; |
247 | | - } |
248 | | - Pin::new(&mut this.io).poll_flush(cx) |
249 | | - } |
250 | | - |
251 | | - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
252 | | - let this = self.get_mut(); |
253 | | - |
254 | | - while this.session.wants_write() { |
255 | | - futures::ready!(this.complete_inner_io(cx, Focus::Writable))?; |
256 | | - } |
257 | | - Pin::new(&mut this.io).poll_close(cx) |
258 | | - } |
259 | | -} |
260 | | - |
261 | | -#[cfg(test)] |
262 | | -mod test_stream; |
0 commit comments