1+ use http_body_util:: StreamBody ;
2+ use hyper:: body:: Bytes ;
3+ use hyper:: body:: Frame ;
4+ use hyper:: rt:: { Read , ReadBufCursor , Write } ;
5+ use hyper:: server:: conn:: http1;
6+ use hyper:: service:: service_fn;
7+ use hyper:: { Response , StatusCode } ;
8+ use pin_project_lite:: pin_project;
9+ use std:: convert:: Infallible ;
10+ use std:: io;
11+ use std:: pin:: Pin ;
12+ use std:: task:: { ready, Context , Poll } ;
13+ use tokio:: sync:: mpsc;
14+ use tracing:: { error, info} ;
15+
16+ pin_project ! {
17+ #[ derive( Debug ) ]
18+ pub struct TxReadyStream {
19+ #[ pin]
20+ read_rx: mpsc:: UnboundedReceiver <Vec <u8 >>,
21+ write_tx: mpsc:: UnboundedSender <Vec <u8 >>,
22+ read_buffer: Vec <u8 >,
23+ poll_since_write: bool ,
24+ flush_count: usize ,
25+ }
26+ }
27+
28+ impl TxReadyStream {
29+ fn new (
30+ read_rx : mpsc:: UnboundedReceiver < Vec < u8 > > ,
31+ write_tx : mpsc:: UnboundedSender < Vec < u8 > > ,
32+ ) -> Self {
33+ Self {
34+ read_rx,
35+ write_tx,
36+ read_buffer : Vec :: new ( ) ,
37+ poll_since_write : true ,
38+ flush_count : 0 ,
39+ }
40+ }
41+
42+ /// Create a new pair of connected ReadyStreams. Returns two streams that are connected to each other.
43+ fn new_pair ( ) -> ( Self , Self ) {
44+ let ( s1_tx, s2_rx) = mpsc:: unbounded_channel ( ) ;
45+ let ( s2_tx, s1_rx) = mpsc:: unbounded_channel ( ) ;
46+ let s1 = Self :: new ( s1_rx, s1_tx) ;
47+ let s2 = Self :: new ( s2_rx, s2_tx) ;
48+ ( s1, s2)
49+ }
50+
51+ /// Send data to the other end of the stream (this will be available for reading on the other stream)
52+ fn send ( & self , data : & [ u8 ] ) -> Result < ( ) , mpsc:: error:: SendError < Vec < u8 > > > {
53+ self . write_tx . send ( data. to_vec ( ) )
54+ }
55+
56+
57+ /// Receive data written to this stream by the other end (async)
58+ async fn recv ( & mut self ) -> Option < Vec < u8 > > {
59+ self . read_rx . recv ( ) . await
60+ }
61+ }
62+
63+ impl Read for TxReadyStream {
64+ fn poll_read (
65+ mut self : Pin < & mut Self > ,
66+ cx : & mut Context < ' _ > ,
67+ mut buf : ReadBufCursor < ' _ > ,
68+ ) -> Poll < io:: Result < ( ) > > {
69+ let mut this = self . as_mut ( ) . project ( ) ;
70+
71+ // First, try to satisfy the read request from the internal buffer
72+ if !this. read_buffer . is_empty ( ) {
73+ let to_read = std:: cmp:: min ( this. read_buffer . len ( ) , buf. remaining ( ) ) ;
74+ // Copy data from internal buffer to the read buffer
75+ buf. put_slice ( & this. read_buffer [ ..to_read] ) ;
76+ // Remove the consumed data from the internal buffer
77+ this. read_buffer . drain ( ..to_read) ;
78+ return Poll :: Ready ( Ok ( ( ) ) ) ;
79+ }
80+
81+ // If internal buffer is empty, try to get data from the channel
82+ match this. read_rx . try_recv ( ) {
83+ Ok ( data) => {
84+ // Copy as much data as we can fit in the buffer
85+ let to_read = std:: cmp:: min ( data. len ( ) , buf. remaining ( ) ) ;
86+ buf. put_slice ( & data[ ..to_read] ) ;
87+
88+ // Store any remaining data in the internal buffer for next time
89+ if to_read < data. len ( ) {
90+ let remaining = & data[ to_read..] ;
91+ this. read_buffer . extend_from_slice ( remaining) ;
92+ }
93+ Poll :: Ready ( Ok ( ( ) ) )
94+ }
95+ Err ( mpsc:: error:: TryRecvError :: Empty ) => {
96+ match ready ! ( this. read_rx. poll_recv( cx) ) {
97+ Some ( data) => {
98+ // Copy as much data as we can fit in the buffer
99+ let to_read = std:: cmp:: min ( data. len ( ) , buf. remaining ( ) ) ;
100+ buf. put_slice ( & data[ ..to_read] ) ;
101+
102+ // Store any remaining data in the internal buffer for next time
103+ if to_read < data. len ( ) {
104+ let remaining = & data[ to_read..] ;
105+ this. read_buffer . extend_from_slice ( remaining) ;
106+ }
107+ Poll :: Ready ( Ok ( ( ) ) )
108+ }
109+ None => Poll :: Ready ( Ok ( ( ) ) ) ,
110+ }
111+ }
112+ Err ( mpsc:: error:: TryRecvError :: Disconnected ) => {
113+ // Channel closed, return EOF
114+ Poll :: Ready ( Ok ( ( ) ) )
115+ }
116+ }
117+ }
118+ }
119+
120+ impl Write for TxReadyStream {
121+ fn poll_write (
122+ mut self : Pin < & mut Self > ,
123+ _cx : & mut Context < ' _ > ,
124+ buf : & [ u8 ] ,
125+ ) -> Poll < io:: Result < usize > > {
126+ if !self . poll_since_write {
127+ return Poll :: Pending ;
128+ }
129+ self . poll_since_write = false ;
130+ let this = self . project ( ) ;
131+ let buf = Vec :: from ( & buf[ ..buf. len ( ) ] ) ;
132+ let len = buf. len ( ) ;
133+
134+ // Send data through the channel - this should always be ready for unbounded channels
135+ match this. write_tx . send ( buf) {
136+ Ok ( _) => {
137+ // Increment write count
138+ Poll :: Ready ( Ok ( len) )
139+ }
140+ Err ( _) => {
141+ error ! ( "ReadyStream::poll_write failed - channel closed" ) ;
142+ Poll :: Ready ( Err ( io:: Error :: new (
143+ io:: ErrorKind :: BrokenPipe ,
144+ "Write channel closed" ,
145+ ) ) )
146+ }
147+ }
148+ }
149+
150+ fn poll_flush ( mut self : Pin < & mut Self > , _cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
151+ self . flush_count += 1 ;
152+ // We require two flushes to complete each chunk, simulating a success at the end of the old
153+ // poll loop. After all chunks are written, we always succeed on flush to allow for finish.
154+ if self . flush_count % 2 != 0 && self . flush_count < TOTAL_CHUNKS * 2 {
155+ return Poll :: Pending ;
156+ }
157+ self . poll_since_write = true ;
158+ Poll :: Ready ( Ok ( ( ) ) )
159+ }
160+
161+ fn poll_shutdown ( self : Pin < & mut Self > , _cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
162+ Poll :: Ready ( Ok ( ( ) ) )
163+ }
164+ }
165+
166+ fn init_tracing ( ) {
167+ use std:: sync:: Once ;
168+ static INIT : Once = Once :: new ( ) ;
169+ INIT . call_once ( || {
170+ tracing_subscriber:: fmt ( )
171+ . with_max_level ( tracing:: Level :: INFO )
172+ . with_target ( true )
173+ . with_thread_ids ( true )
174+ . with_thread_names ( true )
175+ . init ( ) ;
176+ } ) ;
177+ }
178+
179+ const TOTAL_CHUNKS : usize = 16 ;
180+
181+ #[ tokio:: test( flavor = "multi_thread" , worker_threads = 2 ) ]
182+ async fn body_test ( ) {
183+ init_tracing ( ) ;
184+ // Create a pair of connected streams
185+ let ( server_stream, mut client_stream) = TxReadyStream :: new_pair ( ) ;
186+
187+ let mut http_builder = http1:: Builder :: new ( ) ;
188+ http_builder. max_buf_size ( CHUNK_SIZE ) ;
189+ const CHUNK_SIZE : usize = 64 * 1024 ;
190+ let service = service_fn ( |_| async move {
191+ info ! (
192+ "Creating payload of {} chunks of {} KiB each ({} MiB total)..." ,
193+ TOTAL_CHUNKS ,
194+ CHUNK_SIZE / 1024 ,
195+ TOTAL_CHUNKS * CHUNK_SIZE / ( 1024 * 1024 )
196+ ) ;
197+ let bytes = Bytes :: from ( vec ! [ 0 ; CHUNK_SIZE ] ) ;
198+ let data = vec ! [ bytes. clone( ) ; TOTAL_CHUNKS ] ;
199+ let stream = futures_util:: stream:: iter (
200+ data. into_iter ( )
201+ . map ( |b| Ok :: < _ , Infallible > ( Frame :: data ( b) ) ) ,
202+ ) ;
203+ let body = StreamBody :: new ( stream) ;
204+ info ! ( "Server: Sending data response..." ) ;
205+ Ok :: < _ , hyper:: Error > (
206+ Response :: builder ( )
207+ . status ( StatusCode :: OK )
208+ . header ( "content-type" , "application/octet-stream" )
209+ . header ( "content-length" , ( TOTAL_CHUNKS * CHUNK_SIZE ) . to_string ( ) )
210+ . body ( body)
211+ . unwrap ( ) ,
212+ )
213+ } ) ;
214+
215+ let server_task = tokio:: spawn ( async move {
216+ let conn = http_builder. serve_connection ( server_stream, service) ;
217+ if let Err ( e) = conn. await {
218+ error ! ( "Server connection error: {}" , e) ;
219+ }
220+ } ) ;
221+
222+ let get_request = "GET / HTTP/1.1\r \n Host: localhost\r \n Connection: close\r \n \r \n " ;
223+ client_stream. send ( get_request. as_bytes ( ) ) . unwrap ( ) ;
224+
225+ info ! ( "Client is reading response..." ) ;
226+ let mut bytes_received = 0 ;
227+ while let Some ( chunk) = client_stream. recv ( ) . await {
228+ bytes_received += chunk. len ( ) ;
229+ }
230+ // Clean up
231+ server_task. abort ( ) ;
232+
233+ info ! ( bytes_received, "Client done receiving bytes" ) ;
234+ }
0 commit comments