Skip to content

Commit 76f8e4b

Browse files
samuelcolvinKludex
andauthored
fix client disconnect (#134)
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
1 parent 604fe2f commit 76f8e4b

File tree

9 files changed

+98
-67
lines changed

9 files changed

+98
-67
lines changed

gateway/src/auth.ts

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import type { GatewayOptions } from '.'
22
import type { RateLimiter } from './rateLimiter'
33
import type { ApiKeyInfo } from './types'
4-
import { ResponseError, runAfter } from './utils'
4+
import { runAfter, textResponse } from './utils'
55

66
const CACHE_TTL = 86400 * 30
77

@@ -10,12 +10,12 @@ export async function apiKeyAuth(
1010
ctx: ExecutionContext,
1111
options: GatewayOptions,
1212
rateLimiter: RateLimiter,
13-
): Promise<ApiKeyInfo> {
13+
): Promise<ApiKeyInfo | Response> {
1414
const authorization = request.headers.get('authorization')
1515
const xApiKey = request.headers.get('x-api-key')
1616

1717
if (authorization && xApiKey) {
18-
throw new ResponseError(401, 'Unauthorized - Both Authorization and X-API-Key headers are present, use only one')
18+
return textResponse(401, 'Unauthorized - Both Authorization and X-API-Key headers are present, use only one')
1919
}
2020

2121
const authHeader = authorization || xApiKey
@@ -28,11 +28,11 @@ export async function apiKeyAuth(
2828
key = authHeader
2929
}
3030
} else {
31-
throw new ResponseError(401, 'Unauthorized - Missing Authorization Header')
31+
return textResponse(401, 'Unauthorized - Missing Authorization Header')
3232
}
3333
// avoid very long queries to the DB
3434
if (key.length > 100) {
35-
throw new ResponseError(401, 'Unauthorized - Key too long')
35+
return textResponse(401, 'Unauthorized - Key too long')
3636
}
3737

3838
const cacheKey = apiKeyCacheKey(key, options.kvVersion)
@@ -46,7 +46,10 @@ export async function apiKeyAuth(
4646
options.kv.get(projectStateCacheKey(apiKeyInfo.project, options.kvVersion)),
4747
rateLimiter.requestStart(apiKeyInfo),
4848
])
49-
processLimiterResult(limiterResult)
49+
const limiterResponse = processLimiterResult(limiterResult)
50+
if (limiterResponse) {
51+
return limiterResponse
52+
}
5053
// we only return a cache match if the project state is the same, so updating the project state invalidates the cache
5154
// projectState is null if we have never invalidated the cache which will only be true for the first request after a deployment
5255
if (projectState === null || projectState === cacheResult.metadata) {
@@ -58,12 +61,16 @@ export async function apiKeyAuth(
5861
const apiKeyInfo = await options.keysDb.getApiKey(key)
5962
if (apiKeyInfo) {
6063
if (!rateLimiterStarted) {
61-
processLimiterResult(await rateLimiter.requestStart(apiKeyInfo))
64+
const limiterResult = await rateLimiter.requestStart(apiKeyInfo)
65+
const limiterResponse = processLimiterResult(limiterResult)
66+
if (limiterResponse) {
67+
return limiterResponse
68+
}
6269
}
6370
runAfter(ctx, 'setApiKeyCache', setApiKeyCache(apiKeyInfo, options))
6471
return apiKeyInfo
6572
}
66-
throw new ResponseError(401, 'Unauthorized - Key not found')
73+
return textResponse(401, 'Unauthorized - Key not found')
6774
}
6875

6976
export async function setApiKeyCache(
@@ -99,6 +106,6 @@ const projectStateCacheKey = (project: number, kvVersion: string) => `projectSta
99106

100107
function processLimiterResult(limiterResult: string | null) {
101108
if (typeof limiterResult === 'string') {
102-
throw new ResponseError(429, limiterResult)
109+
return textResponse(429, limiterResult)
103110
}
104111
}

gateway/src/gateway.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ export async function gateway(
3535
}
3636

3737
const rateLimiter = options.rateLimiter ?? noopLimiter
38-
const apiKeyInfo = await apiKeyAuth(request, ctx, options, rateLimiter)
38+
const authResult = await apiKeyAuth(request, ctx, options, rateLimiter)
39+
if (authResult instanceof Response) {
40+
return authResult
41+
}
42+
const apiKeyInfo = authResult
3943
try {
4044
return await gatewayWithLimiter(request, restOfPath, apiType, apiKeyInfo, ctx, options)
4145
} finally {

gateway/src/index.ts

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import { gateway } from './gateway'
2020
import type { DefaultProviderProxy, Middleware, Next } from './providers/default'
2121
import type { RateLimiter } from './rateLimiter'
2222
import type { SubFetch } from './types'
23-
import { ctHeader, ResponseError, response405, textResponse } from './utils'
23+
import { ctHeader, response405, runAfter, textResponse } from './utils'
2424

2525
export { changeProjectState as setProjectState, deleteApiKeyCache, setApiKeyCache } from './auth'
2626
export type { DefaultProviderProxy, Middleware, Next }
@@ -56,15 +56,13 @@ export async function gatewayFetch(
5656
if (proxyPath === '/') {
5757
return index(request, options)
5858
} else {
59-
return await gateway(request, `${proxyPath}${queryString}`, ctx, options)
59+
const gatewayPromise = gateway(request, `${proxyPath}${queryString}`, ctx, options)
60+
runAfter(ctx, 'gatewayPromise', gatewayPromise)
61+
return await gatewayPromise
6062
}
6163
} catch (error) {
62-
if (error instanceof ResponseError) {
63-
logfire.reportError('ResponseError', error)
64-
return error.response()
65-
} else {
66-
throw error
67-
}
64+
logfire.reportError('ResponseError', error as Error)
65+
throw error
6866
}
6967
}
7068

gateway/src/providers/anthropic.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import type { ModelAPI } from '../api'
22
import { AnthropicAPI } from '../api/anthropic'
3-
import { DefaultProviderProxy } from './default'
3+
import { DefaultProviderProxy, type ProxyInvalidRequest } from './default'
44

55
export class AnthropicProvider extends DefaultProviderProxy {
66
protected isWhitelistedEndpoint(): boolean {
@@ -15,8 +15,9 @@ export class AnthropicProvider extends DefaultProviderProxy {
1515
}
1616

1717
// biome-ignore lint/suspicious/useAwait: required by google auth
18-
protected async requestHeaders(headers: Headers): Promise<void> {
18+
protected async requestHeaders(headers: Headers): Promise<ProxyInvalidRequest | null> {
1919
headers.set('x-api-key', this.providerProxy.credentials)
20+
return null
2021
}
2122

2223
protected responseHeaders(headers: Headers): Headers {

gateway/src/providers/default.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,9 @@ export class DefaultProviderProxy {
186186
}
187187

188188
// biome-ignore lint/suspicious/useAwait: required by google auth
189-
protected async requestHeaders(headers: Headers): Promise<void> {
189+
protected async requestHeaders(headers: Headers): Promise<ProxyInvalidRequest | null> {
190190
headers.set('Authorization', `Bearer ${this.providerProxy.credentials}`)
191+
return null
191192
}
192193

193194
protected async prepRequest(): Promise<Prepare | ProxyInvalidRequest> {
@@ -289,7 +290,10 @@ export class DefaultProviderProxy {
289290
requestHeaders.set('user-agent', this.userAgent())
290291
// authorization header was used by the gateway auth, it definitely should not be forwarded to the target api
291292
requestHeaders.delete('authorization')
292-
await this.requestHeaders(requestHeaders)
293+
const requestHeadersError = await this.requestHeaders(requestHeaders)
294+
if (requestHeadersError) {
295+
return requestHeadersError
296+
}
293297

294298
const prepResult = await this.prepRequest()
295299
if ('error' in prepResult) {

gateway/src/providers/google/auth.ts

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,48 @@
1-
import { ResponseError } from '../../utils'
1+
import type { ProxyInvalidRequest } from '../default'
22

3-
export async function authToken(credentials: string, kv: KVNamespace, subFetch: typeof fetch): Promise<string> {
3+
export async function authToken(
4+
credentials: string,
5+
kv: KVNamespace,
6+
subFetch: typeof fetch,
7+
): Promise<{ token: string } | ProxyInvalidRequest> {
48
const serviceAccountHash = await hash(credentials)
59
const cacheKey = `gcp-auth:${serviceAccountHash}`
610
const cachedToken = await kv.get(cacheKey, { cacheTtl: 300 })
711
if (cachedToken) {
8-
return cachedToken
12+
return { token: cachedToken }
913
}
10-
const serviceAccount = getServiceAccount(credentials)
11-
const jwt = await jwtSign(serviceAccount)
12-
const token = await getAccessToken(jwt, subFetch)
13-
await kv.put(cacheKey, token, { expirationTtl: 3000 })
14-
return token
14+
const serviceAccountResult = getServiceAccount(credentials)
15+
if ('error' in serviceAccountResult) {
16+
return serviceAccountResult
17+
}
18+
const jwt = await jwtSign(serviceAccountResult)
19+
const tokenResult = await getAccessToken(jwt, subFetch)
20+
if ('error' in tokenResult) {
21+
return tokenResult
22+
}
23+
await kv.put(cacheKey, tokenResult.token, { expirationTtl: 3000 })
24+
return tokenResult
1525
}
1626

17-
function getServiceAccount(credentials: string): ServiceAccount {
27+
export function getServiceAccount(credentials: string): ServiceAccount | ProxyInvalidRequest {
1828
let sa: ServiceAccount
1929
try {
2030
sa = JSON.parse(credentials)
2131
} catch (error) {
22-
throw new ResponseError(400, `provider credentials are not valid JSON: ${error as Error}`)
32+
return { error: `provider credentials are not valid JSON: ${error as Error}` }
2333
}
2434
if (typeof sa.client_email !== 'string') {
25-
throw new ResponseError(400, `"client_email" should be a string, not ${typeof sa.client_email}`)
35+
return { error: `"client_email" should be a string, not ${typeof sa.client_email}` }
2636
}
2737
if (typeof sa.private_key !== 'string') {
28-
throw new ResponseError(400, `"private_key" should be a string, not ${typeof sa.private_key}`)
38+
return { error: `"private_key" should be a string, not ${typeof sa.private_key}` }
2939
}
3040
if (typeof sa.project_id !== 'string') {
31-
throw new ResponseError(400, `"project_id" should be a string, not ${typeof sa.project_id}`)
41+
return { error: `"project_id" should be a string, not ${typeof sa.project_id}` }
3242
}
3343
return { client_email: sa.client_email, private_key: sa.private_key, project_id: sa.project_id }
3444
}
3545

36-
export function getProjectId(credentials: string): string {
37-
const sa = getServiceAccount(credentials)
38-
return sa.project_id
39-
}
40-
4146
interface ServiceAccount {
4247
client_email: string
4348
private_key: string
@@ -80,7 +85,7 @@ async function jwtSign(serviceAccount: ServiceAccount): Promise<string> {
8085
return `${signingInput}.${b64UrlEncodeArray(signature)}`
8186
}
8287

83-
async function getAccessToken(jwt: string, subFetch: typeof fetch): Promise<string> {
88+
async function getAccessToken(jwt: string, subFetch: typeof fetch): Promise<{ token: string } | ProxyInvalidRequest> {
8489
const body = new URLSearchParams({ grant_type: 'urn:ietf:params:oauth:grant-type:jwt-bearer', assertion: jwt })
8590

8691
const response = await subFetch(tokenUrl, {
@@ -92,10 +97,10 @@ async function getAccessToken(jwt: string, subFetch: typeof fetch): Promise<stri
9297

9398
if (response.ok) {
9499
const responseData: TokenResponse = await response.json()
95-
return responseData.access_token
100+
return { token: responseData.access_token }
96101
} else {
97102
const text = await response.text()
98-
throw new ResponseError(400, `Failed to get GCP access token, response:\n${response.status}: ${text}`)
103+
return { error: `Failed to get GCP access token, response:\n${response.status}: ${text}` }
99104
}
100105
}
101106

gateway/src/providers/google/index.ts

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
import type { ModelAPI } from '../../api'
22
import { AnthropicAPI } from '../../api/anthropic'
33
import { GoogleAPI, type GoogleRequest } from '../../api/google'
4-
import { DefaultProviderProxy } from '../default'
5-
import { authToken, getProjectId } from './auth'
4+
import { DefaultProviderProxy, type ProxyInvalidRequest } from '../default'
5+
import { authToken, getServiceAccount } from './auth'
66

77
export class GoogleVertexProvider extends DefaultProviderProxy {
88
protected usageField = 'usageMetadata'
99
flavor: 'default' | 'anthropic' = 'default'
1010

1111
url() {
1212
if (this.providerProxy.baseUrl) {
13-
const projectId = getProjectId(this.providerProxy.credentials)
13+
const serviceAccountResult = getServiceAccount(this.providerProxy.credentials)
14+
if ('error' in serviceAccountResult) {
15+
return serviceAccountResult
16+
}
17+
const projectId = serviceAccountResult.project_id
1418
const region = regionFromUrl(this.providerProxy.baseUrl)
1519
if (!region) {
1620
return { error: 'Unable to extract region from URL' }
@@ -92,9 +96,14 @@ export class GoogleVertexProvider extends DefaultProviderProxy {
9296
}
9397
}
9498

95-
async requestHeaders(headers: Headers): Promise<void> {
96-
const token = await authToken(this.providerProxy.credentials, this.options.kv, this.options.subFetch)
97-
headers.set('Authorization', `Bearer ${token}`)
99+
async requestHeaders(headers: Headers): Promise<ProxyInvalidRequest | null> {
100+
const tokenResult = await authToken(this.providerProxy.credentials, this.options.kv, this.options.subFetch)
101+
if ('error' in tokenResult) {
102+
return tokenResult
103+
} else {
104+
headers.set('Authorization', `Bearer ${tokenResult.token}`)
105+
return null
106+
}
98107
}
99108
}
100109

gateway/src/utils.ts

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,6 @@ export function getIP(request: Request): string {
3030
}
3131
}
3232

33-
export class ResponseError extends Error {
34-
status: number
35-
message: string
36-
37-
constructor(status: number, message: string) {
38-
super(message)
39-
this.status = status
40-
this.message = message
41-
}
42-
43-
response(): Response {
44-
return textResponse(this.status, this.message)
45-
}
46-
}
47-
4833
export function runAfter(ctx: ExecutionContext, name: string, promise: Promise<unknown>) {
4934
ctx.waitUntil(wrapLogfire(name, promise))
5035
}

gateway/test/auth.spec.ts

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,24 @@ class CountingKeysDb implements KeysDb {
2525
}
2626
}
2727

28+
describe('apiKeyAuth fails', () => {
29+
test('no header', async () => {
30+
const ctx = createExecutionContext()
31+
const baseOptions = buildGatewayEnv(env, [], fetch)
32+
const countingDb = new CountingKeysDb(baseOptions.keysDb)
33+
const options = { ...baseOptions, keysDb: countingDb }
34+
35+
const request = new Request('https://example.com')
36+
37+
// First call should fetch from DB
38+
const result = await apiKeyAuth(request, ctx, options, noopLimiter)
39+
expect(result).instanceOf(Response)
40+
const response = result as Response
41+
expect(response.status).toBe(401)
42+
expect(await response.text()).toEqual('Unauthorized - Missing Authorization Header')
43+
})
44+
})
45+
2846
describe('apiKeyAuth cache invalidation', () => {
2947
test('caches api key and returns cached value', async () => {
3048
const ctx = createExecutionContext()
@@ -36,7 +54,7 @@ describe('apiKeyAuth cache invalidation', () => {
3654

3755
// First call should fetch from DB
3856
const apiKey1 = await apiKeyAuth(request, ctx, options, noopLimiter)
39-
expect(apiKey1.key).toBe('healthy')
57+
expect((apiKey1 as ApiKeyInfo).key).toBe('healthy')
4058
// Wait for cache to be set (it's set asynchronously via runAfter)
4159
await waitOnExecutionContext(ctx)
4260
expect(countingDb.callCount).toBe(1)
@@ -48,7 +66,7 @@ describe('apiKeyAuth cache invalidation', () => {
4866
// Second call should use cache, not hit DB
4967
const ctx2 = createExecutionContext()
5068
const apiKey2 = await apiKeyAuth(request, ctx2, options, noopLimiter)
51-
expect(apiKey2.key).toBe('healthy')
69+
expect((apiKey2 as ApiKeyInfo).key).toBe('healthy')
5270

5371
expect(countingDb.callCount).toBe(1)
5472
})
@@ -85,7 +103,7 @@ describe('apiKeyAuth cache invalidation', () => {
85103
// Third call - cache is invalidated, should hit DB again
86104
const ctx3 = createExecutionContext()
87105
const apiKey3 = await apiKeyAuth(request, ctx3, options, noopLimiter)
88-
expect(apiKey3.key).toBe('healthy')
106+
expect((apiKey3 as ApiKeyInfo).key).toBe('healthy')
89107
await waitOnExecutionContext(ctx3)
90108

91109
expect(countingDb.callCount).toBe(2)

0 commit comments

Comments
 (0)