diff --git a/packages/test-utils/lib/index.ts b/packages/test-utils/lib/index.ts index 65a5af5ab55..5f339e9a426 100644 --- a/packages/test-utils/lib/index.ts +++ b/packages/test-utils/lib/index.ts @@ -313,11 +313,10 @@ export default class TestUtils { //@ts-ignore targetPort: socketOptions.port, //@ts-ignore - targetHost: socketOptions.host, + targetHost: socketOptions.host ?? '127.0.0.1', enableLogging: true }); - await proxy.start(); const proxyClient = client.duplicate({ socket: { diff --git a/packages/test-utils/lib/redis-proxy-spec.ts b/packages/test-utils/lib/redis-proxy-spec.ts index 89b3b28c35a..d0a41204553 100644 --- a/packages/test-utils/lib/redis-proxy-spec.ts +++ b/packages/test-utils/lib/redis-proxy-spec.ts @@ -1,7 +1,7 @@ import { strict as assert } from 'node:assert'; import { Buffer } from 'node:buffer'; import { testUtils, GLOBAL } from './test-utils'; -import { RedisProxy } from './redis-proxy'; +import { InterceptorFunction, RedisProxy } from './redis-proxy'; import type { RedisClientType } from '@redis/client/lib/client/index.js'; describe('RedisSocketProxy', function () { @@ -107,5 +107,61 @@ describe('RedisSocketProxy', function () { const pingResult = await proxiedClient.ping(); assert.equal(pingResult, 'PONG', 'Client should be able to communicate with Redis through the proxy'); - }, GLOBAL.SERVERS.OPEN_RESP_3) + }, GLOBAL.SERVERS.OPEN_RESP_3); + + describe("Middleware", () => { + testUtils.testWithProxiedClient( + "Modify request/response via middleware", + async ( + proxiedClient: RedisClientType, + proxy: RedisProxy, + ) => { + + // Intercept PING commands and modify the response + const pingInterceptor: InterceptorFunction = async (data, next) => { + if (data.includes('PING')) { + return Buffer.from("+PINGINTERCEPTED\r\n"); + } + return next(data); + }; + + // Only intercept GET responses and double numeric values + // Does not modify other commands or non-numeric GET responses + const doubleNumberGetInterceptor: InterceptorFunction = async (data, next) => { + const response = await next(data); + + // Not a GET command, return original response + if (!data.includes("GET")) return response; + + const value = (response.toString().split("\r\n"))[1]; + const number = Number(value); + // Not a number, return original response + if(isNaN(number)) return response; + + const doubled = String(number * 2); + return Buffer.from(`$${doubled.length}\r\n${doubled}\r\n`); + }; + + proxy.setInterceptors([ pingInterceptor, doubleNumberGetInterceptor ]) + + const pingResponse = await proxiedClient.ping(); + assert.equal(pingResponse, 'PINGINTERCEPTED', 'Response should be modified by middleware'); + + await proxiedClient.set('foo', 1); + const getResponse1 = await proxiedClient.get('foo'); + assert.equal(getResponse1, '2', 'GET response should be doubled for numbers by middleware'); + + await proxiedClient.set('bar', 'Hi'); + const getResponse2 = await proxiedClient.get('bar'); + assert.equal(getResponse2, 'Hi', 'GET response should not be modified for strings by middleware'); + + await proxiedClient.hSet('baz', 'foo', 'dictvalue'); + const hgetResponse = await proxiedClient.hGet('baz', 'foo'); + assert.equal(hgetResponse, 'dictvalue', 'HGET response should not be modified by middleware'); + + }, + GLOBAL.SERVERS.OPEN_RESP_3, + ); + }); + }); diff --git a/packages/test-utils/lib/redis-proxy.ts b/packages/test-utils/lib/redis-proxy.ts index 217ec528a33..a4ea605285f 100644 --- a/packages/test-utils/lib/redis-proxy.ts +++ b/packages/test-utils/lib/redis-proxy.ts @@ -20,6 +20,7 @@ interface ConnectionInfo { interface ActiveConnection extends ConnectionInfo { readonly clientSocket: net.Socket; readonly serverSocket: net.Socket; + inflightRequestsCount: number } type SendResult = @@ -49,11 +50,16 @@ interface ProxyEvents { 'close': () => void; } +export type Interceptor = (data: Buffer) => Promise; +export type InterceptorFunction = (data: Buffer, next: Interceptor) => Promise; +type InterceptorInitializer = (init: Interceptor) => Interceptor; + export class RedisProxy extends EventEmitter { private readonly server: net.Server; public readonly config: Required; private readonly connections: Map; private isRunning: boolean; + private interceptorInitializer: InterceptorInitializer = (init) => init; constructor(config: ProxyConfig) { super(); @@ -113,6 +119,13 @@ export class RedisProxy extends EventEmitter { }); } + public setInterceptors(interceptors: Array) { + this.interceptorInitializer = (init) => interceptors.reduceRight( + (next, mw) => (data) => mw(data, next), + init + ); + } + public getStats(): ProxyStats { const connections = Array.from(this.connections.values()); @@ -218,19 +231,22 @@ export class RedisProxy extends EventEmitter { } private handleClientConnection(clientSocket: net.Socket): void { - const connectionId = this.generateConnectionId(); + clientSocket.pause(); const serverSocket = net.createConnection({ host: this.config.targetHost, port: this.config.targetPort }); + serverSocket.once('connect', clientSocket.resume.bind(clientSocket)); + const connectionId = this.generateConnectionId(); const connectionInfo: ActiveConnection = { id: connectionId, clientAddress: clientSocket.remoteAddress || 'unknown', clientPort: clientSocket.remotePort || 0, connectedAt: new Date(), clientSocket, - serverSocket + serverSocket, + inflightRequestsCount: 0 }; this.connections.set(connectionId, connectionInfo); @@ -243,12 +259,33 @@ export class RedisProxy extends EventEmitter { this.emit('connection', connectionInfo); }); - clientSocket.on('data', (data) => { + clientSocket.on('data', async (data) => { this.emit('data', connectionId, 'client->server', data); - serverSocket.write(data); + + connectionInfo.inflightRequestsCount++; + + // next1 -> next2 -> ... -> last -> server + // next1 <- next2 <- ... <- last <- server + const last = (data: Buffer): Promise => { + return new Promise((resolve, reject) => { + serverSocket.write(data); + serverSocket.once('data', (data) => { + connectionInfo.inflightRequestsCount--; + assert(connectionInfo.inflightRequestsCount >= 0, `inflightRequestsCount for connection ${connectionId} went below zero`); + this.emit('data', connectionId, 'server->client', data); + resolve(data); + }); + serverSocket.once('error', reject); + }); + }; + + const interceptorChain = this.interceptorInitializer(last); + const response = await interceptorChain(data); + clientSocket.write(response); }); serverSocket.on('data', (data) => { + if (connectionInfo.inflightRequestsCount > 0) return; this.emit('data', connectionId, 'server->client', data); clientSocket.write(data); }); @@ -273,6 +310,7 @@ export class RedisProxy extends EventEmitter { }); serverSocket.on('error', (error) => { + if (connectionInfo.inflightRequestsCount > 0) return; this.log(`Server error for connection ${connectionId}: ${error.message}`); this.emit('error', error, connectionId); clientSocket.destroy(); @@ -306,6 +344,7 @@ export class RedisProxy extends EventEmitter { } } import { createServer } from 'net'; +import assert from 'node:assert'; export function getFreePortNumber(): Promise { return new Promise((resolve, reject) => { @@ -326,4 +365,3 @@ export function getFreePortNumber(): Promise { export { RedisProxy as RedisTransparentProxy }; export type { ProxyConfig, ConnectionInfo, ProxyEvents, SendResult, DataDirection, ProxyStats }; -