11import { enableCors } from './util' ;
22import { Match , Router } from '../../../util/router' ;
3- import { listToUint8 } from '../../../util/buffers/concat' ;
43import { IncomingBatchMessage , RpcMessageBatchProcessor } from '../../common/rpc/RpcMessageBatchProcessor' ;
54import { RpcError , RpcErrorCodes , RpcErrorType } from '../../common/rpc/caller/error' ;
65import { ConnectionContext } from '../context' ;
@@ -11,15 +10,7 @@ import {RpcMessageFormat} from '../../common/codec/constants';
1110import { RpcCodecs } from '../../common/codec/RpcCodecs' ;
1211import { type ReactiveRpcMessage , RpcMessageStreamProcessor , ReactiveRpcClientMessage } from '../../common' ;
1312import type { Codecs } from '../../../json-pack/codecs/Codecs' ;
14- import type {
15- TemplatedApp ,
16- HttpRequest ,
17- HttpResponse ,
18- HttpMethodPermissive ,
19- JsonRouteHandler ,
20- WebSocket ,
21- RpcWebSocket ,
22- } from './types' ;
13+ import type * as types from './types' ;
2314import type { RouteHandler } from './types' ;
2415import type { RpcCaller } from '../../common/rpc/caller/RpcCaller' ;
2516import type { JsonValueCodec } from '../../../json-pack/codecs/types' ;
@@ -31,15 +22,16 @@ const ERR_NOT_FOUND = RpcError.fromCode(RpcErrorCodes.NOT_FOUND, 'Not Found');
3122const ERR_INTERNAL = RpcError . internal ( ) ;
3223
3324export interface RpcAppOptions {
34- uws : TemplatedApp ;
25+ uws : types . TemplatedApp ;
3526 maxRequestBodySize : number ;
3627 codecs : Codecs ;
37- caller : RpcCaller ;
28+ caller : RpcCaller < any > ;
29+ augmentContext : ( ctx : ConnectionContext ) => void ;
3830}
3931
4032export class RpcApp < Ctx extends ConnectionContext > {
4133 public readonly codecs : RpcCodecs ;
42- protected readonly app : TemplatedApp ;
34+ protected readonly app : types . TemplatedApp ;
4335 protected readonly maxRequestBodySize : number ;
4436 protected readonly router = new Router ( ) ;
4537 protected readonly batchProcessor : RpcMessageBatchProcessor < Ctx > ;
@@ -55,12 +47,12 @@ export class RpcApp<Ctx extends ConnectionContext> {
5547 enableCors ( this . options . uws ) ;
5648 }
5749
58- public routeRaw ( method : HttpMethodPermissive , path : string , handler : RouteHandler < Ctx > ) : void {
59- method = method . toLowerCase ( ) as HttpMethodPermissive ;
50+ public routeRaw ( method : types . HttpMethodPermissive , path : string , handler : RouteHandler < Ctx > ) : void {
51+ method = method . toLowerCase ( ) as types . HttpMethodPermissive ;
6052 this . router . add ( method + path , handler ) ;
6153 }
6254
63- public route ( method : HttpMethodPermissive , path : string , handler : JsonRouteHandler < Ctx > ) : void {
55+ public route ( method : types . HttpMethodPermissive , path : string , handler : types . JsonRouteHandler < Ctx > ) : void {
6456 this . routeRaw ( method , path , async ( ctx : Ctx ) => {
6557 const result = await handler ( ctx ) ;
6658 const res = ctx . res ! ;
@@ -112,6 +104,7 @@ export class RpcApp<Ctx extends ConnectionContext> {
112104
113105 public enableWsRpc ( path : string = '/rpc' ) : this {
114106 const maxBackpressure = 4 * 1024 * 1024 ;
107+ const augmentContext = this . options . augmentContext ;
115108 this . app . ws ( path , {
116109 idleTimeout : 0 ,
117110 maxPayloadLength : 4 * 1024 * 1024 ,
@@ -120,11 +113,12 @@ export class RpcApp<Ctx extends ConnectionContext> {
120113 const secWebSocketProtocol = req . getHeader ( 'sec-websocket-protocol' ) ;
121114 const secWebSocketExtensions = req . getHeader ( 'sec-websocket-extensions' ) ;
122115 const ctx = ConnectionContext . fromReqRes ( req , res , null , this ) ;
116+ augmentContext ( ctx ) ;
123117 /* This immediately calls open handler, you must not use res after this call */
124118 res . upgrade ( { ctx} , secWebSocketKey , secWebSocketProtocol , secWebSocketExtensions , context ) ;
125119 } ,
126- open : ( ws_ : WebSocket ) => {
127- const ws = ws_ as RpcWebSocket < Ctx > ;
120+ open : ( ws_ : types . WebSocket ) => {
121+ const ws = ws_ as types . RpcWebSocket < Ctx > ;
128122 const ctx = ws . ctx ;
129123 const resCodec = ctx . resCodec ;
130124 const msgCodec = ctx . msgCodec ;
@@ -144,8 +138,8 @@ export class RpcApp<Ctx extends ConnectionContext> {
144138 bufferTime : 0 ,
145139 } ) ;
146140 } ,
147- message : ( ws_ : WebSocket , buf : ArrayBuffer , isBinary : boolean ) => {
148- const ws = ws_ as RpcWebSocket < Ctx > ;
141+ message : ( ws_ : types . WebSocket , buf : ArrayBuffer , isBinary : boolean ) => {
142+ const ws = ws_ as types . RpcWebSocket < Ctx > ;
149143 const ctx = ws . ctx ;
150144 const reqCodec = ctx . reqCodec ;
151145 const msgCodec = ctx . msgCodec ;
@@ -158,8 +152,8 @@ export class RpcApp<Ctx extends ConnectionContext> {
158152 rpc . sendNotification ( '.err' , RpcError . value ( RpcError . invalidRequest ( ) ) ) ;
159153 }
160154 } ,
161- close : ( ws_ : WebSocket , code : number , message : ArrayBuffer ) => {
162- const ws = ws_ as RpcWebSocket < Ctx > ;
155+ close : ( ws_ : types . WebSocket , code : number , message : ArrayBuffer ) => {
156+ const ws = ws_ as types . RpcWebSocket < Ctx > ;
163157 ws . rpc ! . stop ( ) ;
164158 } ,
165159 } ) ;
@@ -170,7 +164,8 @@ export class RpcApp<Ctx extends ConnectionContext> {
170164 const matcher = this . router . compile ( ) ;
171165 const codecs = this . codecs ;
172166 let responseCodec : JsonValueCodec = codecs . value . json ;
173- this . app . any ( '/*' , async ( res : HttpResponse , req : HttpRequest ) => {
167+ const augmentContext = this . options . augmentContext ;
168+ this . app . any ( '/*' , async ( res : types . HttpResponse , req : types . HttpRequest ) => {
174169 res . onAborted ( ( ) => {
175170 res . aborted = true ;
176171 } ) ;
@@ -189,6 +184,7 @@ export class RpcApp<Ctx extends ConnectionContext> {
189184 const params = match . params ;
190185 const ctx = ConnectionContext . fromReqRes ( req , res , params , this ) as Ctx ;
191186 responseCodec = ctx . resCodec ;
187+ augmentContext ( ctx ) ;
192188 await handler ( ctx ) ;
193189 } catch ( err ) {
194190 if ( err instanceof RpcError ) {
0 commit comments