Skip to content

Commit d299d9c

Browse files
committed
merge conflicts resolved
2 parents aac7685 + 06dfee5 commit d299d9c

File tree

17 files changed

+1905
-227
lines changed

17 files changed

+1905
-227
lines changed

package.json

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
"main": "index.js",
66
"type": "module",
77
"scripts": {
8-
"test": "jest --passWithNoTests --coverage",
8+
"test": "jest --passWithNoTests --coverage --detectOpenHandles",
99
"lint": "eslint src test",
1010
"lint:fix": "eslint --fix src test @types",
1111
"prettier": "prettier --write .",
12-
"prepare": "husky install"
12+
"prepare": "husky install",
13+
"build": "tsc"
1314
},
1415
"repository": {
1516
"type": "git",

src/@types/rateLimit.d.ts

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,20 @@ export interface RateLimiter {
1616
export interface RateLimiterResponse {
1717
success: boolean;
1818
tokens: number;
19+
retryAfter?: number;
1920
}
2021

2122
export interface RedisBucket {
2223
tokens: number;
2324
timestamp: number;
2425
}
2526

26-
export interface RedisWindow {
27+
export interface FixedWindow {
2728
currentTokens: number;
29+
fixedWindowStart: number;
30+
}
31+
export interface RedisWindow extends FixedWindow {
2832
previousTokens: number;
29-
fixedWindowStart?: number;
3033
}
3134

3235
export type RedisLog = RedisBucket[];
@@ -48,18 +51,15 @@ export interface TokenBucketOptions {
4851
}
4952

5053
/**
51-
* @type {number} windowSize - Size of each fixed window and the rolling window
52-
* @type {number} capacity - Number of tokens a window can hold
54+
* @type {number} windowSize - size of the window in milliseconds
55+
* @type {number} capacity - max number of tokens that can be used in the bucket
5356
*/
54-
export interface SlidingWindowCounterOptions {
57+
export interface WindowOptions {
5558
windowSize: number;
5659
capacity: number;
5760
}
5861

5962
// TODO: This will be a union type where we can specify Option types for other Rate Limiters
60-
// Record<string, never> represents the empty object for alogorithms that don't require settings
63+
// Record<string, never> represents the empty object for algorithms that don't require settings
6164
// and might be able to be removed in the future.
62-
export type RateLimiterOptions =
63-
| TokenBucketOptions
64-
| SlidingWindowCounterOptions
65-
| Record<string, never>;
65+
export type RateLimiterOptions = TokenBucketOptions | Record<string, never>;

src/analysis/buildTypeWeights.ts

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@ import {
1313
isObjectType,
1414
isScalarType,
1515
isUnionType,
16+
isInputType,
1617
Kind,
1718
ValueNode,
1819
GraphQLUnionType,
1920
GraphQLFieldMap,
20-
GraphQLDirective,
21+
isInputObjectType,
2122
} from 'graphql';
2223
import { ObjMap } from 'graphql/jsutils/ObjMap';
2324
import { GraphQLSchema } from 'graphql/type/schema';
@@ -81,11 +82,9 @@ function parseObjectFields(
8182
// Iterate through the fields and add the required data to the result
8283
Object.keys(fields).forEach((field: string) => {
8384
// The GraphQL type that this field represents
84-
const fieldType: GraphQLOutputType = fields[field].type;
85-
if (
86-
isScalarType(fieldType) ||
87-
(isNonNullType(fieldType) && isScalarType(fieldType.ofType))
88-
) {
85+
let fieldType: GraphQLOutputType = fields[field].type;
86+
if (isNonNullType(fieldType)) fieldType = fieldType.ofType;
87+
if (isScalarType(fieldType)) {
8988
result.fields[field] = {
9089
weight: typeWeights.scalar,
9190
// resolveTo: fields[field].name.toLowerCase(),
@@ -101,7 +100,8 @@ function parseObjectFields(
101100
};
102101
} else if (isListType(fieldType)) {
103102
// 'listType' is the GraphQL type that the list resolves to
104-
const listType = fieldType.ofType;
103+
let listType = fieldType.ofType;
104+
if (isNonNullType(listType)) listType = listType.ofType;
105105
if (isScalarType(listType) && typeWeights.scalar === 0) {
106106
// list won't compound if weight is zero
107107
result.fields[field] = {
@@ -135,9 +135,7 @@ function parseObjectFields(
135135
// if no directive is supplied to list field
136136
fields[field].args.forEach((arg: GraphQLArgument) => {
137137
// If field has an argument matching one of the limiting keywords and resolves to a list
138-
// then the weight of the field should be dependent on both the weight of the
139-
// resolved type and the limiting argument.
140-
// FIXME: Can nonnull wrap list types?
138+
// then the weight of the field should be dependent on both the weight of the resolved type and the limiting argument.
141139
if (KEYWORDS.includes(arg.name)) {
142140
// Get the type that comprises the list
143141
result.fields[field] = {
@@ -207,6 +205,7 @@ function compareTypes(a: GraphQLOutputType, b: GraphQLOutputType): boolean {
207205
return (
208206
(isObjectType(b) && isObjectType(a) && a.name === b.name) ||
209207
(isUnionType(b) && isUnionType(a) && a.name === b.name) ||
208+
(isEnumType(b) && isEnumType(a) && a.name === b.name) ||
210209
(isInterfaceType(b) && isInterfaceType(a) && a.name === b.name) ||
211210
(isScalarType(b) && isScalarType(a) && a.name === b.name) ||
212211
(isListType(b) && isListType(a) && compareTypes(b.ofType, a.ofType)) ||
@@ -313,24 +312,26 @@ function parseUnionTypes(
313312
* c. objects have a resolveTo type.
314313
* */
315314

316-
const current = commonFields[field].type;
315+
let current = commonFields[field].type;
316+
if (isNonNullType(current)) current = current.ofType;
317317
if (isScalarType(current)) {
318318
fieldTypes[field] = {
319319
weight: commonFields[field].weight,
320320
};
321-
} else if (isObjectType(current) || isInterfaceType(current) || isUnionType(current)) {
321+
} else if (
322+
isObjectType(current) ||
323+
isInterfaceType(current) ||
324+
isUnionType(current) ||
325+
isEnumType(current)
326+
) {
322327
fieldTypes[field] = {
323328
resolveTo: commonFields[field].resolveTo,
324-
weight: typeWeights.object,
325329
};
326330
} else if (isListType(current)) {
327331
fieldTypes[field] = {
328332
resolveTo: commonFields[field].resolveTo,
329333
weight: commonFields[field].weight,
330334
};
331-
} else if (isNonNullType(current)) {
332-
throw new Error('non null types not supported on unions');
333-
// TODO: also a recursive data structure
334335
} else {
335336
throw new Error('Unhandled union type. Should never get here');
336337
}
@@ -374,7 +375,7 @@ function parseTypes(schema: GraphQLSchema, typeWeights: TypeWeightSet): TypeWeig
374375
};
375376
} else if (isUnionType(currentType)) {
376377
unions.push(currentType);
377-
} else if (!isScalarType(currentType)) {
378+
} else if (!isScalarType(currentType) && !isInputObjectType(currentType)) {
378379
throw new Error(`ERROR: buildTypeWeight: Unsupported type: ${currentType}`);
379380
}
380381
}

src/middleware/index.ts

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
import Redis, { RedisOptions } from 'ioredis';
1+
import EventEmitter from 'events';
2+
23
import { parse, validate } from 'graphql';
4+
import { RedisOptions } from 'ioredis';
35
import { GraphQLSchema } from 'graphql/type/schema';
46
import { Request, Response, NextFunction, RequestHandler } from 'express';
57

68
import buildTypeWeightsFromSchema, { defaultTypeWeightsConfig } from '../analysis/buildTypeWeights';
79
import setupRateLimiter from './rateLimiterSetup';
810
import getQueryTypeComplexity from '../analysis/typeComplexityAnalysis';
9-
import { RateLimiterOptions, RateLimiterSelection } from '../@types/rateLimit';
11+
import { RateLimiterOptions, RateLimiterSelection, RateLimiterResponse } from '../@types/rateLimit';
1012
import { TypeWeightConfig } from '../@types/buildTypeWeights';
13+
import { connect } from '../utils/redis';
1114

1215
// FIXME: Will the developer be responsible for first parsing the schema from a file?
1316
// Can consider accepting a string representing a the filepath to a schema
@@ -39,11 +42,72 @@ export function expressRateLimiter(
3942
*/
4043
// TODO: Throw ValidationError if schema is invalid
4144
const typeWeightObject = buildTypeWeightsFromSchema(schema, typeWeightConfig);
45+
4246
// TODO: Throw error if connection is unsuccessful
43-
const redisClient = new Redis(redisClientOptions); // Default port is 6379 automatically
47+
const redisClient = connect(redisClientOptions); // Default port is 6379 automatically
4448
const rateLimiter = setupRateLimiter(rateLimiterAlgo, rateLimiterOptions, redisClient);
4549

46-
// return the rate limiting middleware
50+
// stores request IDs to be processed
51+
const requestQueue: { [index: string]: string[] } = {};
52+
53+
// Manages processing of event queue
54+
const requestEvents = new EventEmitter();
55+
56+
// Resolves the promise created by throttledProcess
57+
async function processRequestResolver(
58+
userId: string,
59+
timestamp: number,
60+
tokens: number,
61+
resolve: (value: RateLimiterResponse | PromiseLike<RateLimiterResponse>) => void,
62+
reject: (reason: any) => void
63+
) {
64+
try {
65+
const response = await rateLimiter.processRequest(userId, timestamp, tokens);
66+
requestQueue[userId] = requestQueue[userId].slice(1);
67+
// trigger the next event
68+
resolve(response);
69+
requestEvents.emit(requestQueue[userId][0]);
70+
if (requestQueue[userId].length === 0) delete requestQueue[userId];
71+
} catch (err) {
72+
reject(err);
73+
}
74+
}
75+
76+
/**
77+
* Throttle rateLimiter.processRequest based on user IP to prevent inaccurate redis reads
78+
* Throttling is based on a event driven promise fulfillment approach.
79+
* Each time a request is received a promise is added to the user's request queue. The promise "subscribes"
80+
* to the previous request in the user's queue then calls processRequest and resolves once the previous request
81+
* is complete.
82+
* @param userId
83+
* @param timestamp
84+
* @param tokens
85+
* @returns
86+
*/
87+
async function throttledProcess(
88+
userId: string,
89+
timestamp: number,
90+
tokens: number
91+
): Promise<RateLimiterResponse> {
92+
// Alternatively use crypto.randomUUID() to generate a random uuid
93+
const requestId = `${timestamp}${tokens}`;
94+
95+
if (!requestQueue[userId]) {
96+
requestQueue[userId] = [];
97+
}
98+
requestQueue[userId].push(requestId);
99+
100+
return new Promise((resolve, reject) => {
101+
if (requestQueue[userId].length > 1) {
102+
requestEvents.once(requestId, async () => {
103+
await processRequestResolver(userId, timestamp, tokens, resolve, reject);
104+
});
105+
} else {
106+
processRequestResolver(userId, timestamp, tokens, resolve, reject);
107+
}
108+
});
109+
}
110+
47111
return async (
48112
req: Request,
49113
res: Response,
@@ -57,7 +121,7 @@ export function expressRateLimiter(
57121
return next();
58122
}
59123
/**
60-
* There are numorous ways to get the ip address off of the request object.
124+
* There are numerous ways to get the ip address off of the request object.
61125
* - the header 'x-forward-for' will hold the originating ip address if a proxy is placed infront of the server. This would be commen for a production build.
62126
* - req.ips wwill hold an array of ip addresses in'x-forward-for' header. client is likely at index zero
63127
* - req.ip will have the ip address
@@ -66,7 +130,7 @@ export function expressRateLimiter(
66130
* req.ip and req.ips will worx in express but not with other frameworks
67131
*/
68132
// check for a proxied ip address before using the ip address on request
69-
const ip: string = req.ips[0] || req.ip;
133+
const ip: string = req.ips ? req.ips[0] : req.ip;
70134

71135
// FIXME: this will only work with type complexity
72136
const queryAST = parse(query);
@@ -80,9 +144,9 @@ export function expressRateLimiter(
80144

81145
const queryComplexity = getQueryTypeComplexity(queryAST, variables, typeWeightObject);
82146
try {
83-
// process the request and conditinoally respond to client with status code 429 o
84-
// r pass the request onto the next middleware function
85-
const rateLimiterResponse = await rateLimiter.processRequest(
147+
// process the request and conditinoally respond to client with status code 429 or
148+
// pass the request onto the next middleware function
149+
const rateLimiterResponse = await throttledProcess(
86150
ip,
87151
requestTimestamp,
88152
queryComplexity

src/middleware/rateLimiterSetup.ts

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import Redis from 'ioredis';
2-
import { RateLimiterOptions, RateLimiterSelection, TokenBucketOptions } from '../@types/rateLimit';
3-
import SlidingWindowCounter from '../rateLimiters/slidingWindowCounter';
2+
import { RateLimiterOptions, RateLimiterSelection } from '../@types/rateLimit';
43
import TokenBucket from '../rateLimiters/tokenBucket';
54

65
/**
@@ -26,13 +25,12 @@ export default function setupRateLimiter(
2625
break;
2726
case 'LEAKY_BUCKET':
2827
throw new Error('Leaky Bucket algonithm has not be implemented.');
29-
break;
3028
case 'FIXED_WINDOW':
3129
throw new Error('Fixed Window algonithm has not be implemented.');
32-
break;
3330
case 'SLIDING_WINDOW_LOG':
34-
throw new Error('Sliding Window Log has not be implemented.');
35-
break;
31+
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
32+
// @ts-ignore
33+
return new SlidingWindowLog(options.windowSize, options.capacity, client);
3634
case 'SLIDING_WINDOW_COUNTER':
3735
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
3836
// @ts-ignore
@@ -41,6 +39,5 @@ export default function setupRateLimiter(
4139
default:
4240
// typescript should never let us invoke this function with anything other than the options above
4341
throw new Error('Selected rate limiting algorithm is not suppported');
44-
break;
4542
}
4643
}

0 commit comments

Comments
 (0)