1- import Redis , { RedisOptions } from 'ioredis' ;
1+ import EventEmitter from 'events' ;
2+
23import { parse , validate } from 'graphql' ;
4+ import { RedisOptions } from 'ioredis' ;
35import { GraphQLSchema } from 'graphql/type/schema' ;
46import { Request , Response , NextFunction , RequestHandler } from 'express' ;
57
68import buildTypeWeightsFromSchema , { defaultTypeWeightsConfig } from '../analysis/buildTypeWeights' ;
79import setupRateLimiter from './rateLimiterSetup' ;
810import getQueryTypeComplexity from '../analysis/typeComplexityAnalysis' ;
9- import { RateLimiterOptions , RateLimiterSelection } from '../@types/rateLimit' ;
11+ import { RateLimiterOptions , RateLimiterSelection , RateLimiterResponse } from '../@types/rateLimit' ;
1012import { TypeWeightConfig } from '../@types/buildTypeWeights' ;
13+ import { connect } from '../utils/redis' ;
1114
1215// FIXME: Will the developer be responsible for first parsing the schema from a file?
1316// Can consider accepting a string representing a the filepath to a schema
@@ -39,11 +42,72 @@ export function expressRateLimiter(
3942 */
4043 // TODO: Throw ValidationError if schema is invalid
4144 const typeWeightObject = buildTypeWeightsFromSchema ( schema , typeWeightConfig ) ;
45+
4246 // TODO: Throw error if connection is unsuccessful
43- const redisClient = new Redis ( redisClientOptions ) ; // Default port is 6379 automatically
47+ const redisClient = connect ( redisClientOptions ) ; // Default port is 6379 automatically
4448 const rateLimiter = setupRateLimiter ( rateLimiterAlgo , rateLimiterOptions , redisClient ) ;
4549
46- // return the rate limiting middleware
50+ // stores request IDs to be processed
51+ const requestQueue : { [ index : string ] : string [ ] } = { } ;
52+
53+ // Manages processing of event queue
54+ const requestEvents = new EventEmitter ( ) ;
55+
56+ // Resolves the promise created by throttledProcess
57+ async function processRequestResolver (
58+ userId : string ,
59+ timestamp : number ,
60+ tokens : number ,
61+ resolve : ( value : RateLimiterResponse | PromiseLike < RateLimiterResponse > ) => void ,
62+ reject : ( reason : any ) => void
63+ ) {
64+ try {
65+ const response = await rateLimiter . processRequest ( userId , timestamp , tokens ) ;
66+ requestQueue [ userId ] = requestQueue [ userId ] . slice ( 1 ) ;
67+ // trigger the next event
68+ resolve ( response ) ;
69+ requestEvents . emit ( requestQueue [ userId ] [ 0 ] ) ;
70+ if ( requestQueue [ userId ] . length === 0 ) delete requestQueue [ userId ] ;
71+ } catch ( err ) {
72+ reject ( err ) ;
73+ }
74+ }
75+
76+ /**
77+ * Throttle rateLimiter.processRequest based on user IP to prevent inaccurate redis reads
78+ * Throttling is based on a event driven promise fulfillment approach.
79+ * Each time a request is received a promise is added to the user's request queue. The promise "subscribes"
80+ * to the previous request in the user's queue then calls processRequest and resolves once the previous request
81+ * is complete.
82+ * @param userId
83+ * @param timestamp
84+ * @param tokens
85+ * @returns
86+ */
87+ async function throttledProcess (
88+ userId : string ,
89+ timestamp : number ,
90+ tokens : number
91+ ) : Promise < RateLimiterResponse > {
92+ // Alternatively use crypto.randomUUID() to generate a random uuid
93+ const requestId = `${ timestamp } ${ tokens } ` ;
94+
95+ if ( ! requestQueue [ userId ] ) {
96+ requestQueue [ userId ] = [ ] ;
97+ }
98+ requestQueue [ userId ] . push ( requestId ) ;
99+
100+ return new Promise ( ( resolve , reject ) => {
101+ if ( requestQueue [ userId ] . length > 1 ) {
102+ requestEvents . once ( requestId , async ( ) => {
103+ await processRequestResolver ( userId , timestamp , tokens , resolve , reject ) ;
104+ } ) ;
105+ } else {
106+ processRequestResolver ( userId , timestamp , tokens , resolve , reject ) ;
107+ }
108+ } ) ;
109+ }
110+
47111 return async (
48112 req : Request ,
49113 res : Response ,
@@ -57,7 +121,7 @@ export function expressRateLimiter(
57121 return next ( ) ;
58122 }
59123 /**
60- * There are numorous ways to get the ip address off of the request object.
124+ * There are numerous ways to get the ip address off of the request object.
61125 * - the header 'x-forward-for' will hold the originating ip address if a proxy is placed infront of the server. This would be commen for a production build.
62126 * - req.ips wwill hold an array of ip addresses in'x-forward-for' header. client is likely at index zero
63127 * - req.ip will have the ip address
@@ -66,7 +130,7 @@ export function expressRateLimiter(
66130 * req.ip and req.ips will worx in express but not with other frameworks
67131 */
68132 // check for a proxied ip address before using the ip address on request
69- const ip : string = req . ips [ 0 ] || req . ip ;
133+ const ip : string = req . ips ? req . ips [ 0 ] : req . ip ;
70134
71135 // FIXME: this will only work with type complexity
72136 const queryAST = parse ( query ) ;
@@ -80,9 +144,9 @@ export function expressRateLimiter(
80144
81145 const queryComplexity = getQueryTypeComplexity ( queryAST , variables , typeWeightObject ) ;
82146 try {
83- // process the request and conditinoally respond to client with status code 429 o
84- // r pass the request onto the next middleware function
85- const rateLimiterResponse = await rateLimiter . processRequest (
147+ // process the request and conditinoally respond to client with status code 429 or
148+ // pass the request onto the next middleware function
149+ const rateLimiterResponse = await throttledProcess (
86150 ip ,
87151 requestTimestamp ,
88152 queryComplexity
0 commit comments