Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/tftpd-dir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ async fn main() -> Result<()> {
.expect("Failed to initialize logger");

let tftpd = TftpServerBuilder::with_dir_ro(".")?
.bind("0.0.0.0:6969".parse().unwrap())
.bind("127.0.0.1:6969".parse().unwrap())
// Workaround to handle cases where client is behind VPN
.block_size_limit(1024)
.build()
Expand Down
8 changes: 8 additions & 0 deletions src/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub enum Error {
UnknownTransferId,
FileAlreadyExists,
NoSuchUser,
OptionNegotiationFailed,
}

#[derive(Debug)]
Expand Down Expand Up @@ -192,6 +193,7 @@ impl Error {
5 => Error::UnknownTransferId,
6 => Error::FileAlreadyExists,
7 => Error::NoSuchUser,
8 => Error::OptionNegotiationFailed,
0 | _ => match msg {
Some(msg) => Error::Msg(msg.to_string()),
None => Error::UnknownError,
Expand All @@ -210,6 +212,7 @@ impl Error {
Error::UnknownTransferId => 5,
Error::FileAlreadyExists => 6,
Error::NoSuchUser => 7,
Error::OptionNegotiationFailed => 8,
}
}

Expand All @@ -224,8 +227,13 @@ impl Error {
Error::UnknownTransferId => "Unknown transfer ID",
Error::FileAlreadyExists => "File already exists",
Error::NoSuchUser => "No such user",
Error::OptionNegotiationFailed => "Option negotiation failed",
}
}

pub(crate) fn is_client_error(&self) -> bool {
matches!(self, Error::OptionNegotiationFailed)
}
}

impl From<Error> for Packet<'_> {
Expand Down
96 changes: 58 additions & 38 deletions src/server/read_req.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,20 @@ where
pub(crate) async fn handle(&mut self) {
if let Err(e) = self.try_handle().await {
trace!("RRQ request failed (peer: {}, error: {})", &self.peer, &e);
let mut buffer = BytesMut::with_capacity(DEFAULT_BLOCK_SIZE);
Packet::Error(e.into()).encode(&mut buffer);
let buf = buffer.split().freeze();
// Errors are never retransmitted.
// We do not care if `send_to` resulted to an IO error.
let _ = self.socket.send_to(&buf[..], self.peer).await;

if let Error::Packet(
crate::packet::Error::OptionNegotiationFailed,
) = e
{
// client aborted the connection, nothing to do
} else {
let mut buffer = BytesMut::with_capacity(DEFAULT_BLOCK_SIZE);
Packet::Error(e.into()).encode(&mut buffer);
let buf = buffer.split().freeze();
// Errors are never retransmitted.
// We do not care if `send_to` resulted to an IO error.
let _ = self.socket.send_to(&buf[..], self.peer).await;
}
}
}

Expand All @@ -92,17 +100,20 @@ where
let mut block_id: u16;
let mut window_base: u16 = 1;
let mut buf: Bytes;
let mut is_last_block: bool;
let mut is_last_block: bool = false;

(buf, is_last_block) = self.fill_data_block(window_base).await?;
window.push_back(buf);

// Send OACK after we manage to read the first block from reader.
//
// We do this because we want to give the developers the option to
// produce an error after they construct a reader.
if let Some(opts) = self.oack_opts.as_ref() {
if let Some(opts) = self.oack_opts.take() {
trace!("RRQ OACK (peer: {}, opts: {:?}", &self.peer, &opts);
// Send OACK after we manage to read the first block from the reader for
// non-transfer size probe requests (oack.transfer_size value is set).
// During transfer size probes a client aborts the connection after receiving
// oack from the server. For normal requests we do this because we want to give
// the developers the option to produce an error after they construct a reader.
if opts.transfer_size.is_none() {
(buf, is_last_block) =
self.fill_data_block(window_base).await?;
window.push_back(buf);
}
let mut buff = BytesMut::with_capacity(PACKET_DATA_HEADER_LEN + 64);
Packet::OAck(opts.to_owned()).encode(&mut buff);
// OACK is not really part of the window, so we send it separately
Expand Down Expand Up @@ -186,15 +197,17 @@ where
);
return Ok(blocks_acked);
}
Err(ref e) if e.kind() == io::ErrorKind::TimedOut => {
Err(Error::Io(ref e))
if e.kind() == io::ErrorKind::TimedOut =>
{
trace!(
"RRQ (peer: {}, block_id: {}) - Timeout",
&self.peer,
window_base
);
continue;
}
Err(e) => return Err(e.into()),
Err(e) => return Err(e),
}
}

Expand All @@ -206,7 +219,7 @@ where
&mut self,
window_base: u16,
window_len: u16,
) -> io::Result<u16> {
) -> Result<u16> {
// We can not use `self` within `async_std::io::timeout` because not all
// struct members implement `Sync`. So we borrow only what we need.
let socket = &mut self.socket;
Expand All @@ -224,30 +237,37 @@ where
}

// parse only valid Ack packets, the rest are ignored
if let Ok(Packet::Ack(recved_block_id)) =
Packet::decode(&buf[..len])
// if let Ok(Packet::Ack(recved_block_id)) =
match Packet::decode(&buf[..len])
{
let window_end = window_base.wrapping_add(window_len);

if window_end > window_base {
// window_end did not wrap
if recved_block_id >= window_base && recved_block_id < window_end {
// number of blocks acked
return Ok(recved_block_id-window_base+1u16);
}
else {
trace!("Unexpected ack packet {recved_block_id}, window_base: {window_base}, window_len: {window_len}");
}
}else {
// window_end wrapped
if recved_block_id >= window_base {
return Ok(1u16 + (recved_block_id - window_base));
} else if recved_block_id < window_end {
return Ok(1u16 + recved_block_id + (window_len - window_end));
Ok(Packet::Ack(recved_block_id)) => {
let window_end = window_base.wrapping_add(window_len);

if window_end > window_base {
// window_end did not wrap
if recved_block_id >= window_base && recved_block_id < window_end {
// number of blocks acked
return Ok(recved_block_id - window_base + 1u16);
} else {
trace!("Unexpected ack packet {recved_block_id}, window_base: {window_base}, window_len: {window_len}");
}
} else {
trace!("Unexpected ack packet {recved_block_id}, window_base: {window_base}, window_len: {window_len}");
// window_end wrapped
if recved_block_id >= window_base {
return Ok(1u16 + (recved_block_id - window_base));
} else if recved_block_id < window_end {
return Ok(1u16 + recved_block_id + (window_len - window_end));
} else {
trace!("Unexpected ack packet {recved_block_id}, window_base: {window_base}, window_len: {window_len}");
}
}
},
Ok(Packet::Error(error)) if error.is_client_error()=> {
// pass errors coming from the client
return Err(Error::Packet(error))
}
// ignore all other errors
_ => {}
}
}
})
Expand Down
28 changes: 19 additions & 9 deletions src/server/write_req.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,20 @@ where
pub(crate) async fn handle(&mut self) {
if let Err(e) = self.try_handle().await {
trace!("WRQ request failed (peer: {}, error: {}", self.peer, &e);

Packet::Error(e.into()).encode(&mut self.buffer);
let buf = self.buffer.split().freeze();
// Errors are never retransmitted.
// We do not care if `send_to` resulted to an IO error.
let _ = self.socket.send_to(&buf[..], self.peer).await;
match e {
Error::Packet(client_error)
if client_error.is_client_error() =>
{
//we don't have to acknowledge client errors}
}
e => {
Packet::Error(e.into()).encode(&mut self.buffer);
let buf = self.buffer.split().freeze();
// Errors are never retransmitted.
// We do not care if `send_to` resulted to an IO error.
let _ = self.socket.send_to(&buf[..], self.peer).await;
}
}
}
}

Expand Down Expand Up @@ -122,19 +130,21 @@ where
self.socket.send_to(&self.ack, self.peer).await?;
return Ok(data);
}
Err(ref e) if e.kind() == io::ErrorKind::TimedOut => {
Err(Error::Io(ref e))
if e.kind() == io::ErrorKind::TimedOut =>
{
// On timeout reply with the previous ACK packet
self.socket.send_to(&self.ack, self.peer).await?;
continue;
}
Err(e) => return Err(e.into()),
Err(e) => return Err(e),
}
}

Err(Error::MaxSendRetriesReached(self.peer, block_id))
}

async fn recv_data_block(&mut self, block_id: u16) -> io::Result<Bytes> {
async fn recv_data_block(&mut self, block_id: u16) -> Result<Bytes> {
let socket = &mut self.socket;
let peer = self.peer;

Expand Down
111 changes: 111 additions & 0 deletions src/tests/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
use crate::packet::Error::OptionNegotiationFailed;
use crate::packet::{self, Mode, Opts, RwReq};
use crate::server::TftpServerBuilder;
use async_executor::Executor;
use async_io::{Async, Timer};
use futures_lite::future::block_on;
use futures_lite::{future, AsyncRead};
use std::cell::Cell;
use std::io;
use std::net::UdpSocket;
use std::rc::Rc;
use std::sync::Arc;

use super::handlers::*;
use super::packet::packet_to_bytes;
use std::task::Poll;
use std::time::Duration;

struct ResultsReader {
results: Vec<io::Result<Vec<u8>>>,
}

impl AsyncRead for ResultsReader {
fn poll_read(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(match self.get_mut().results.pop() {
Some(Ok(result)) => {
buf[..result.len()].copy_from_slice(&result);
Ok(result.len())
}
Some(Err(err)) => Err(err),
None => Err(io::ErrorKind::NotFound.into()),
})
}
}

#[test]
fn test_abort_on_read_error() {
let ex = Arc::new(Executor::new());
let transferred = Rc::new(Cell::new(false));

block_on(ex.run({
let ex = ex.clone();
let transferred = transferred.clone();

async move {
let handler = ReaderHandler::new(ResultsReader { results: vec![Err(io::ErrorKind::InvalidInput.into())] }, Some(4));
let tftpd = TftpServerBuilder::with_handler(handler)
.bind("127.0.0.1:0".parse().unwrap())
.timeout(Duration::from_secs(1))
.build()
.await
.unwrap();
let addr = tftpd.listen_addr().unwrap();
// start server
let server_task = ex.spawn(async move {
tftpd.serve().await.unwrap();
});

let socket = Async::<UdpSocket>::bind(([127, 0, 0, 1], 0)).unwrap();
// send rrq with transfer size to 0 to simulate a transfer size probe
let req_opts = Opts {
transfer_size: Some(0),
..Default::default()
};
let rrq = packet::Packet::Rrq(RwReq {
filename: "abc".to_string(),
mode: Mode::Octet,
opts: req_opts,
});
socket.send_to(&packet_to_bytes(&rrq), addr).await.unwrap();

// read the ack
let mut buf = [0u8; 1024];
let (len, peer) = socket.recv_from(&mut buf).await.unwrap();
let response = packet::Packet::decode(&buf[..len]).unwrap();
assert!(matches!(
response,
packet::Packet::OAck(Opts {transfer_size: Some(4),..})),
"Server did not send OAck packet: {:?}", response);

// send error packet
let abort_packet = packet::Packet::Error(OptionNegotiationFailed);
socket.send_to(&packet_to_bytes(&abort_packet), peer).await.unwrap();

// make sure the server doesn't send anything else
assert!(
future::race(
async move {
Timer::after(Duration::from_secs(3)).await;
true
},
async move {
// fail if we get anything after sending OptionNegotiationFailed error
let _ = socket.recv_from(&mut buf).await.unwrap();
false
}
)
.await,
"Server sent data after client sent OptionNegotiationFailed error"
);
server_task.cancel().await;
transferred.set(true);
}
}));

assert!(transferred.get());
}
Loading
Loading