@@ -42,6 +42,10 @@ pub(crate) fn register_grpc_callout(token_id: u32) {
4242 DISPATCHER . with ( |dispatcher| dispatcher. register_grpc_callout ( token_id) ) ;
4343}
4444
45+ pub ( crate ) fn register_grpc_stream ( token_id : u32 ) {
46+ DISPATCHER . with ( |dispatcher| dispatcher. register_grpc_stream ( token_id) ) ;
47+ }
48+
4549struct NoopRoot ;
4650
4751impl Context for NoopRoot { }
@@ -57,6 +61,7 @@ struct Dispatcher {
5761 active_id : Cell < u32 > ,
5862 callouts : RefCell < HashMap < u32 , u32 > > ,
5963 grpc_callouts : RefCell < HashMap < u32 , u32 > > ,
64+ grpc_streams : RefCell < HashMap < u32 , u32 > > ,
6065}
6166
6267impl Dispatcher {
@@ -71,6 +76,7 @@ impl Dispatcher {
7176 active_id : Cell :: new ( 0 ) ,
7277 callouts : RefCell :: new ( HashMap :: new ( ) ) ,
7378 grpc_callouts : RefCell :: new ( HashMap :: new ( ) ) ,
79+ grpc_streams : RefCell :: new ( HashMap :: new ( ) ) ,
7480 }
7581 }
7682
@@ -97,6 +103,17 @@ impl Dispatcher {
97103 }
98104 }
99105
106+ fn register_grpc_stream ( & self , token_id : u32 ) {
107+ if self
108+ . grpc_streams
109+ . borrow_mut ( )
110+ . insert ( token_id, self . active_id . get ( ) )
111+ . is_some ( )
112+ {
113+ panic ! ( "duplicate token_id" )
114+ }
115+ }
116+
100117 fn register_grpc_callout ( & self , token_id : u32 ) {
101118 if self
102119 . grpc_callouts
@@ -399,47 +416,116 @@ impl Dispatcher {
399416 }
400417 }
401418
402- fn on_grpc_receive ( & self , token_id : u32 , response_size : usize ) {
403- let context_id = self
404- . grpc_callouts
419+ fn on_grpc_receive_initial_metadata ( & self , token_id : u32 , headers : u32 ) {
420+ let context_id = * self
421+ . grpc_streams
405422 . borrow_mut ( )
406- . remove ( & token_id)
423+ . get ( & token_id)
407424 . expect ( "invalid token_id" ) ;
408425
409426 if let Some ( http_stream) = self . http_streams . borrow_mut ( ) . get_mut ( & context_id) {
410427 self . active_id . set ( context_id) ;
411428 hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
412- http_stream. on_grpc_call_response ( token_id, 0 , response_size ) ;
429+ http_stream. on_grpc_stream_initial_metadata ( token_id, headers ) ;
413430 } else if let Some ( stream) = self . streams . borrow_mut ( ) . get_mut ( & context_id) {
414431 self . active_id . set ( context_id) ;
415432 hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
416- stream. on_grpc_call_response ( token_id, 0 , response_size ) ;
433+ stream. on_grpc_stream_initial_metadata ( token_id, headers ) ;
417434 } else if let Some ( root) = self . roots . borrow_mut ( ) . get_mut ( & context_id) {
418435 self . active_id . set ( context_id) ;
419436 hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
420- root. on_grpc_call_response ( token_id, 0 , response_size ) ;
437+ root. on_grpc_stream_initial_metadata ( token_id, headers ) ;
421438 }
422439 }
423440
424- fn on_grpc_close ( & self , token_id : u32 , status_code : u32 ) {
425- let context_id = self
426- . grpc_callouts
441+ fn on_grpc_receive ( & self , token_id : u32 , response_size : usize ) {
442+ if let Some ( context_id) = self . grpc_callouts . borrow_mut ( ) . remove ( & token_id) {
443+ if let Some ( http_stream) = self . http_streams . borrow_mut ( ) . get_mut ( & context_id) {
444+ self . active_id . set ( context_id) ;
445+ hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
446+ http_stream. on_grpc_call_response ( token_id, 0 , response_size) ;
447+ } else if let Some ( stream) = self . streams . borrow_mut ( ) . get_mut ( & context_id) {
448+ self . active_id . set ( context_id) ;
449+ hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
450+ stream. on_grpc_call_response ( token_id, 0 , response_size) ;
451+ } else if let Some ( root) = self . roots . borrow_mut ( ) . get_mut ( & context_id) {
452+ self . active_id . set ( context_id) ;
453+ hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
454+ root. on_grpc_call_response ( token_id, 0 , response_size) ;
455+ }
456+ } else if let Some ( context_id) = self . grpc_streams . borrow_mut ( ) . get ( & token_id) {
457+ let context_id = * context_id;
458+ if let Some ( http_stream) = self . http_streams . borrow_mut ( ) . get_mut ( & context_id) {
459+ self . active_id . set ( context_id) ;
460+ hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
461+ http_stream. on_grpc_stream_message ( token_id, response_size) ;
462+ } else if let Some ( stream) = self . streams . borrow_mut ( ) . get_mut ( & context_id) {
463+ self . active_id . set ( context_id) ;
464+ hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
465+ stream. on_grpc_stream_message ( token_id, response_size) ;
466+ } else if let Some ( root) = self . roots . borrow_mut ( ) . get_mut ( & context_id) {
467+ self . active_id . set ( context_id) ;
468+ hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
469+ root. on_grpc_stream_message ( token_id, response_size) ;
470+ }
471+ } else {
472+ panic ! ( "invalid token_id" )
473+ }
474+ }
475+
476+ fn on_grpc_receive_trailing_metadata ( & self , token_id : u32 , trailers : u32 ) {
477+ let context_id = * self
478+ . grpc_streams
427479 . borrow_mut ( )
428- . remove ( & token_id)
480+ . get ( & token_id)
429481 . expect ( "invalid token_id" ) ;
430482
431483 if let Some ( http_stream) = self . http_streams . borrow_mut ( ) . get_mut ( & context_id) {
432484 self . active_id . set ( context_id) ;
433485 hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
434- http_stream. on_grpc_call_response ( token_id, status_code , 0 ) ;
486+ http_stream. on_grpc_stream_trailing_metadata ( token_id, trailers ) ;
435487 } else if let Some ( stream) = self . streams . borrow_mut ( ) . get_mut ( & context_id) {
436488 self . active_id . set ( context_id) ;
437489 hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
438- stream. on_grpc_call_response ( token_id, status_code , 0 ) ;
490+ stream. on_grpc_stream_trailing_metadata ( token_id, trailers ) ;
439491 } else if let Some ( root) = self . roots . borrow_mut ( ) . get_mut ( & context_id) {
440492 self . active_id . set ( context_id) ;
441493 hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
442- root. on_grpc_call_response ( token_id, status_code, 0 ) ;
494+ root. on_grpc_stream_trailing_metadata ( token_id, trailers) ;
495+ }
496+ }
497+
498+ fn on_grpc_close ( & self , token_id : u32 , status_code : u32 ) {
499+ if let Some ( context_id) = self . grpc_callouts . borrow_mut ( ) . remove ( & token_id) {
500+ if let Some ( http_stream) = self . http_streams . borrow_mut ( ) . get_mut ( & context_id) {
501+ self . active_id . set ( context_id) ;
502+ hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
503+ http_stream. on_grpc_call_response ( token_id, status_code, 0 ) ;
504+ } else if let Some ( stream) = self . streams . borrow_mut ( ) . get_mut ( & context_id) {
505+ self . active_id . set ( context_id) ;
506+ hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
507+ stream. on_grpc_call_response ( token_id, status_code, 0 ) ;
508+ } else if let Some ( root) = self . roots . borrow_mut ( ) . get_mut ( & context_id) {
509+ self . active_id . set ( context_id) ;
510+ hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
511+ root. on_grpc_call_response ( token_id, status_code, 0 ) ;
512+ }
513+ } else if let Some ( context_id) = self . grpc_streams . borrow_mut ( ) . remove ( & token_id) {
514+ if let Some ( http_stream) = self . http_streams . borrow_mut ( ) . get_mut ( & context_id) {
515+ self . active_id . set ( context_id) ;
516+ hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
517+ http_stream. on_grpc_stream_close ( token_id, status_code)
518+ } else if let Some ( stream) = self . streams . borrow_mut ( ) . get_mut ( & context_id) {
519+ self . active_id . set ( context_id) ;
520+ hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
521+ stream. on_grpc_stream_close ( token_id, status_code)
522+ } else if let Some ( root) = self . roots . borrow_mut ( ) . get_mut ( & context_id) {
523+ self . active_id . set ( context_id) ;
524+ hostcalls:: set_effective_context ( context_id) . unwrap ( ) ;
525+ root. on_grpc_stream_close ( token_id, status_code)
526+ }
527+ } else {
528+ panic ! ( "invalid token_id" )
443529 }
444530 }
445531}
@@ -571,11 +657,29 @@ pub extern "C" fn proxy_on_http_call_response(
571657 } )
572658}
573659
660+ #[ no_mangle]
661+ pub extern "C" fn proxy_on_grpc_receive_initial_metadata (
662+ _context_id : u32 ,
663+ token_id : u32 ,
664+ headers : u32 ,
665+ ) {
666+ DISPATCHER . with ( |dispatcher| dispatcher. on_grpc_receive_initial_metadata ( token_id, headers) )
667+ }
668+
574669#[ no_mangle]
575670pub extern "C" fn proxy_on_grpc_receive ( _context_id : u32 , token_id : u32 , response_size : usize ) {
576671 DISPATCHER . with ( |dispatcher| dispatcher. on_grpc_receive ( token_id, response_size) )
577672}
578673
674+ #[ no_mangle]
675+ pub extern "C" fn proxy_on_grpc_receive_trailing_metadata (
676+ _context_id : u32 ,
677+ token_id : u32 ,
678+ trailers : u32 ,
679+ ) {
680+ DISPATCHER . with ( |dispatcher| dispatcher. on_grpc_receive_trailing_metadata ( token_id, trailers) )
681+ }
682+
579683#[ no_mangle]
580684pub extern "C" fn proxy_on_grpc_close ( _context_id : u32 , token_id : u32 , status_code : u32 ) {
581685 DISPATCHER . with ( |dispatcher| dispatcher. on_grpc_close ( token_id, status_code) )
0 commit comments