Skip to content

Commit 4e8387a

Browse files
committed
Refactor error and make get and provide side generic
1 parent d764dc0 commit 4e8387a

File tree

10 files changed

+428
-184
lines changed

10 files changed

+428
-184
lines changed

src/api.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ pub enum ExportBaoError {
9898
#[snafu(display("encode error: {source}"))]
9999
ExportBaoInner { source: bao_tree::io::EncodeError },
100100
#[snafu(display("client error: {source}"))]
101-
ClientError { source: ProgressError },
101+
Progress { source: ProgressError },
102102
}
103103

104104
impl From<ExportBaoError> for Error {
@@ -109,7 +109,7 @@ impl From<ExportBaoError> for Error {
109109
ExportBaoError::Request { source, .. } => Self::Io(source.into()),
110110
ExportBaoError::ExportBaoIo { source, .. } => Self::Io(source),
111111
ExportBaoError::ExportBaoInner { source, .. } => Self::Io(source.into()),
112-
ExportBaoError::ClientError { source, .. } => Self::Io(source.into()),
112+
ExportBaoError::Progress { source, .. } => Self::Io(source.into()),
113113
}
114114
}
115115
}
@@ -157,7 +157,7 @@ impl From<bao_tree::io::EncodeError> for ExportBaoError {
157157

158158
impl From<ProgressError> for ExportBaoError {
159159
fn from(value: ProgressError) -> Self {
160-
ClientSnafu.into_error(value)
160+
ProgressSnafu.into_error(value)
161161
}
162162
}
163163

src/api/downloader.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ mod tests {
563563
.download(request, Shuffled::new(vec![node1_id, node2_id]))
564564
.stream()
565565
.await?;
566-
while let Some(_) = progress.next().await {}
566+
while progress.next().await.is_some() {}
567567
assert_eq!(store3.get_bytes(tt1.hash).await?.deref(), b"hello world");
568568
assert_eq!(store3.get_bytes(tt2.hash).await?.deref(), b"hello world 2");
569569
Ok(())
@@ -606,7 +606,7 @@ mod tests {
606606
))
607607
.stream()
608608
.await?;
609-
while let Some(_) = progress.next().await {}
609+
while progress.next().await.is_some() {}
610610
}
611611
if false {
612612
let conn = r3.endpoint().connect(node1_addr, crate::ALPN).await?;
@@ -668,7 +668,7 @@ mod tests {
668668
))
669669
.stream()
670670
.await?;
671-
while let Some(_) = progress.next().await {}
671+
while progress.next().await.is_some() {}
672672
Ok(())
673673
}
674674
}

src/api/remote.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,7 @@ impl GetProgress {
9999

100100
pub async fn complete(self) -> GetResult<Stats> {
101101
just_result(self.stream()).await.unwrap_or_else(|| {
102-
Err(LocalFailureSnafu
103-
.into_error(anyhow::anyhow!("stream closed without result").into()))
102+
Err(LocalFailureSnafu.into_error(anyhow::anyhow!("stream closed without result")))
104103
})
105104
}
106105
}
@@ -512,15 +511,15 @@ impl Remote {
512511
let local = self
513512
.local(content)
514513
.await
515-
.map_err(|e: anyhow::Error| LocalFailureSnafu.into_error(e.into()))?;
514+
.map_err(|e: anyhow::Error| LocalFailureSnafu.into_error(e))?;
516515
if local.is_complete() {
517516
return Ok(Default::default());
518517
}
519518
let request = local.missing();
520519
let conn = conn
521520
.connection()
522521
.await
523-
.map_err(|e| LocalFailureSnafu.into_error(e.into()))?;
522+
.map_err(|e| LocalFailureSnafu.into_error(e))?;
524523
let stats = self.execute_get_sink(&conn, request, progress).await?;
525524
Ok(stats)
526525
}
@@ -914,8 +913,7 @@ async fn get_blob_ranges_impl(
914913
};
915914
let complete = async move {
916915
handle.rx.await.map_err(|e| {
917-
LocalFailureSnafu
918-
.into_error(anyhow::anyhow!("error reading from import stream: {e}").into())
916+
LocalFailureSnafu.into_error(anyhow::anyhow!("error reading from import stream: {e}"))
919917
})
920918
};
921919
let (_, end) = tokio::try_join!(complete, write)?;

src/get.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ impl AsyncStreamWriter for IrohStreamWriter {
5353
}
5454

5555
async fn sync(&mut self) -> io::Result<()> {
56-
Ok(self.0.flush().await?)
56+
self.0.flush().await
5757
}
5858
}
5959

@@ -716,7 +716,7 @@ pub mod fsm {
716716
DecodeError::LeafNotFound { .. } => {
717717
io::Error::new(io::ErrorKind::UnexpectedEof, cause)
718718
}
719-
DecodeError::Read { source, .. } => source.into(),
719+
DecodeError::Read { source, .. } => source,
720720
DecodeError::Write { source, .. } => source,
721721
_ => io::Error::other(cause),
722722
}

src/get/error.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ impl GetError {
8282
Self::ConnectedNext {
8383
source: ConnectedNextError::Write { source, .. },
8484
..
85-
} => Some(&source),
85+
} => Some(source),
8686
_ => None,
8787
}
8888
}
@@ -92,7 +92,7 @@ impl GetError {
9292
Self::InitialNext {
9393
source: InitialNextError::Open { source, .. },
9494
..
95-
} => Some(&source),
95+
} => Some(source),
9696
_ => None,
9797
}
9898
}
@@ -102,15 +102,15 @@ impl GetError {
102102
Self::AtBlobHeaderNext {
103103
source: AtBlobHeaderNextError::Read { source, .. },
104104
..
105-
} => Some(&source),
105+
} => Some(source),
106106
Self::Decode {
107107
source: DecodeError::Read { source, .. },
108108
..
109-
} => Some(&source),
109+
} => Some(source),
110110
Self::AtClosingNext {
111111
source: AtClosingNextError::Read { source, .. },
112112
..
113-
} => Some(&source),
113+
} => Some(source),
114114
_ => None,
115115
}
116116
}
@@ -120,7 +120,7 @@ impl GetError {
120120
Self::Decode {
121121
source: DecodeError::Write { source, .. },
122122
..
123-
} => Some(&source),
123+
} => Some(source),
124124
_ => None,
125125
}
126126
}

src/get/request.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ impl GetBlobResult {
5858
let mut parts = Vec::new();
5959
let stats = loop {
6060
let Some(item) = self.next().await else {
61-
return Err(LocalFailureSnafu.into_error(anyhow::anyhow!("unexpected end").into()));
61+
return Err(LocalFailureSnafu.into_error(anyhow::anyhow!("unexpected end")));
6262
};
6363
match item {
6464
GetBlobItem::Item(item) => {
@@ -238,11 +238,11 @@ pub async fn get_hash_seq_and_sizes(
238238
let (at_blob_content, size) = at_start_root.next().await?;
239239
// check the size to avoid parsing a maliciously large hash seq
240240
if size > max_size {
241-
return Err(BadRequestSnafu.into_error(anyhow::anyhow!("size too large").into()));
241+
return Err(BadRequestSnafu.into_error(anyhow::anyhow!("size too large")));
242242
}
243243
let (mut curr, hash_seq) = at_blob_content.concatenate_into_vec().await?;
244-
let hash_seq = HashSeq::try_from(Bytes::from(hash_seq))
245-
.map_err(|e| BadRequestSnafu.into_error(e.into()))?;
244+
let hash_seq =
245+
HashSeq::try_from(Bytes::from(hash_seq)).map_err(|e| BadRequestSnafu.into_error(e))?;
246246
let mut sizes = Vec::with_capacity(hash_seq.len());
247247
let closing = loop {
248248
match curr.next() {

src/protocol.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -382,15 +382,14 @@ use bao_tree::{io::round_up_to_chunks, ChunkNum};
382382
use builder::GetRequestBuilder;
383383
use derive_more::From;
384384
use iroh::endpoint::VarInt;
385-
use irpc::util::AsyncReadVarintExt;
385+
use iroh_io::AsyncStreamReader;
386386
use postcard::experimental::max_size::MaxSize;
387387
use range_collections::{range_set::RangeSetEntry, RangeSet2};
388388
use serde::{Deserialize, Serialize};
389389
mod range_spec;
390390
pub use bao_tree::ChunkRanges;
391391
pub use range_spec::{ChunkRangesSeq, NonEmptyRequestRangeSpecIter, RangeSpec};
392392
use snafu::{GenerateImplicitData, Snafu};
393-
use tokio::io::AsyncReadExt;
394393

395394
use crate::{api::blobs::Bitfield, provider::RecvStreamExt, BlobFormat, Hash, HashAndFormat};
396395

@@ -448,7 +447,7 @@ pub enum RequestType {
448447
}
449448

450449
impl Request {
451-
pub async fn read_async(reader: &mut iroh::endpoint::RecvStream) -> io::Result<(Self, usize)> {
450+
pub async fn read_async<R: AsyncStreamReader>(reader: &mut R) -> io::Result<(Self, usize)> {
452451
let request_type = reader.read_u8().await?;
453452
let request_type: RequestType = postcard::from_bytes(std::slice::from_ref(&request_type))
454453
.map_err(|_| {

0 commit comments

Comments
 (0)