From 42c24021d9605f7eb89d74d5f6b5c93e4ab03ea7 Mon Sep 17 00:00:00 2001 From: lazy404 Date: Sat, 24 Feb 2024 23:08:21 +0000 Subject: [PATCH] Introduce support for OptionNegotiationFailed error code --- examples/tftpd-dir.rs | 2 +- src/packet.rs | 8 +++ src/server/read_req.rs | 96 ++++++++++++++++++++-------------- src/server/write_req.rs | 28 ++++++---- src/tests/client.rs | 111 ++++++++++++++++++++++++++++++++++++++++ src/tests/handlers.rs | 36 ++++++------- src/tests/mod.rs | 1 + src/tests/packet.rs | 2 +- src/tests/rrq.rs | 2 +- src/utils.rs | 10 ++-- 10 files changed, 222 insertions(+), 74 deletions(-) create mode 100644 src/tests/client.rs diff --git a/examples/tftpd-dir.rs b/examples/tftpd-dir.rs index 5cd068a..3c538d7 100644 --- a/examples/tftpd-dir.rs +++ b/examples/tftpd-dir.rs @@ -11,7 +11,7 @@ async fn main() -> Result<()> { .expect("Failed to initialize logger"); let tftpd = TftpServerBuilder::with_dir_ro(".")? - .bind("0.0.0.0:6969".parse().unwrap()) + .bind("127.0.0.1:6969".parse().unwrap()) // Workaround to handle cases where client is behind VPN .block_size_limit(1024) .build() diff --git a/src/packet.rs b/src/packet.rs index f0dca80..65e51ae 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -32,6 +32,7 @@ pub enum Error { UnknownTransferId, FileAlreadyExists, NoSuchUser, + OptionNegotiationFailed, } #[derive(Debug)] @@ -192,6 +193,7 @@ impl Error { 5 => Error::UnknownTransferId, 6 => Error::FileAlreadyExists, 7 => Error::NoSuchUser, + 8 => Error::OptionNegotiationFailed, 0 | _ => match msg { Some(msg) => Error::Msg(msg.to_string()), None => Error::UnknownError, @@ -210,6 +212,7 @@ impl Error { Error::UnknownTransferId => 5, Error::FileAlreadyExists => 6, Error::NoSuchUser => 7, + Error::OptionNegotiationFailed => 8, } } @@ -224,8 +227,13 @@ impl Error { Error::UnknownTransferId => "Unknown transfer ID", Error::FileAlreadyExists => "File already exists", Error::NoSuchUser => "No such user", + Error::OptionNegotiationFailed => "Option negotiation failed", } } + + pub(crate) fn is_client_error(&self) -> bool { + matches!(self, Error::OptionNegotiationFailed) + } } impl From for Packet<'_> { diff --git a/src/server/read_req.rs b/src/server/read_req.rs index 063f57b..aafd330 100644 --- a/src/server/read_req.rs +++ b/src/server/read_req.rs @@ -77,12 +77,20 @@ where pub(crate) async fn handle(&mut self) { if let Err(e) = self.try_handle().await { trace!("RRQ request failed (peer: {}, error: {})", &self.peer, &e); - let mut buffer = BytesMut::with_capacity(DEFAULT_BLOCK_SIZE); - Packet::Error(e.into()).encode(&mut buffer); - let buf = buffer.split().freeze(); - // Errors are never retransmitted. - // We do not care if `send_to` resulted to an IO error. - let _ = self.socket.send_to(&buf[..], self.peer).await; + + if let Error::Packet( + crate::packet::Error::OptionNegotiationFailed, + ) = e + { + // client aborted the connection, nothing to do + } else { + let mut buffer = BytesMut::with_capacity(DEFAULT_BLOCK_SIZE); + Packet::Error(e.into()).encode(&mut buffer); + let buf = buffer.split().freeze(); + // Errors are never retransmitted. + // We do not care if `send_to` resulted to an IO error. + let _ = self.socket.send_to(&buf[..], self.peer).await; + } } } @@ -92,17 +100,20 @@ where let mut block_id: u16; let mut window_base: u16 = 1; let mut buf: Bytes; - let mut is_last_block: bool; + let mut is_last_block: bool = false; - (buf, is_last_block) = self.fill_data_block(window_base).await?; - window.push_back(buf); - - // Send OACK after we manage to read the first block from reader. - // - // We do this because we want to give the developers the option to - // produce an error after they construct a reader. - if let Some(opts) = self.oack_opts.as_ref() { + if let Some(opts) = self.oack_opts.take() { trace!("RRQ OACK (peer: {}, opts: {:?}", &self.peer, &opts); + // Send OACK after we manage to read the first block from the reader for + // non-transfer size probe requests (oack.transfer_size value is set). + // During transfer size probes a client aborts the connection after receiving + // oack from the server. For normal requests we do this because we want to give + // the developers the option to produce an error after they construct a reader. + if opts.transfer_size.is_none() { + (buf, is_last_block) = + self.fill_data_block(window_base).await?; + window.push_back(buf); + } let mut buff = BytesMut::with_capacity(PACKET_DATA_HEADER_LEN + 64); Packet::OAck(opts.to_owned()).encode(&mut buff); // OACK is not really part of the window, so we send it separately @@ -186,7 +197,9 @@ where ); return Ok(blocks_acked); } - Err(ref e) if e.kind() == io::ErrorKind::TimedOut => { + Err(Error::Io(ref e)) + if e.kind() == io::ErrorKind::TimedOut => + { trace!( "RRQ (peer: {}, block_id: {}) - Timeout", &self.peer, @@ -194,7 +207,7 @@ where ); continue; } - Err(e) => return Err(e.into()), + Err(e) => return Err(e), } } @@ -206,7 +219,7 @@ where &mut self, window_base: u16, window_len: u16, - ) -> io::Result { + ) -> Result { // We can not use `self` within `async_std::io::timeout` because not all // struct members implement `Sync`. So we borrow only what we need. let socket = &mut self.socket; @@ -224,30 +237,37 @@ where } // parse only valid Ack packets, the rest are ignored - if let Ok(Packet::Ack(recved_block_id)) = - Packet::decode(&buf[..len]) + // if let Ok(Packet::Ack(recved_block_id)) = + match Packet::decode(&buf[..len]) { - let window_end = window_base.wrapping_add(window_len); - - if window_end > window_base { - // window_end did not wrap - if recved_block_id >= window_base && recved_block_id < window_end { - // number of blocks acked - return Ok(recved_block_id-window_base+1u16); - } - else { - trace!("Unexpected ack packet {recved_block_id}, window_base: {window_base}, window_len: {window_len}"); - } - }else { - // window_end wrapped - if recved_block_id >= window_base { - return Ok(1u16 + (recved_block_id - window_base)); - } else if recved_block_id < window_end { - return Ok(1u16 + recved_block_id + (window_len - window_end)); + Ok(Packet::Ack(recved_block_id)) => { + let window_end = window_base.wrapping_add(window_len); + + if window_end > window_base { + // window_end did not wrap + if recved_block_id >= window_base && recved_block_id < window_end { + // number of blocks acked + return Ok(recved_block_id - window_base + 1u16); + } else { + trace!("Unexpected ack packet {recved_block_id}, window_base: {window_base}, window_len: {window_len}"); + } } else { - trace!("Unexpected ack packet {recved_block_id}, window_base: {window_base}, window_len: {window_len}"); + // window_end wrapped + if recved_block_id >= window_base { + return Ok(1u16 + (recved_block_id - window_base)); + } else if recved_block_id < window_end { + return Ok(1u16 + recved_block_id + (window_len - window_end)); + } else { + trace!("Unexpected ack packet {recved_block_id}, window_base: {window_base}, window_len: {window_len}"); + } } + }, + Ok(Packet::Error(error)) if error.is_client_error()=> { + // pass errors coming from the client + return Err(Error::Packet(error)) } + // ignore all other errors + _ => {} } } }) diff --git a/src/server/write_req.rs b/src/server/write_req.rs index 94180d0..6c21e36 100644 --- a/src/server/write_req.rs +++ b/src/server/write_req.rs @@ -75,12 +75,20 @@ where pub(crate) async fn handle(&mut self) { if let Err(e) = self.try_handle().await { trace!("WRQ request failed (peer: {}, error: {}", self.peer, &e); - - Packet::Error(e.into()).encode(&mut self.buffer); - let buf = self.buffer.split().freeze(); - // Errors are never retransmitted. - // We do not care if `send_to` resulted to an IO error. - let _ = self.socket.send_to(&buf[..], self.peer).await; + match e { + Error::Packet(client_error) + if client_error.is_client_error() => + { + //we don't have to acknowledge client errors} + } + e => { + Packet::Error(e.into()).encode(&mut self.buffer); + let buf = self.buffer.split().freeze(); + // Errors are never retransmitted. + // We do not care if `send_to` resulted to an IO error. + let _ = self.socket.send_to(&buf[..], self.peer).await; + } + } } } @@ -122,19 +130,21 @@ where self.socket.send_to(&self.ack, self.peer).await?; return Ok(data); } - Err(ref e) if e.kind() == io::ErrorKind::TimedOut => { + Err(Error::Io(ref e)) + if e.kind() == io::ErrorKind::TimedOut => + { // On timeout reply with the previous ACK packet self.socket.send_to(&self.ack, self.peer).await?; continue; } - Err(e) => return Err(e.into()), + Err(e) => return Err(e), } } Err(Error::MaxSendRetriesReached(self.peer, block_id)) } - async fn recv_data_block(&mut self, block_id: u16) -> io::Result { + async fn recv_data_block(&mut self, block_id: u16) -> Result { let socket = &mut self.socket; let peer = self.peer; diff --git a/src/tests/client.rs b/src/tests/client.rs new file mode 100644 index 0000000..6d5eecc --- /dev/null +++ b/src/tests/client.rs @@ -0,0 +1,111 @@ +use crate::packet::Error::OptionNegotiationFailed; +use crate::packet::{self, Mode, Opts, RwReq}; +use crate::server::TftpServerBuilder; +use async_executor::Executor; +use async_io::{Async, Timer}; +use futures_lite::future::block_on; +use futures_lite::{future, AsyncRead}; +use std::cell::Cell; +use std::io; +use std::net::UdpSocket; +use std::rc::Rc; +use std::sync::Arc; + +use super::handlers::*; +use super::packet::packet_to_bytes; +use std::task::Poll; +use std::time::Duration; + +struct ResultsReader { + results: Vec>>, +} + +impl AsyncRead for ResultsReader { + fn poll_read( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Poll::Ready(match self.get_mut().results.pop() { + Some(Ok(result)) => { + buf[..result.len()].copy_from_slice(&result); + Ok(result.len()) + } + Some(Err(err)) => Err(err), + None => Err(io::ErrorKind::NotFound.into()), + }) + } +} + +#[test] +fn test_abort_on_read_error() { + let ex = Arc::new(Executor::new()); + let transferred = Rc::new(Cell::new(false)); + + block_on(ex.run({ + let ex = ex.clone(); + let transferred = transferred.clone(); + + async move { + let handler = ReaderHandler::new(ResultsReader { results: vec![Err(io::ErrorKind::InvalidInput.into())] }, Some(4)); + let tftpd = TftpServerBuilder::with_handler(handler) + .bind("127.0.0.1:0".parse().unwrap()) + .timeout(Duration::from_secs(1)) + .build() + .await + .unwrap(); + let addr = tftpd.listen_addr().unwrap(); + // start server + let server_task = ex.spawn(async move { + tftpd.serve().await.unwrap(); + }); + + let socket = Async::::bind(([127, 0, 0, 1], 0)).unwrap(); + // send rrq with transfer size to 0 to simulate a transfer size probe + let req_opts = Opts { + transfer_size: Some(0), + ..Default::default() + }; + let rrq = packet::Packet::Rrq(RwReq { + filename: "abc".to_string(), + mode: Mode::Octet, + opts: req_opts, + }); + socket.send_to(&packet_to_bytes(&rrq), addr).await.unwrap(); + + // read the ack + let mut buf = [0u8; 1024]; + let (len, peer) = socket.recv_from(&mut buf).await.unwrap(); + let response = packet::Packet::decode(&buf[..len]).unwrap(); + assert!(matches!( + response, + packet::Packet::OAck(Opts {transfer_size: Some(4),..})), + "Server did not send OAck packet: {:?}", response); + + // send error packet + let abort_packet = packet::Packet::Error(OptionNegotiationFailed); + socket.send_to(&packet_to_bytes(&abort_packet), peer).await.unwrap(); + + // make sure the server doesn't send anything else + assert!( + future::race( + async move { + Timer::after(Duration::from_secs(3)).await; + true + }, + async move { + // fail if we get anything after sending OptionNegotiationFailed error + let _ = socket.recv_from(&mut buf).await.unwrap(); + false + } + ) + .await, + "Server sent data after client sent OptionNegotiationFailed error" + ); + server_task.cancel().await; + transferred.set(true); + } + })); + + assert!(transferred.get()); +} diff --git a/src/tests/handlers.rs b/src/tests/handlers.rs index 00adf62..0023ebb 100644 --- a/src/tests/handlers.rs +++ b/src/tests/handlers.rs @@ -1,31 +1,28 @@ -#![cfg(feature = "external-client-tests")] -#![cfg(target_os = "linux")] - -use async_channel::Sender; +use crate::packet; +use crate::server::Handler; use futures_lite::io::Sink; +use futures_lite::AsyncRead; use std::net::SocketAddr; use std::path::Path; -use super::random_file::RandomFile; -use crate::packet; -use crate::server::Handler; - -pub struct RandomHandler { - md5_tx: Option>, - file_size: usize, +pub struct ReaderHandler { + reader: Option, + size: Option, } -impl RandomHandler { - pub fn new(file_size: usize, md5_tx: Sender) -> Self { - RandomHandler { - md5_tx: Some(md5_tx), - file_size, +impl ReaderHandler { + pub fn new(reader: Reader, size: Option) -> Self { + ReaderHandler { + reader: Some(reader), + size, } } } -impl Handler for RandomHandler { - type Reader = RandomFile; +impl Handler + for ReaderHandler +{ + type Reader = Reader; type Writer = Sink; async fn read_req_open( @@ -33,8 +30,7 @@ impl Handler for RandomHandler { _client: &SocketAddr, _path: &Path, ) -> Result<(Self::Reader, Option), packet::Error> { - let md5_tx = self.md5_tx.take().expect("md5_tx already consumed"); - Ok((RandomFile::new(self.file_size, md5_tx), None)) + Ok((self.reader.take().expect("reader already consumed"), self.size)) } async fn write_req_open( diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 4da4d5f..dcc58b7 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -1,5 +1,6 @@ #![cfg(test)] +mod client; mod external_client; mod handlers; mod packet; diff --git a/src/tests/packet.rs b/src/tests/packet.rs index 17c0470..bb5bbd8 100644 --- a/src/tests/packet.rs +++ b/src/tests/packet.rs @@ -4,7 +4,7 @@ use crate::error::Error; use crate::packet::{self, Mode, Opts, Packet, RwReq}; use crate::parse::parse_opts; -fn packet_to_bytes(packet: &Packet) -> Bytes { +pub(crate) fn packet_to_bytes(packet: &Packet) -> Bytes { let mut buf = BytesMut::with_capacity(0); packet.encode(&mut buf); buf.freeze() diff --git a/src/tests/rrq.rs b/src/tests/rrq.rs index 9599778..31d7147 100644 --- a/src/tests/rrq.rs +++ b/src/tests/rrq.rs @@ -27,7 +27,7 @@ fn transfer( async move { let (md5_tx, md5_rx) = async_channel::bounded(1); - let handler = RandomHandler::new(file_size, md5_tx); + let handler = ReaderHandler::new(md5_tx, Some(file_size)); // bind let tftpd = TftpServerBuilder::with_handler(handler) diff --git a/src/utils.rs b/src/utils.rs index 3a9f2f1..12a81cf 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,16 +1,18 @@ +use crate::Error; +use crate::Result; use async_io::Timer; use futures_lite::future; use std::future::Future; -use std::io; +use std::io::ErrorKind; use std::time::Duration; pub async fn io_timeout( dur: Duration, - f: impl Future>, -) -> io::Result { + f: impl Future>, +) -> Result { future::race(f, async move { Timer::after(dur).await; - Err(io::ErrorKind::TimedOut.into()) + Err(Error::Io(ErrorKind::TimedOut.into())) }) .await }