Skip to content

Commit 77aa156

Browse files
committed
Introduce support for OptionNegotiationFailed error code
1 parent 5f1df99 commit 77aa156

File tree

10 files changed

+220
-72
lines changed

10 files changed

+220
-72
lines changed

examples/tftpd-dir.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ async fn main() -> Result<()> {
1111
.expect("Failed to initialize logger");
1212

1313
let tftpd = TftpServerBuilder::with_dir_ro(".")?
14-
.bind("0.0.0.0:6969".parse().unwrap())
14+
.bind("127.0.0.1:6969".parse().unwrap())
1515
// Workaround to handle cases where client is behind VPN
1616
.block_size_limit(1024)
1717
.build()

src/packet.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ pub enum Error {
3232
UnknownTransferId,
3333
FileAlreadyExists,
3434
NoSuchUser,
35+
OptionNegotiationFailed,
3536
}
3637

3738
#[derive(Debug)]
@@ -192,6 +193,7 @@ impl Error {
192193
5 => Error::UnknownTransferId,
193194
6 => Error::FileAlreadyExists,
194195
7 => Error::NoSuchUser,
196+
8 => Error::OptionNegotiationFailed,
195197
0 | _ => match msg {
196198
Some(msg) => Error::Msg(msg.to_string()),
197199
None => Error::UnknownError,
@@ -210,6 +212,7 @@ impl Error {
210212
Error::UnknownTransferId => 5,
211213
Error::FileAlreadyExists => 6,
212214
Error::NoSuchUser => 7,
215+
Error::OptionNegotiationFailed => 8,
213216
}
214217
}
215218

@@ -224,8 +227,13 @@ impl Error {
224227
Error::UnknownTransferId => "Unknown transfer ID",
225228
Error::FileAlreadyExists => "File already exists",
226229
Error::NoSuchUser => "No such user",
230+
Error::OptionNegotiationFailed => "Option negotiation failed",
227231
}
228232
}
233+
234+
pub(crate) fn is_client_error(&self) -> bool {
235+
matches!(self, Error::OptionNegotiationFailed)
236+
}
229237
}
230238

231239
impl From<Error> for Packet<'_> {

src/server/read_req.rs

Lines changed: 57 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,20 @@ where
7777
pub(crate) async fn handle(&mut self) {
7878
if let Err(e) = self.try_handle().await {
7979
trace!("RRQ request failed (peer: {}, error: {})", &self.peer, &e);
80-
let mut buffer = BytesMut::with_capacity(DEFAULT_BLOCK_SIZE);
81-
Packet::Error(e.into()).encode(&mut buffer);
82-
let buf = buffer.split().freeze();
83-
// Errors are never retransmitted.
84-
// We do not care if `send_to` resulted to an IO error.
85-
let _ = self.socket.send_to(&buf[..], self.peer).await;
80+
81+
if let Error::Packet(
82+
crate::packet::Error::OptionNegotiationFailed,
83+
) = e
84+
{
85+
// client aborted the connection, nothing to do
86+
} else {
87+
let mut buffer = BytesMut::with_capacity(DEFAULT_BLOCK_SIZE);
88+
Packet::Error(e.into()).encode(&mut buffer);
89+
let buf = buffer.split().freeze();
90+
// Errors are never retransmitted.
91+
// We do not care if `send_to` resulted to an IO error.
92+
let _ = self.socket.send_to(&buf[..], self.peer).await;
93+
}
8694
}
8795
}
8896

@@ -92,17 +100,20 @@ where
92100
let mut block_id: u16;
93101
let mut window_base: u16 = 1;
94102
let mut buf: Bytes;
95-
let mut is_last_block: bool;
103+
let mut is_last_block: bool = false;
96104

97-
(buf, is_last_block) = self.fill_data_block(window_base).await?;
98-
window.push_back(buf);
99-
100-
// Send OACK after we manage to read the first block from reader.
101-
//
102-
// We do this because we want to give the developers the option to
103-
// produce an error after they construct a reader.
104-
if let Some(opts) = self.oack_opts.as_ref() {
105+
if let Some(opts) = self.oack_opts.take() {
105106
trace!("RRQ OACK (peer: {}, opts: {:?}", &self.peer, &opts);
107+
// Send OACK after we manage to read the first block from the reader for
108+
// non-transfer size probe requests (oack.transfer_size value is set).
109+
// During transfer size probes a client aborts the connection after receiving
110+
// oack from the server. For normal requests we do this because we want to give
111+
// the developers the option to produce an error after they construct a reader.
112+
if opts.transfer_size.is_none() {
113+
(buf, is_last_block) =
114+
self.fill_data_block(window_base).await?;
115+
window.push_back(buf);
116+
}
106117
let mut buff = BytesMut::with_capacity(PACKET_DATA_HEADER_LEN + 64);
107118
Packet::OAck(opts.to_owned()).encode(&mut buff);
108119
// OACK is not really part of the window, so we send it separately
@@ -186,7 +197,9 @@ where
186197
);
187198
return Ok(blocks_acked);
188199
}
189-
Err(ref e) if e.kind() == io::ErrorKind::TimedOut => {
200+
Err(Error::Io(ref e))
201+
if e.kind() == io::ErrorKind::TimedOut =>
202+
{
190203
trace!(
191204
"RRQ (peer: {}, block_id: {}) - Timeout",
192205
&self.peer,
@@ -206,7 +219,7 @@ where
206219
&mut self,
207220
window_base: u16,
208221
window_len: u16,
209-
) -> io::Result<u16> {
222+
) -> Result<u16> {
210223
// We can not use `self` within `async_std::io::timeout` because not all
211224
// struct members implement `Sync`. So we borrow only what we need.
212225
let socket = &mut self.socket;
@@ -224,30 +237,37 @@ where
224237
}
225238

226239
// parse only valid Ack packets, the rest are ignored
227-
if let Ok(Packet::Ack(recved_block_id)) =
228-
Packet::decode(&buf[..len])
240+
// if let Ok(Packet::Ack(recved_block_id)) =
241+
match Packet::decode(&buf[..len])
229242
{
230-
let window_end = window_base.wrapping_add(window_len);
231-
232-
if window_end > window_base {
233-
// window_end did not wrap
234-
if recved_block_id >= window_base && recved_block_id < window_end {
235-
// number of blocks acked
236-
return Ok(recved_block_id-window_base+1u16);
237-
}
238-
else {
239-
trace!("Unexpected ack packet {recved_block_id}, window_base: {window_base}, window_len: {window_len}");
240-
}
241-
}else {
242-
// window_end wrapped
243-
if recved_block_id >= window_base {
244-
return Ok(1u16 + (recved_block_id - window_base));
245-
} else if recved_block_id < window_end {
246-
return Ok(1u16 + recved_block_id + (window_len - window_end));
243+
Ok(Packet::Ack(recved_block_id)) => {
244+
let window_end = window_base.wrapping_add(window_len);
245+
246+
if window_end > window_base {
247+
// window_end did not wrap
248+
if recved_block_id >= window_base && recved_block_id < window_end {
249+
// number of blocks acked
250+
return Ok(recved_block_id - window_base + 1u16);
251+
} else {
252+
trace!("Unexpected ack packet {recved_block_id}, window_base: {window_base}, window_len: {window_len}");
253+
}
247254
} else {
248-
trace!("Unexpected ack packet {recved_block_id}, window_base: {window_base}, window_len: {window_len}");
255+
// window_end wrapped
256+
if recved_block_id >= window_base {
257+
return Ok(1u16 + (recved_block_id - window_base));
258+
} else if recved_block_id < window_end {
259+
return Ok(1u16 + recved_block_id + (window_len - window_end));
260+
} else {
261+
trace!("Unexpected ack packet {recved_block_id}, window_base: {window_base}, window_len: {window_len}");
262+
}
249263
}
264+
},
265+
Ok(Packet::Error(error)) if error.is_client_error()=> {
266+
// pass errors coming from the client
267+
return Err(Error::Packet(error))
250268
}
269+
// ignore all other errors
270+
_ => {}
251271
}
252272
}
253273
})

src/server/write_req.rs

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,20 @@ where
7575
pub(crate) async fn handle(&mut self) {
7676
if let Err(e) = self.try_handle().await {
7777
trace!("WRQ request failed (peer: {}, error: {}", self.peer, &e);
78-
79-
Packet::Error(e.into()).encode(&mut self.buffer);
80-
let buf = self.buffer.split().freeze();
81-
// Errors are never retransmitted.
82-
// We do not care if `send_to` resulted to an IO error.
83-
let _ = self.socket.send_to(&buf[..], self.peer).await;
78+
match e {
79+
Error::Packet(client_error)
80+
if client_error.is_client_error() =>
81+
{
82+
//we don't have to acknowledge client errors}
83+
}
84+
e => {
85+
Packet::Error(e.into()).encode(&mut self.buffer);
86+
let buf = self.buffer.split().freeze();
87+
// Errors are never retransmitted.
88+
// We do not care if `send_to` resulted to an IO error.
89+
let _ = self.socket.send_to(&buf[..], self.peer).await;
90+
}
91+
}
8492
}
8593
}
8694

@@ -122,7 +130,9 @@ where
122130
self.socket.send_to(&self.ack, self.peer).await?;
123131
return Ok(data);
124132
}
125-
Err(ref e) if e.kind() == io::ErrorKind::TimedOut => {
133+
Err(Error::Io(ref e))
134+
if e.kind() == io::ErrorKind::TimedOut =>
135+
{
126136
// On timeout reply with the previous ACK packet
127137
self.socket.send_to(&self.ack, self.peer).await?;
128138
continue;
@@ -134,7 +144,7 @@ where
134144
Err(Error::MaxSendRetriesReached(self.peer, block_id))
135145
}
136146

137-
async fn recv_data_block(&mut self, block_id: u16) -> io::Result<Bytes> {
147+
async fn recv_data_block(&mut self, block_id: u16) -> Result<Bytes> {
138148
let socket = &mut self.socket;
139149
let peer = self.peer;
140150

src/tests/client.rs

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
use crate::packet::Error::OptionNegotiationFailed;
2+
use crate::packet::{self, Mode, Opts, RwReq};
3+
use crate::server::TftpServerBuilder;
4+
use async_executor::Executor;
5+
use async_io::{Async, Timer};
6+
use futures_lite::future::block_on;
7+
use futures_lite::{future, AsyncRead};
8+
use std::cell::Cell;
9+
use std::io;
10+
use std::net::UdpSocket;
11+
use std::rc::Rc;
12+
use std::sync::Arc;
13+
14+
use super::handlers::*;
15+
use super::packet::packet_to_bytes;
16+
use std::task::Poll;
17+
use std::time::Duration;
18+
19+
struct ResultsReader {
20+
results: Vec<io::Result<Vec<u8>>>,
21+
}
22+
23+
impl AsyncRead for ResultsReader {
24+
fn poll_read(
25+
self: std::pin::Pin<&mut Self>,
26+
_cx: &mut std::task::Context<'_>,
27+
buf: &mut [u8],
28+
) -> Poll<io::Result<usize>> {
29+
Poll::Ready(match self.get_mut().results.pop() {
30+
Some(Ok(result)) => {
31+
buf[..result.len()].copy_from_slice(&result);
32+
Ok(result.len())
33+
}
34+
Some(Err(err)) => Err(err),
35+
None => Err(io::ErrorKind::NotFound.into()),
36+
})
37+
}
38+
}
39+
40+
#[test]
41+
fn test_abort_on_read_error() {
42+
let ex = Arc::new(Executor::new());
43+
let transferred = Rc::new(Cell::new(false));
44+
45+
block_on(ex.run({
46+
let ex = ex.clone();
47+
let transferred = transferred.clone();
48+
49+
async move {
50+
let handler = ReaderHandler::new(ResultsReader { results: vec![Err(io::ErrorKind::InvalidInput.into())] }, Some(4));
51+
let tftpd = TftpServerBuilder::with_handler(handler)
52+
.bind("127.0.0.1:0".parse().unwrap())
53+
.timeout(Duration::from_secs(1))
54+
.build()
55+
.await
56+
.unwrap();
57+
let addr = tftpd.listen_addr().unwrap();
58+
// start server
59+
let server_task = ex.spawn(async move {
60+
tftpd.serve().await.unwrap();
61+
});
62+
63+
let socket = Async::<UdpSocket>::bind(([127, 0, 0, 1], 0)).unwrap();
64+
// send rrq with transfer size to 0 to simulate a transfer size probe
65+
let req_opts = Opts {
66+
transfer_size: Some(0),
67+
..Default::default()
68+
};
69+
let rrq = packet::Packet::Rrq(RwReq {
70+
filename: "abc".to_string(),
71+
mode: Mode::Octet,
72+
opts: req_opts,
73+
});
74+
socket.send_to(&packet_to_bytes(&rrq), addr).await.unwrap();
75+
76+
// read the ack
77+
let mut buf = [0u8; 1024];
78+
let (len, peer) = socket.recv_from(&mut buf).await.unwrap();
79+
let response = packet::Packet::decode(&buf[..len]).unwrap();
80+
assert!(matches!(
81+
response,
82+
packet::Packet::OAck(Opts {transfer_size: Some(4),..})),
83+
"Server did not send OAck packet: {:?}", response);
84+
85+
// send error packet
86+
let abort_packet = packet::Packet::Error(OptionNegotiationFailed);
87+
socket.send_to(&packet_to_bytes(&abort_packet), peer).await.unwrap();
88+
89+
// make sure the server doesn't send anything else
90+
assert!(
91+
future::race(
92+
async move {
93+
Timer::after(Duration::from_secs(3)).await;
94+
true
95+
},
96+
async move {
97+
// fail if we get anything after sending OptionNegotiationFailed error
98+
let _ = socket.recv_from(&mut buf).await.unwrap();
99+
false
100+
}
101+
)
102+
.await,
103+
"Server sent data after client sent OptionNegotiationFailed error"
104+
);
105+
server_task.cancel().await;
106+
transferred.set(true);
107+
}
108+
}));
109+
110+
assert!(transferred.get());
111+
}

src/tests/handlers.rs

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,36 @@
1-
#![cfg(feature = "external-client-tests")]
2-
#![cfg(target_os = "linux")]
3-
4-
use async_channel::Sender;
1+
use crate::packet;
2+
use crate::server::Handler;
53
use futures_lite::io::Sink;
4+
use futures_lite::AsyncRead;
65
use std::net::SocketAddr;
76
use std::path::Path;
87

9-
use super::random_file::RandomFile;
10-
use crate::packet;
11-
use crate::server::Handler;
12-
13-
pub struct RandomHandler {
14-
md5_tx: Option<Sender<md5::Digest>>,
15-
file_size: usize,
8+
pub struct ReaderHandler<Reader> {
9+
reader: Option<Reader>,
10+
size: Option<u64>,
1611
}
1712

18-
impl RandomHandler {
19-
pub fn new(file_size: usize, md5_tx: Sender<md5::Digest>) -> Self {
20-
RandomHandler {
21-
md5_tx: Some(md5_tx),
22-
file_size,
13+
impl<Reader> ReaderHandler<Reader> {
14+
pub fn new(reader: Reader, size: Option<u64>) -> Self {
15+
ReaderHandler {
16+
reader: Some(reader),
17+
size,
2318
}
2419
}
2520
}
2621

27-
impl Handler for RandomHandler {
28-
type Reader = RandomFile;
22+
impl<Reader: Send + AsyncRead + Unpin + 'static> Handler
23+
for ReaderHandler<Reader>
24+
{
25+
type Reader = Reader;
2926
type Writer = Sink;
3027

3128
async fn read_req_open(
3229
&mut self,
3330
_client: &SocketAddr,
3431
_path: &Path,
3532
) -> Result<(Self::Reader, Option<u64>), packet::Error> {
36-
let md5_tx = self.md5_tx.take().expect("md5_tx already consumed");
37-
Ok((RandomFile::new(self.file_size, md5_tx), None))
33+
Ok((self.reader.take().expect("reader already consumed"), self.size))
3834
}
3935

4036
async fn write_req_open(

0 commit comments

Comments
 (0)