Skip to content

Commit 41284d2

Browse files
committed
compression example works again
1 parent f3d02e7 commit 41284d2

File tree

6 files changed

+332
-206
lines changed

6 files changed

+332
-206
lines changed

examples/compression.rs_ renamed to examples/compression.rs

Lines changed: 92 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
/// grade code you might nevertheless put the tasks into a [tokio::task::JoinSet] or
1010
/// [n0_future::FuturesUnordered].
1111
mod common;
12-
use std::{path::PathBuf, time::Instant};
12+
use std::{io, path::PathBuf, time::Instant};
1313

1414
use anyhow::Result;
1515
use async_compression::tokio::{bufread::Lz4Decoder, write::Lz4Encoder};
@@ -22,7 +22,8 @@ use iroh_blobs::{
2222
get::fsm::{AtConnected, ConnectedNext, EndBlobNext},
2323
protocol::{ChunkRangesSeq, GetRequest, Request},
2424
provider::{
25-
events::{ClientConnected, EventSender, HasErrorCode}, handle_get, AsyncReadRecvStream, AsyncWriteSendStream, ErrorHandler, StreamPair
25+
events::{ClientConnected, EventSender, HasErrorCode},
26+
handle_get, AsyncReadRecvStream, AsyncWriteSendStream, StreamPair,
2627
},
2728
store::mem::MemStore,
2829
ticket::BlobTicket,
@@ -50,11 +51,91 @@ pub enum Args {
5051
},
5152
}
5253

53-
type CompressedWriter =
54-
AsyncWriteSendStream<async_compression::tokio::write::Lz4Encoder<iroh::endpoint::SendStream>>;
55-
type CompressedReader = AsyncReadRecvStream<
54+
struct CompressedWriter(async_compression::tokio::write::Lz4Encoder<iroh::endpoint::SendStream>);
55+
struct CompressedReader(
5656
async_compression::tokio::bufread::Lz4Decoder<BufReader<iroh::endpoint::RecvStream>>,
57-
>;
57+
);
58+
59+
impl iroh_blobs::provider::SendStream for CompressedWriter {
60+
async fn send_bytes(&mut self, bytes: bytes::Bytes) -> io::Result<()> {
61+
AsyncWriteSendStream::new(self).send_bytes(bytes).await
62+
}
63+
64+
async fn send<const L: usize>(&mut self, buf: &[u8; L]) -> io::Result<()> {
65+
AsyncWriteSendStream::new(self).send(buf).await
66+
}
67+
68+
async fn sync(&mut self) -> io::Result<()> {
69+
AsyncWriteSendStream::new(self).sync().await
70+
}
71+
}
72+
73+
impl iroh_blobs::provider::RecvStream for CompressedReader {
74+
async fn recv_bytes(&mut self, len: usize) -> io::Result<bytes::Bytes> {
75+
AsyncReadRecvStream::new(self).recv_bytes(len).await
76+
}
77+
78+
async fn recv_bytes_exact(&mut self, len: usize) -> io::Result<bytes::Bytes> {
79+
AsyncReadRecvStream::new(self).recv_bytes_exact(len).await
80+
}
81+
82+
async fn recv<const L: usize>(&mut self) -> io::Result<[u8; L]> {
83+
AsyncReadRecvStream::new(self).recv::<L>().await
84+
}
85+
}
86+
87+
impl tokio::io::AsyncRead for CompressedReader {
88+
fn poll_read(
89+
mut self: std::pin::Pin<&mut Self>,
90+
cx: &mut std::task::Context<'_>,
91+
buf: &mut tokio::io::ReadBuf<'_>,
92+
) -> std::task::Poll<io::Result<()>> {
93+
std::pin::Pin::new(&mut self.0).poll_read(cx, buf)
94+
}
95+
}
96+
97+
impl tokio::io::AsyncWrite for CompressedWriter {
98+
fn poll_write(
99+
mut self: std::pin::Pin<&mut Self>,
100+
cx: &mut std::task::Context<'_>,
101+
buf: &[u8],
102+
) -> std::task::Poll<io::Result<usize>> {
103+
std::pin::Pin::new(&mut self.0).poll_write(cx, buf)
104+
}
105+
106+
fn poll_flush(
107+
mut self: std::pin::Pin<&mut Self>,
108+
cx: &mut std::task::Context<'_>,
109+
) -> std::task::Poll<io::Result<()>> {
110+
std::pin::Pin::new(&mut self.0).poll_flush(cx)
111+
}
112+
113+
fn poll_shutdown(
114+
mut self: std::pin::Pin<&mut Self>,
115+
cx: &mut std::task::Context<'_>,
116+
) -> std::task::Poll<io::Result<()>> {
117+
std::pin::Pin::new(&mut self.0).poll_shutdown(cx)
118+
}
119+
}
120+
121+
impl iroh_blobs::provider::SendStreamSpecific for CompressedWriter {
122+
fn reset(&mut self, code: quinn::VarInt) -> io::Result<()> {
123+
self.0.get_mut().reset(code)?;
124+
Ok(())
125+
}
126+
127+
async fn stopped(&mut self) -> io::Result<Option<quinn::VarInt>> {
128+
let res = self.0.get_mut().stopped().await?;
129+
Ok(res)
130+
}
131+
}
132+
133+
impl iroh_blobs::provider::RecvStreamSpecific for CompressedReader {
134+
fn stop(&mut self, code: quinn::VarInt) -> io::Result<()> {
135+
self.0.get_mut().get_mut().stop(code)?;
136+
Ok(())
137+
}
138+
}
58139

59140
#[derive(Debug, Clone)]
60141
struct CompressedBlobsProtocol {
@@ -71,22 +152,6 @@ impl CompressedBlobsProtocol {
71152
}
72153
}
73154

74-
struct CompressedErrorHandler;
75-
76-
impl ErrorHandler for CompressedErrorHandler {
77-
type W = CompressedWriter;
78-
79-
type R = CompressedReader;
80-
81-
async fn stop(reader: &mut Self::R, code: quinn::VarInt) {
82-
reader.0.get_mut().get_mut().stop(code).ok();
83-
}
84-
85-
async fn reset(writer: &mut Self::W, code: quinn::VarInt) {
86-
writer.0.get_mut().reset(code).ok();
87-
}
88-
}
89-
90155
impl ProtocolHandler for CompressedBlobsProtocol {
91156
async fn accept(
92157
&self,
@@ -108,15 +173,15 @@ impl ProtocolHandler for CompressedBlobsProtocol {
108173
}
109174
while let Ok((send, recv)) = connection.accept_bi().await {
110175
let stream_id = send.id().index();
111-
let send = TokioStreamWriter(Lz4Encoder::new(send));
112-
let recv = TokioStreamReader(Lz4Decoder::new(BufReader::new(recv)));
176+
let send = CompressedWriter(Lz4Encoder::new(send));
177+
let recv = CompressedReader(Lz4Decoder::new(BufReader::new(recv)));
113178
let store = self.store.clone();
114179
let mut pair =
115180
StreamPair::new(connection_id, stream_id, recv, send, self.events.clone());
116181
tokio::spawn(async move {
117182
let request = pair.read_request().await?;
118183
if let Request::Get(request) = request {
119-
handle_get::<CompressedErrorHandler>(pair, store, request).await?;
184+
handle_get(pair, store, request).await?;
120185
}
121186
anyhow::Ok(())
122187
});
@@ -156,8 +221,8 @@ async fn main() -> Result<()> {
156221
Args::Get { ticket, target } => {
157222
let conn = endpoint.connect(ticket.node_addr().clone(), ALPN).await?;
158223
let (send, recv) = conn.open_bi().await?;
159-
let send = AsyncWriteSendStream(Lz4Encoder::new(send));
160-
let recv = AsyncReadRecvStream::new(Lz4Decoder::new(BufReader::new(recv)));
224+
let send = CompressedWriter(Lz4Encoder::new(send));
225+
let recv = CompressedReader(Lz4Decoder::new(BufReader::new(recv)));
161226
let request = GetRequest {
162227
hash: ticket.hash(),
163228
ranges: ChunkRangesSeq::root(),

src/api/blobs.rs

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use bao_tree::{
2323
};
2424
use bytes::Bytes;
2525
use genawaiter::sync::Gen;
26-
use iroh_io::{AsyncStreamReader, AsyncStreamWriter};
26+
use iroh_io::AsyncStreamWriter;
2727
use irpc::channel::{mpsc, oneshot};
2828
use n0_future::{future, stream, Stream, StreamExt};
2929
use range_collections::{range_set::RangeSetRange, RangeSet2};
@@ -55,7 +55,7 @@ use super::{
5555
};
5656
use crate::{
5757
api::proto::{BatchRequest, ImportByteStreamUpdate},
58-
provider::events::ClientResult,
58+
provider::{events::ClientResult, RecvStreamAsyncStreamReader},
5959
store::IROH_BLOCK_SIZE,
6060
util::temp_tag::TempTag,
6161
BlobFormat, Hash, HashAndFormat,
@@ -429,13 +429,13 @@ impl Blobs {
429429
}
430430

431431
#[cfg_attr(feature = "hide-proto-docs", doc(hidden))]
432-
pub async fn import_bao_reader<R: AsyncStreamReader>(
432+
pub async fn import_bao_reader<R: crate::provider::RecvStream>(
433433
&self,
434434
hash: Hash,
435435
ranges: ChunkRanges,
436436
mut reader: R,
437437
) -> RequestResult<R> {
438-
let size = u64::from_le_bytes(reader.read::<8>().await.map_err(super::Error::other)?);
438+
let size = u64::from_le_bytes(reader.recv::<8>().await.map_err(super::Error::other)?);
439439
let Some(size) = NonZeroU64::new(size) else {
440440
return if hash == Hash::EMPTY {
441441
Ok(reader)
@@ -444,7 +444,12 @@ impl Blobs {
444444
};
445445
};
446446
let tree = BaoTree::new(size.get(), IROH_BLOCK_SIZE);
447-
let mut decoder = ResponseDecoder::new(hash.into(), ranges, tree, reader);
447+
let mut decoder = ResponseDecoder::new(
448+
hash.into(),
449+
ranges,
450+
tree,
451+
RecvStreamAsyncStreamReader::new(reader),
452+
);
448453
let options = ImportBaoOptions { hash, size };
449454
let handle = self.import_bao_with_opts(options, 32).await?;
450455
let driver = async move {
@@ -463,7 +468,7 @@ impl Blobs {
463468
let fut = async move { handle.rx.await.map_err(io::Error::other)? };
464469
let (reader, res) = tokio::join!(driver, fut);
465470
res?;
466-
Ok(reader?)
471+
Ok(reader?.into_inner())
467472
}
468473

469474
#[cfg_attr(feature = "hide-proto-docs", doc(hidden))]
@@ -1068,7 +1073,7 @@ impl ExportBaoProgress {
10681073
}
10691074

10701075
/// Write quinn variant that also feeds a progress writer.
1071-
pub(crate) async fn write_with_progress<W: AsyncStreamWriter>(
1076+
pub(crate) async fn write_with_progress<W: crate::provider::SendStream>(
10721077
self,
10731078
writer: &mut W,
10741079
progress: &mut impl WriteProgress,
@@ -1080,19 +1085,19 @@ impl ExportBaoProgress {
10801085
match item {
10811086
EncodedItem::Size(size) => {
10821087
progress.send_transfer_started(index, hash, size).await;
1083-
writer.write(&size.to_le_bytes()).await?;
1088+
writer.send(&size.to_le_bytes()).await?;
10841089
progress.log_other_write(8);
10851090
}
10861091
EncodedItem::Parent(parent) => {
1087-
let mut data = vec![0u8; 64];
1092+
let mut data = [0u8; 64];
10881093
data[..32].copy_from_slice(parent.pair.0.as_bytes());
10891094
data[32..].copy_from_slice(parent.pair.1.as_bytes());
1090-
writer.write(&data).await?;
1095+
writer.send(&data).await?;
10911096
progress.log_other_write(64);
10921097
}
10931098
EncodedItem::Leaf(leaf) => {
10941099
let len = leaf.data.len();
1095-
writer.write_bytes(leaf.data).await?;
1100+
writer.send_bytes(leaf.data).await?;
10961101
progress
10971102
.notify_payload_write(index, leaf.offset, len)
10981103
.await?;

src/api/remote.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use crate::{
1616
get::{
1717
fsm::DecodeError,
1818
get_error::{BadRequestSnafu, LocalFailureSnafu},
19-
GetError, GetResult, IrohStreamWriter, Stats,
19+
GetError, GetResult, Stats,
2020
},
2121
protocol::{
2222
GetManyRequest, ObserveItem, ObserveRequest, PushRequest, Request, RequestType,
@@ -593,7 +593,6 @@ impl Remote {
593593
let mut request_ranges = request.ranges.iter_infinite();
594594
let root = request.hash;
595595
let root_ranges = request_ranges.next().expect("infinite iterator");
596-
let mut send = IrohStreamWriter(send);
597596
if !root_ranges.is_empty() {
598597
self.store()
599598
.export_bao(root, root_ranges.clone())
@@ -602,7 +601,7 @@ impl Remote {
602601
}
603602
if request.ranges.is_blob() {
604603
// we are done
605-
send.0.finish()?;
604+
send.finish()?;
606605
return Ok(Default::default());
607606
}
608607
let hash_seq = self.store().get_bytes(root).await?;
@@ -617,7 +616,7 @@ impl Remote {
617616
.await?;
618617
}
619618
}
620-
send.0.finish()?;
619+
send.finish()?;
621620
Ok(Default::default())
622621
}
623622

0 commit comments

Comments
 (0)