@@ -18,25 +18,48 @@ import NIOConcurrencyHelpers
1818import NIOHTTP1
1919import NIOSSL
2020
21- public extension HTTPClient {
22- enum Body : Equatable {
23- case byteBuffer( ByteBuffer )
24- case data( Data )
25- case string( String )
26-
27- var length : Int {
28- switch self {
29- case . byteBuffer( let buffer) :
30- return buffer. readableBytes
31- case . data( let data) :
32- return data. count
33- case . string( let string) :
34- return string. utf8. count
21+ extension HTTPClient {
22+
23+ public struct Body {
24+ public struct StreamWriter {
25+ let closure : ( IOData ) -> EventLoopFuture < Void >
26+
27+ public func write( _ data: IOData ) -> EventLoopFuture < Void > {
28+ return self . closure ( data)
29+ }
30+ }
31+
32+ public var length : Int ?
33+ public var stream : ( StreamWriter ) -> EventLoopFuture < Void >
34+
35+ public static func byteBuffer( _ buffer: ByteBuffer ) -> Body {
36+ return Body ( length: buffer. readableBytes) { writer in
37+ writer. write ( . byteBuffer( buffer) )
38+ }
39+ }
40+
41+ public static func stream( length: Int ? = nil , _ stream: @escaping ( StreamWriter ) -> EventLoopFuture < Void > ) -> Body {
42+ return Body ( length: length, stream: stream)
43+ }
44+
45+ public static func data( _ data: Data ) -> Body {
46+ return Body ( length: data. count) { writer in
47+ var buffer = ByteBufferAllocator ( ) . buffer ( capacity: data. count)
48+ buffer. writeBytes ( data)
49+ return writer. write ( . byteBuffer( buffer) )
50+ }
51+ }
52+
53+ public static func string( _ string: String ) -> Body {
54+ return Body ( length: string. utf8. count) { writer in
55+ var buffer = ByteBufferAllocator ( ) . buffer ( capacity: string. utf8. count)
56+ buffer. writeString ( string)
57+ return writer. write ( . byteBuffer( buffer) )
3558 }
3659 }
3760 }
3861
39- struct Request : Equatable {
62+ public struct Request {
4063 public var version : HTTPVersion
4164 public var method : HTTPMethod
4265 public var url : URL
@@ -53,7 +76,7 @@ public extension HTTPClient {
5376 try self . init ( url: url, version: version, method: method, headers: headers, body: body)
5477 }
5578
56- public init ( url: URL , version: HTTPVersion , method: HTTPMethod = . GET, headers: HTTPHeaders = HTTPHeaders ( ) , body: Body ? = nil ) throws {
79+ public init ( url: URL , version: HTTPVersion = HTTPVersion ( major : 1 , minor : 1 ) , method: HTTPMethod = . GET, headers: HTTPHeaders = HTTPHeaders ( ) , body: Body ? = nil ) throws {
5780 guard let scheme = url. scheme else {
5881 throw HTTPClientError . emptyScheme
5982 }
@@ -88,7 +111,7 @@ public extension HTTPClient {
88111 }
89112 }
90113
91- struct Response : Equatable {
114+ public struct Response {
92115 public var host : String
93116 public var status : HTTPResponseStatus
94117 public var headers : HTTPHeaders
@@ -114,9 +137,7 @@ internal class ResponseAccumulator: HTTPClientResponseDelegate {
114137 self . request = request
115138 }
116139
117- func didTransmitRequestBody( task: HTTPClient . Task < Response > ) { }
118-
119- func didReceiveHead( task: HTTPClient . Task < Response > , _ head: HTTPResponseHead ) {
140+ func didReceiveHead( task: HTTPClient . Task < Response > , _ head: HTTPResponseHead ) -> EventLoopFuture < Void > {
120141 switch self . state {
121142 case . idle:
122143 self . state = . head( head)
@@ -129,9 +150,10 @@ internal class ResponseAccumulator: HTTPClientResponseDelegate {
129150 case . error:
130151 break
131152 }
153+ return task. eventLoop. makeSucceededFuture ( ( ) )
132154 }
133155
134- func didReceivePart( task: HTTPClient . Task < Response > , _ part: ByteBuffer ) {
156+ func didReceivePart( task: HTTPClient . Task < Response > , _ part: ByteBuffer ) -> EventLoopFuture < Void > {
135157 switch self . state {
136158 case . idle:
137159 preconditionFailure ( " no head received before body " )
@@ -146,6 +168,7 @@ internal class ResponseAccumulator: HTTPClientResponseDelegate {
146168 case . error:
147169 break
148170 }
171+ return task. eventLoop. makeSucceededFuture ( ( ) )
149172 }
150173
151174 func didReceiveError( task: HTTPClient . Task < Response > , _ error: Error ) {
@@ -174,25 +197,33 @@ internal class ResponseAccumulator: HTTPClientResponseDelegate {
174197public protocol HTTPClientResponseDelegate : AnyObject {
175198 associatedtype Response
176199
177- func didTransmitRequestBody( task: HTTPClient . Task < Response > )
200+ func didSendRequestHead( task: HTTPClient . Task < Response > , _ head: HTTPRequestHead )
201+
202+ func didSendRequestPart( task: HTTPClient . Task < Response > , _ part: IOData )
203+
204+ func didSendRequest( task: HTTPClient . Task < Response > )
178205
179- func didReceiveHead( task: HTTPClient . Task < Response > , _ head: HTTPResponseHead )
206+ func didReceiveHead( task: HTTPClient . Task < Response > , _ head: HTTPResponseHead ) -> EventLoopFuture < Void >
180207
181- func didReceivePart( task: HTTPClient . Task < Response > , _ buffer: ByteBuffer )
208+ func didReceivePart( task: HTTPClient . Task < Response > , _ buffer: ByteBuffer ) -> EventLoopFuture < Void >
182209
183210 func didReceiveError( task: HTTPClient . Task < Response > , _ error: Error )
184211
185212 func didFinishRequest( task: HTTPClient . Task < Response > ) throws -> Response
186213}
187214
188215extension HTTPClientResponseDelegate {
189- func didTransmitRequestBody ( task: HTTPClient . Task < Response > ) { }
216+ public func didSendRequestHead ( task: HTTPClient . Task < Response > , _ head : HTTPRequestHead ) { }
190217
191- func didReceiveHead ( task: HTTPClient . Task < Response > , _: HTTPResponseHead ) { }
218+ public func didSendRequestPart ( task: HTTPClient . Task < Response > , _ part : IOData ) { }
192219
193- func didReceivePart ( task: HTTPClient . Task < Response > , _ : ByteBuffer ) { }
220+ public func didSendRequest ( task: HTTPClient . Task < Response > ) { }
194221
195- func didReceiveError( task: HTTPClient . Task < Response > , _: Error ) { }
222+ public func didReceiveHead( task: HTTPClient . Task < Response > , _: HTTPResponseHead ) -> EventLoopFuture < Void > { return task. eventLoop. makeSucceededFuture ( ( ) ) }
223+
224+ public func didReceivePart( task: HTTPClient . Task < Response > , _: ByteBuffer ) -> EventLoopFuture < Void > { return task. eventLoop. makeSucceededFuture ( ( ) ) }
225+
226+ public func didReceiveError( task: HTTPClient . Task < Response > , _: Error ) { }
196227}
197228
198229internal extension URL {
@@ -207,13 +238,15 @@ internal extension URL {
207238
208239public extension HTTPClient {
209240 final class Task < Response> {
241+ public let eventLoop : EventLoop
210242 let future : EventLoopFuture < Response >
211243
212244 private var channel : Channel ?
213245 private var cancelled : Bool
214246 private let lock : Lock
215247
216- init ( future: EventLoopFuture < Response > ) {
248+ init ( eventLoop: EventLoop , future: EventLoopFuture < Response > ) {
249+ self . eventLoop = eventLoop
217250 self . future = future
218251 self . cancelled = false
219252 self . lock = Lock ( )
@@ -267,6 +300,8 @@ internal class TaskHandler<T: HTTPClientResponseDelegate>: ChannelInboundHandler
267300 let redirectHandler : RedirectHandler < T . Response > ?
268301
269302 var state : State = . idle
303+ var pendingRead = false
304+ var mayRead = true
270305
271306 init ( task: HTTPClient . Task < T . Response > , delegate: T , promise: EventLoopPromise < T . Response > , redirectHandler: RedirectHandler < T . Response > ? ) {
272307 self . task = task
@@ -298,35 +333,52 @@ internal class TaskHandler<T: HTTPClientResponseDelegate>: ChannelInboundHandler
298333
299334 head. headers = headers
300335
301- context. write ( wrapOutboundOut ( . head( head) ) , promise: nil )
336+ context. write ( wrapOutboundOut ( . head( head) ) ) . whenSuccess {
337+ self . delegate. didSendRequestHead ( task: self . task, head)
338+ }
302339
303- if let body = request. body {
304- let part : HTTPClientRequestPart
305- switch body {
306- case . byteBuffer( let buffer) :
307- part = HTTPClientRequestPart . body ( . byteBuffer( buffer) )
308- case . data( let data) :
309- var buffer = context. channel. allocator. buffer ( capacity: data. count)
310- buffer. writeBytes ( data)
311- part = HTTPClientRequestPart . body ( . byteBuffer( buffer) )
312- case . string( let string) :
313- var buffer = context. channel. allocator. buffer ( capacity: string. utf8. count)
314- buffer. writeString ( string)
315- part = HTTPClientRequestPart . body ( . byteBuffer( buffer) )
316- }
340+ self . writeBody ( request: request, context: context) . whenComplete { result in
341+ switch result {
342+ case . success:
343+ context. write ( self . wrapOutboundOut ( . end( nil ) ) , promise: promise)
344+ context. flush ( )
317345
318- context . write ( wrapOutboundOut ( part ) , promise : nil )
319- }
346+ self . state = . sent
347+ self . delegate . didSendRequest ( task : self . task )
320348
321- context. write ( wrapOutboundOut ( . end( nil ) ) , promise: promise)
322- context. flush ( )
349+ let channel = context. channel
350+ self . promise. futureResult. whenComplete { _ in
351+ channel. close ( promise: nil )
352+ }
353+ case . failure( let error) :
354+ self . state = . end
355+ self . delegate. didReceiveError ( task: self . task, error)
356+ self . promise. fail ( error)
357+ context. close ( promise: nil )
358+ }
359+ }
360+ }
323361
324- self . state = . sent
325- self . delegate. didTransmitRequestBody ( task: self . task)
362+ private func writeBody( request: HTTPClient . Request , context: ChannelHandlerContext ) -> EventLoopFuture < Void > {
363+ if let body = request. body {
364+ return body. stream ( HTTPClient . Body. StreamWriter { part in
365+ let future = context. writeAndFlush ( self . wrapOutboundOut ( . body( part) ) )
366+ future. whenSuccess { _ in
367+ self . delegate. didSendRequestPart ( task: self . task, part)
368+ }
369+ return future
370+ } )
371+ } else {
372+ return context. eventLoop. makeSucceededFuture ( ( ) )
373+ }
374+ }
326375
327- let channel = context. channel
328- self . promise. futureResult. whenComplete { _ in
329- channel. close ( promise: nil )
376+ public func read( context: ChannelHandlerContext ) {
377+ if self . mayRead {
378+ self . pendingRead = false
379+ context. read ( )
380+ } else {
381+ self . pendingRead = true
330382 }
331383 }
332384
@@ -338,15 +390,21 @@ internal class TaskHandler<T: HTTPClientResponseDelegate>: ChannelInboundHandler
338390 self . state = . redirected( head, redirectURL)
339391 } else {
340392 self . state = . head
341- self . delegate. didReceiveHead ( task: self . task, head)
393+ self . mayRead = false
394+ self . delegate. didReceiveHead ( task: self . task, head) . whenComplete { result in
395+ self . handleBackpressureResult ( context: context, result: result)
396+ }
342397 }
343398 case . body( let body) :
344399 switch self . state {
345400 case . redirected:
346401 break
347402 default :
348403 self . state = . body
349- self . delegate. didReceivePart ( task: self . task, body)
404+ self . mayRead = false
405+ self . delegate. didReceivePart ( task: self . task, body) . whenComplete { result in
406+ self . handleBackpressureResult ( context: context, result: result)
407+ }
350408 }
351409 case . end:
352410 switch self . state {
@@ -365,6 +423,20 @@ internal class TaskHandler<T: HTTPClientResponseDelegate>: ChannelInboundHandler
365423 }
366424 }
367425
426+ private func handleBackpressureResult( context: ChannelHandlerContext , result: Result < Void , Error > ) {
427+ switch result {
428+ case . success:
429+ self . mayRead = true
430+ if self . pendingRead {
431+ context. read ( )
432+ }
433+ case . failure( let error) :
434+ self . state = . end
435+ self . delegate. didReceiveError ( task: self . task, error)
436+ self . promise. fail ( error)
437+ }
438+ }
439+
368440 func userInboundEventTriggered( context: ChannelHandlerContext , event: Any ) {
369441 if ( event as? IdleStateHandler . IdleStateEvent) == . read {
370442 self . state = . end
0 commit comments