Skip to content

Commit 349c36b

Browse files
committed
Generic receive into store
1 parent 41284d2 commit 349c36b

File tree

4 files changed

+186
-233
lines changed

4 files changed

+186
-233
lines changed

examples/compression.rs

Lines changed: 38 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,24 @@
99
/// grade code you might nevertheless put the tasks into a [tokio::task::JoinSet] or
1010
/// [n0_future::FuturesUnordered].
1111
mod common;
12-
use std::{io, path::PathBuf, time::Instant};
12+
use std::{io, path::PathBuf};
1313

1414
use anyhow::Result;
1515
use async_compression::tokio::{bufread::Lz4Decoder, write::Lz4Encoder};
16-
use bao_tree::blake3;
1716
use clap::Parser;
1817
use common::setup_logging;
19-
use iroh::protocol::ProtocolHandler;
18+
use iroh::{endpoint::VarInt, protocol::ProtocolHandler};
2019
use iroh_blobs::{
2120
api::Store,
22-
get::fsm::{AtConnected, ConnectedNext, EndBlobNext},
23-
protocol::{ChunkRangesSeq, GetRequest, Request},
2421
provider::{
2522
events::{ClientConnected, EventSender, HasErrorCode},
26-
handle_get, AsyncReadRecvStream, AsyncWriteSendStream, StreamPair,
23+
handle_stream, AsyncReadRecvStream, AsyncWriteSendStream, RecvStreamSpecific,
24+
SendStreamSpecific, StreamPair,
2725
},
2826
store::mem::MemStore,
2927
ticket::BlobTicket,
3028
};
31-
use tokio::io::BufReader;
29+
use tokio::io::{AsyncRead, AsyncWrite, BufReader};
3230
use tracing::debug;
3331

3432
use crate::common::get_or_generate_secret_key;
@@ -51,89 +49,35 @@ pub enum Args {
5149
},
5250
}
5351

54-
struct CompressedWriter(async_compression::tokio::write::Lz4Encoder<iroh::endpoint::SendStream>);
55-
struct CompressedReader(
56-
async_compression::tokio::bufread::Lz4Decoder<BufReader<iroh::endpoint::RecvStream>>,
57-
);
52+
struct CompressedWriteStream(Lz4Encoder<iroh::endpoint::SendStream>);
5853

59-
impl iroh_blobs::provider::SendStream for CompressedWriter {
60-
async fn send_bytes(&mut self, bytes: bytes::Bytes) -> io::Result<()> {
61-
AsyncWriteSendStream::new(self).send_bytes(bytes).await
54+
impl SendStreamSpecific for CompressedWriteStream {
55+
fn inner(&mut self) -> &mut (impl AsyncWrite + Unpin + Send) {
56+
&mut self.0
6257
}
6358

64-
async fn send<const L: usize>(&mut self, buf: &[u8; L]) -> io::Result<()> {
65-
AsyncWriteSendStream::new(self).send(buf).await
59+
fn reset(&mut self, code: VarInt) -> io::Result<()> {
60+
Ok(self.0.get_mut().reset(code)?)
6661
}
6762

68-
async fn sync(&mut self) -> io::Result<()> {
69-
AsyncWriteSendStream::new(self).sync().await
63+
async fn stopped(&mut self) -> io::Result<Option<VarInt>> {
64+
Ok(self.0.get_mut().stopped().await?)
7065
}
7166
}
7267

73-
impl iroh_blobs::provider::RecvStream for CompressedReader {
74-
async fn recv_bytes(&mut self, len: usize) -> io::Result<bytes::Bytes> {
75-
AsyncReadRecvStream::new(self).recv_bytes(len).await
76-
}
77-
78-
async fn recv_bytes_exact(&mut self, len: usize) -> io::Result<bytes::Bytes> {
79-
AsyncReadRecvStream::new(self).recv_bytes_exact(len).await
80-
}
81-
82-
async fn recv<const L: usize>(&mut self) -> io::Result<[u8; L]> {
83-
AsyncReadRecvStream::new(self).recv::<L>().await
84-
}
85-
}
86-
87-
impl tokio::io::AsyncRead for CompressedReader {
88-
fn poll_read(
89-
mut self: std::pin::Pin<&mut Self>,
90-
cx: &mut std::task::Context<'_>,
91-
buf: &mut tokio::io::ReadBuf<'_>,
92-
) -> std::task::Poll<io::Result<()>> {
93-
std::pin::Pin::new(&mut self.0).poll_read(cx, buf)
94-
}
95-
}
96-
97-
impl tokio::io::AsyncWrite for CompressedWriter {
98-
fn poll_write(
99-
mut self: std::pin::Pin<&mut Self>,
100-
cx: &mut std::task::Context<'_>,
101-
buf: &[u8],
102-
) -> std::task::Poll<io::Result<usize>> {
103-
std::pin::Pin::new(&mut self.0).poll_write(cx, buf)
104-
}
105-
106-
fn poll_flush(
107-
mut self: std::pin::Pin<&mut Self>,
108-
cx: &mut std::task::Context<'_>,
109-
) -> std::task::Poll<io::Result<()>> {
110-
std::pin::Pin::new(&mut self.0).poll_flush(cx)
111-
}
112-
113-
fn poll_shutdown(
114-
mut self: std::pin::Pin<&mut Self>,
115-
cx: &mut std::task::Context<'_>,
116-
) -> std::task::Poll<io::Result<()>> {
117-
std::pin::Pin::new(&mut self.0).poll_shutdown(cx)
118-
}
119-
}
68+
struct CompressedReadStream(Lz4Decoder<BufReader<iroh::endpoint::RecvStream>>);
12069

121-
impl iroh_blobs::provider::SendStreamSpecific for CompressedWriter {
122-
fn reset(&mut self, code: quinn::VarInt) -> io::Result<()> {
123-
self.0.get_mut().reset(code)?;
124-
Ok(())
70+
impl RecvStreamSpecific for CompressedReadStream {
71+
fn inner(&mut self) -> &mut (impl AsyncRead + Unpin + Send) {
72+
&mut self.0
12573
}
12674

127-
async fn stopped(&mut self) -> io::Result<Option<quinn::VarInt>> {
128-
let res = self.0.get_mut().stopped().await?;
129-
Ok(res)
75+
fn stop(&mut self, code: VarInt) -> io::Result<()> {
76+
Ok(self.0.get_mut().get_mut().stop(code)?)
13077
}
131-
}
13278

133-
impl iroh_blobs::provider::RecvStreamSpecific for CompressedReader {
134-
fn stop(&mut self, code: quinn::VarInt) -> io::Result<()> {
135-
self.0.get_mut().get_mut().stop(code)?;
136-
Ok(())
79+
fn id(&self) -> u64 {
80+
self.0.get_ref().get_ref().id().index()
13781
}
13882
}
13983

@@ -172,19 +116,13 @@ impl ProtocolHandler for CompressedBlobsProtocol {
172116
return Ok(());
173117
}
174118
while let Ok((send, recv)) = connection.accept_bi().await {
175-
let stream_id = send.id().index();
176-
let send = CompressedWriter(Lz4Encoder::new(send));
177-
let recv = CompressedReader(Lz4Decoder::new(BufReader::new(recv)));
119+
let send = AsyncWriteSendStream::new(CompressedWriteStream(Lz4Encoder::new(send)));
120+
let recv = AsyncReadRecvStream::new(CompressedReadStream(Lz4Decoder::new(
121+
BufReader::new(recv),
122+
)));
178123
let store = self.store.clone();
179-
let mut pair =
180-
StreamPair::new(connection_id, stream_id, recv, send, self.events.clone());
181-
tokio::spawn(async move {
182-
let request = pair.read_request().await?;
183-
if let Request::Get(request) = request {
184-
handle_get(pair, store, request).await?;
185-
}
186-
anyhow::Ok(())
187-
});
124+
let pair = StreamPair::new(connection_id, recv, send, self.events.clone());
125+
tokio::spawn(handle_stream(pair, store));
188126
}
189127
Ok(())
190128
}
@@ -219,34 +157,21 @@ async fn main() -> Result<()> {
219157
router.shutdown().await?;
220158
}
221159
Args::Get { ticket, target } => {
160+
let store = MemStore::new();
222161
let conn = endpoint.connect(ticket.node_addr().clone(), ALPN).await?;
162+
let connection_id = conn.stable_id() as u64;
223163
let (send, recv) = conn.open_bi().await?;
224-
let send = CompressedWriter(Lz4Encoder::new(send));
225-
let recv = CompressedReader(Lz4Decoder::new(BufReader::new(recv)));
226-
let request = GetRequest {
227-
hash: ticket.hash(),
228-
ranges: ChunkRangesSeq::root(),
229-
};
230-
let connected =
231-
AtConnected::new(Instant::now(), recv, send, request, Default::default());
232-
let ConnectedNext::StartRoot(start) = connected.next().await? else {
233-
unreachable!("expected start root");
234-
};
235-
let (end, data) = start.next().concatenate_into_vec().await?;
236-
let EndBlobNext::Closing(closing) = end.next() else {
237-
unreachable!("expected closing");
238-
};
239-
let stats = closing.next().await?;
164+
let send = AsyncWriteSendStream::new(CompressedWriteStream(Lz4Encoder::new(send)));
165+
let recv = AsyncReadRecvStream::new(CompressedReadStream(Lz4Decoder::new(
166+
BufReader::new(recv),
167+
)));
168+
let sp = StreamPair::new(connection_id, recv, send, EventSender::DEFAULT);
169+
let stats = store.remote().fetch(sp, ticket.hash_and_format()).await?;
240170
if let Some(target) = target {
241-
tokio::fs::write(&target, &data).await?;
242-
println!(
243-
"Wrote {} bytes to {}",
244-
stats.payload_bytes_read,
245-
target.display()
246-
);
171+
let size = store.export(ticket.hash(), &target).await?;
172+
println!("Wrote {} bytes to {}", size, target.display());
247173
} else {
248-
let hash = blake3::hash(&data);
249-
println!("Hash: {hash}");
174+
println!("Hash: {}", ticket.hash());
250175
}
251176
}
252177
}

src/api/downloader.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ async fn execute_get(
456456
};
457457
match remote
458458
.execute_get_sink(
459-
&conn,
459+
conn.clone(),
460460
local.missing(),
461461
(&mut progress).with_map(move |x| DownloadProgessItem::Progress(x + local_bytes)),
462462
)

0 commit comments

Comments
 (0)