Skip to content

Commit 84f506b

Browse files
committed
implemented throttled version of processRequest
1 parent a011a77 commit 84f506b

File tree

2 files changed

+85
-15
lines changed

2 files changed

+85
-15
lines changed

src/middleware/index.ts

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import { Request, Response, NextFunction, RequestHandler } from 'express';
66
import buildTypeWeightsFromSchema, { defaultTypeWeightsConfig } from '../analysis/buildTypeWeights';
77
import setupRateLimiter from './rateLimiterSetup';
88
import getQueryTypeComplexity from '../analysis/typeComplexityAnalysis';
9-
import { RateLimiterOptions, RateLimiterSelection } from '../@types/rateLimit';
9+
import { RateLimiterOptions, RateLimiterSelection, RateLimiterResponse } from '../@types/rateLimit';
1010
import { TypeWeightConfig } from '../@types/buildTypeWeights';
1111
import { connect } from '../utils/redis';
1212

@@ -46,6 +46,42 @@ export function expressRateLimiter(
4646
const redisClient = connect(redisClientOptions); // Default port is 6379 automatically
4747
const rateLimiter = setupRateLimiter(rateLimiterAlgo, rateLimiterOptions, redisClient);
4848

49+
// stores request IDs to be processed
50+
const requestQueue: { [index: string]: string[] } = {};
51+
52+
// Throttle rateLimiter.processRequest based on user IP to prent inaccurate redis reads
53+
async function throttledProcess(
54+
userId: string,
55+
timestamp: number,
56+
tokens: number
57+
): Promise<RateLimiterResponse> {
58+
// Generate a random uuid for this request and add it to the queue
59+
// Alternatively use crypto.randomUUID() to generate a uuid
60+
const requestId = `${userId}${timestamp}${tokens}`;
61+
62+
if (!requestQueue[userId]) {
63+
requestQueue[userId] = [];
64+
}
65+
requestQueue[userId].push(requestId);
66+
67+
// Start a loop to check when this request should be processed
68+
return new Promise((resolve, reject) => {
69+
const intervalId = setInterval(async () => {
70+
console.log('in set timeout');
71+
if (requestQueue[userId][0] === requestId) {
72+
// process the request
73+
clearInterval(intervalId);
74+
const response = await rateLimiter.processRequest(userId, timestamp, tokens);
75+
// requestQueue[userId].shift();
76+
requestQueue[userId] = requestQueue[userId].slice(1);
77+
resolve(response);
78+
} else {
79+
console.log('not our turn');
80+
}
81+
}, 100);
82+
});
83+
}
84+
4985
// Sort the requests by timestamps to make sure we process in the correct order
5086
// We need to store the request, response and next object so that the correct one is used
5187
// the function we return accepts the unique request, response, next objects
@@ -64,8 +100,6 @@ export function expressRateLimiter(
64100
// r1, and r2 get processed with thin same frame on call stack
65101
// r2 call is done once r2 is added to the queue
66102

67-
const requestsInProcess: { [index: string]: Request[] } = {};
68-
69103
// return a throttled middleware. Check every 100ms? make this a setting?
70104
// how do we make sure these get queued properly?
71105
// store the requests in an array when available grab the next request for a user
@@ -81,7 +115,6 @@ export function expressRateLimiter(
81115
* Not throttling on time just queueing requests.
82116
*/
83117

84-
// return the rate limiting middleware
85118
return async (
86119
req: Request,
87120
res: Response,
@@ -95,7 +128,7 @@ export function expressRateLimiter(
95128
return next();
96129
}
97130
/**
98-
* There are numorous ways to get the ip address off of the request object.
131+
* There are numerous ways to get the ip address off of the request object.
99132
* - the header 'x-forward-for' will hold the originating ip address if a proxy is placed infront of the server. This would be commen for a production build.
100133
* - req.ips wwill hold an array of ip addresses in'x-forward-for' header. client is likely at index zero
101134
* - req.ip will have the ip address
@@ -105,7 +138,6 @@ export function expressRateLimiter(
105138
*/
106139
// check for a proxied ip address before using the ip address on request
107140
const ip: string = req.ips ? req.ips[0] : req.ip;
108-
// requestsInProcess[ip] = true;
109141

110142
// FIXME: this will only work with type complexity
111143
const queryAST = parse(query);
@@ -121,7 +153,7 @@ export function expressRateLimiter(
121153
try {
122154
// process the request and conditinoally respond to client with status code 429 or
123155
// pass the request onto the next middleware function
124-
const rateLimiterResponse = await rateLimiter.processRequest(
156+
const rateLimiterResponse = await throttledProcess(
125157
ip,
126158
requestTimestamp,
127159
queryComplexity

test/middleware/express.test.ts

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,9 @@ describe('Express Middleware tests', () => {
250250

251251
describe('Adds expected properties to res.locals', () => {
252252
test('Adds UNIX timestamp and complexity', async () => {
253+
jest.useRealTimers();
253254
await middleware(mockRequest as Request, mockResponse as Response, nextFunction);
254-
255+
jest.useFakeTimers();
255256
const expected = {
256257
complexity: expect.any(Number),
257258
timestamp: expect.any(Number),
@@ -271,6 +272,12 @@ describe('Express Middleware tests', () => {
271272
});
272273

273274
describe('Correctly limits requests', () => {
275+
beforeAll(() => {
276+
jest.useRealTimers();
277+
});
278+
afterAll(() => {
279+
jest.useFakeTimers();
280+
});
274281
describe('Allows requests', () => {
275282
test('...a single request', async () => {
276283
// successful request calls next without any arguments.
@@ -296,15 +303,35 @@ describe('Express Middleware tests', () => {
296303
});
297304

298305
test('Multiple valid requests at within one second', async () => {
306+
const requests = new Array(5).fill(0);
307+
299308
for (let i = 0; i < 3; i++) {
300-
const next: NextFunction = jest.fn();
301-
await middleware(mockRequest as Request, mockResponse as Response, next);
302-
expect(next).toBeCalledTimes(1);
303-
expect(next).toBeCalledWith();
309+
// Send 3 queries of complexity 2. These should all succeed
310+
requests.push(
311+
middleware(
312+
mockRequest as Request,
313+
mockResponse as Response,
314+
nextFunction
315+
)
316+
);
304317

305-
// advance the timers by 1 second for the next request
306-
jest.advanceTimersByTime(20);
318+
// advance the timers by 20 miliseconds for the next request
319+
// jest.advanceTimersByTime(20);
307320
}
321+
jest.runAllTimers();
322+
await Promise.all(requests);
323+
expect(nextFunction).toBeCalledTimes(3);
324+
expect(nextFunction).toBeCalledWith();
325+
326+
// for (let i = 0; i < 3; i++) {
327+
// const next: NextFunction = jest.fn();
328+
// await middleware(mockRequest as Request, mockResponse as Response, next);
329+
// expect(next).toBeCalledTimes(1);
330+
// expect(next).toBeCalledWith();
331+
332+
// // advance the timers by 20 milliseconds for the next request
333+
// jest.advanceTimersByTime(20);
334+
// }
308335
});
309336
});
310337

@@ -347,6 +374,7 @@ describe('Express Middleware tests', () => {
347374
});
348375

349376
test('Multiple queries that exceed token limit', async () => {
377+
// jest.useRealTimers();
350378
const requests = new Array(5).fill(0);
351379

352380
for (let i = 0; i < 5; i++) {
@@ -363,9 +391,19 @@ describe('Express Middleware tests', () => {
363391
jest.advanceTimersByTime(20);
364392
}
365393

394+
jest.runAllTimers();
395+
await Promise.all(requests);
366396
// Send a 6th request that should be blocked.
367397
const next: NextFunction = jest.fn();
368-
await middleware(mockRequest as Request, mockResponse as Response, next);
398+
399+
const myPromise = middleware(
400+
mockRequest as Request,
401+
mockResponse as Response,
402+
next
403+
);
404+
jest.runAllTimers();
405+
await myPromise;
406+
369407
expect(mockResponse.status).toHaveBeenCalledWith(429);
370408
expect(next).not.toBeCalled();
371409

0 commit comments

Comments
 (0)