Skip to content

Commit 8fea6f6

Browse files
authored
feat: implement fallback (#121)
1 parent 0bc9ab3 commit 8fea6f6

File tree

5 files changed

+210
-41
lines changed

5 files changed

+210
-41
lines changed

gateway/src/gateway.ts

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -54,33 +54,51 @@ export async function gateway(
5454
// sort providers on priority, highest first
5555
providerProxies.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0))
5656

57-
const providerProxy = providerProxies[0]
58-
if (!providerProxy) {
57+
if (providerProxies.length === 0) {
5958
return textResponse(403, 'Forbidden - Provider not supported by this API Key')
6059
}
6160

6261
const otel = new OtelTrace(request, apiKeyInfo.otelSettings, options)
6362

64-
const ProxyCls = getProvider(providerProxy.providerId)
65-
66-
const dispatchSpan = otel.startSpan()
67-
const proxy = new ProxyCls({
68-
request,
69-
gatewayOptions: options,
70-
apiKeyInfo,
71-
providerProxy,
72-
restOfPath,
73-
ctx,
74-
middlewares: options.proxyMiddlewares,
75-
otelSpan: dispatchSpan,
76-
})
77-
78-
const result = await proxy.dispatch()
79-
80-
// This doesn't work on streaming because the `result` object is returned as soon as we create the streaming response.
81-
if (!('responseStream' in result) && !('response' in result)) {
82-
const [spanName, attributes, level] = genAiOtelAttributes(result, proxy)
83-
dispatchSpan.end(spanName, attributes, { level })
63+
// The AI did this, but I actually find it nice.
64+
let result!: Awaited<ReturnType<InstanceType<ReturnType<typeof getProvider>>['dispatch']>>
65+
66+
for (const providerProxy of providerProxies) {
67+
const ProxyCls = getProvider(providerProxy.providerId)
68+
69+
const otelSpan = otel.startSpan()
70+
const proxy = new ProxyCls({
71+
// Since the body is consumed by the proxy, we need to clone the request.
72+
request: request.clone(),
73+
gatewayOptions: options,
74+
apiKeyInfo,
75+
providerProxy,
76+
restOfPath,
77+
ctx,
78+
middlewares: options.proxyMiddlewares,
79+
otelSpan,
80+
})
81+
82+
result = await proxy.dispatch()
83+
84+
// Those responses are already closing the `otelSpan`.
85+
if (!('responseStream' in result) && !('response' in result) && !('unexpectedStatus' in result)) {
86+
const [spanName, attributes, level] = genAiOtelAttributes(result, proxy)
87+
otelSpan.end(spanName, attributes, { level })
88+
}
89+
90+
// Check if we should retry with the next provider.
91+
if ('unexpectedStatus' in result && isRetryableError(result.unexpectedStatus)) {
92+
logfire.info('Provider failed with retryable error, trying next provider', {
93+
providerId: providerProxy.providerId,
94+
status: result.unexpectedStatus,
95+
routingGroup: providerProxy.routingGroup,
96+
})
97+
continue
98+
}
99+
100+
// If it succeeds, or it's not a retryable error, we can break out of the loop.
101+
break
84102
}
85103

86104
let response: Response
@@ -129,7 +147,7 @@ export async function gateway(
129147
response = new Response(responseBody, { status: unexpectedStatus, headers: responseHeaders })
130148
}
131149

132-
// TODO(Marcelo): This needs a bit of refactoring. We need the `dispatchSpan` to be closed before we send the spans.
150+
// TODO(Marcelo): This needs a bit of refactoring. We need the `otelSpan` to be closed before we send the spans.
133151
if (!('responseStream' in result)) {
134152
runAfter(ctx, 'otel.send', otel.send())
135153
}
@@ -231,3 +249,7 @@ function calculateExpirationTtl(ex: ExceededScope[]): number | undefined {
231249
d.setHours(23, 59, 59)
232250
return Math.floor((d.getTime() - now.getTime()) / 1000)
233251
}
252+
253+
function isRetryableError(status: number): boolean {
254+
return status === 429 || (status >= 500 && status <= 599)
255+
}

gateway/src/otel/attributes.ts

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
1-
import type {
2-
DefaultProviderProxy,
3-
ProxyInvalidRequest,
4-
ProxySuccess,
5-
ProxyUnexpectedResponse,
6-
} from '../providers/default'
1+
import type { DefaultProviderProxy, ProxyInvalidRequest, ProxySuccess } from '../providers/default'
72
import type { Attributes, Level } from '.'
83
import type { InputMessages, OutputMessages, TextPart } from './genai'
94

105
export function genAiOtelAttributes(
11-
result: ProxySuccess | ProxyInvalidRequest | ProxyUnexpectedResponse,
6+
result: ProxySuccess | ProxyInvalidRequest,
127
provider: DefaultProviderProxy,
138
): [string, Attributes, Level] {
149
const { requestModel } = result
@@ -39,21 +34,11 @@ export function genAiOtelAttributes(
3934
'gen_ai.usage.cache_audio_read_tokens': usage.cache_audio_read_tokens,
4035
'gen_ai.usage.output_audio_tokens': usage.output_audio_tokens,
4136
}
42-
} else if ('error' in result) {
37+
} else {
4338
const { error } = result
4439
spanName = `chat ${requestModel ?? 'unknown-model'}, invalid request {error}`
4540
attributes = { ...attributes, error }
4641
level = 'error'
47-
} else {
48-
const { unexpectedStatus, requestBody, responseBody } = result
49-
spanName = `chat ${requestModel ?? 'unknown-model'}, unexpected response: {http.response.status_code}`
50-
attributes = {
51-
...attributes,
52-
'http.response.status_code': unexpectedStatus,
53-
'http.request.body.text': requestBody,
54-
'http.response.body.text': responseBody,
55-
}
56-
level = 'warn'
5742
}
5843
return [spanName, attributes, level]
5944
}

gateway/src/providers/default.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,17 @@ export class DefaultProviderProxy {
312312
if (!response.ok) {
313313
// CAUTION: can we be charged in any way for failed requests?
314314
const responseBody = await response.text()
315+
this.otelSpan.end(
316+
`chat ${requestModel ?? 'unknown-model'}, unexpected response: {http.response.status_code}`,
317+
{
318+
...attributesFromRequest(this.request),
319+
...attributesFromResponse(response),
320+
'http.request.body.text': requestBodyText,
321+
'http.response.body.text': responseBody,
322+
'http.response.status_code': response.status,
323+
},
324+
{ level: 'warn' },
325+
)
315326
return {
316327
requestModel,
317328
requestBody: requestBodyText,

gateway/test/gateway.spec.ts

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,125 @@ describe('custom middleware', () => {
268268
expect(responses).lengthOf(1)
269269
})
270270
})
271+
272+
describe('routing group fallback', () => {
273+
test('should fallback to next provider on retryable error', async () => {
274+
let attemptCount = 0
275+
const providerAttempts: string[] = []
276+
277+
class FailFirstMiddleware implements Middleware {
278+
dispatch(next: Next): Next {
279+
return async (proxy: DefaultProviderProxy) => {
280+
attemptCount++
281+
const baseUrl = (proxy as unknown as { providerProxy: { baseUrl: string } }).providerProxy.baseUrl
282+
providerAttempts.push(baseUrl)
283+
284+
// First provider should fail with 503
285+
if (baseUrl.includes('provider1')) {
286+
return {
287+
requestModel: 'gpt-5',
288+
requestBody: '{}',
289+
unexpectedStatus: 503,
290+
responseHeaders: new Headers(),
291+
responseBody: JSON.stringify({ error: 'Service unavailable' }),
292+
}
293+
}
294+
295+
// Second provider should succeed
296+
return await next(proxy)
297+
}
298+
}
299+
}
300+
301+
const ctx = createExecutionContext()
302+
const request = new Request<unknown, IncomingRequestCfProperties>('https://example.com/chat/gpt-5', {
303+
method: 'POST',
304+
headers: { Authorization: 'fallback-test', 'pydantic-ai-gateway-routing-group': 'test-group' },
305+
body: JSON.stringify({ model: 'gpt-5', messages: [{ role: 'user', content: 'Hello' }] }),
306+
})
307+
308+
const gatewayEnv = buildGatewayEnv(env, [], fetch, undefined, [new FailFirstMiddleware()])
309+
const response = await gatewayFetch(request, new URL(request.url), ctx, gatewayEnv)
310+
await waitOnExecutionContext(ctx)
311+
312+
expect(response.status).toBe(200)
313+
expect(attemptCount).toBe(2)
314+
expect(providerAttempts).toEqual(['http://test.example.com/provider1', 'http://test.example.com/provider2'])
315+
316+
// Verify the response came from the second provider
317+
const content = (await response.json()) as { choices: [{ message: { content: string } }] }
318+
expect(content.choices[0].message.content).toMatchInlineSnapshot(
319+
`"request URL: http://test.example.com/provider2/gpt-5"`,
320+
)
321+
})
322+
323+
test('should not fallback on non-retryable error', async () => {
324+
let attemptCount = 0
325+
326+
class FailWithBadRequestMiddleware implements Middleware {
327+
dispatch(_next: Next): Next {
328+
return (_proxy: DefaultProviderProxy) => {
329+
attemptCount++
330+
// Return 400 error (non-retryable)
331+
return Promise.resolve({
332+
requestModel: 'gpt-5',
333+
requestBody: '{}',
334+
unexpectedStatus: 400,
335+
responseHeaders: new Headers(),
336+
responseBody: JSON.stringify({ error: 'Bad request' }),
337+
})
338+
}
339+
}
340+
}
341+
342+
const ctx = createExecutionContext()
343+
const request = new Request<unknown, IncomingRequestCfProperties>('https://example.com/chat/gpt-5', {
344+
method: 'POST',
345+
headers: { Authorization: 'fallback-test', 'pydantic-ai-gateway-routing-group': 'test-group' },
346+
body: JSON.stringify({ model: 'gpt-5', messages: [{ role: 'user', content: 'Hello' }] }),
347+
})
348+
349+
const gatewayEnv = buildGatewayEnv(env, [], fetch, undefined, [new FailWithBadRequestMiddleware()])
350+
const response = await gatewayFetch(request, new URL(request.url), ctx, gatewayEnv)
351+
await waitOnExecutionContext(ctx)
352+
353+
// Should fail immediately without trying fallback
354+
expect(response.status).toBe(400)
355+
expect(attemptCount).toBe(1)
356+
})
357+
358+
test('should return error if all providers fail', async () => {
359+
let attemptCount = 0
360+
361+
class FailAllMiddleware implements Middleware {
362+
dispatch(_next: Next): Next {
363+
return (_proxy: DefaultProviderProxy) => {
364+
attemptCount++
365+
// Always return 503
366+
return Promise.resolve({
367+
requestModel: 'gpt-5',
368+
requestBody: '{}',
369+
unexpectedStatus: 503,
370+
responseHeaders: new Headers(),
371+
responseBody: JSON.stringify({ error: 'Service unavailable' }),
372+
})
373+
}
374+
}
375+
}
376+
377+
const ctx = createExecutionContext()
378+
const request = new Request<unknown, IncomingRequestCfProperties>('https://example.com/chat/gpt-5', {
379+
method: 'POST',
380+
headers: { Authorization: 'fallback-test', 'pydantic-ai-gateway-routing-group': 'test-group' },
381+
body: JSON.stringify({ model: 'gpt-5', messages: [{ role: 'user', content: 'Hello' }] }),
382+
})
383+
384+
const gatewayEnv = buildGatewayEnv(env, [], fetch, undefined, [new FailAllMiddleware()])
385+
const response = await gatewayFetch(request, new URL(request.url), ctx, gatewayEnv)
386+
await waitOnExecutionContext(ctx)
387+
388+
// Should try both providers and fail with last error
389+
expect(response.status).toBe(503)
390+
expect(attemptCount).toBe(2)
391+
})
392+
})

gateway/test/worker.ts

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ export namespace IDS {
5050
export const keyHealthy = 4
5151
export const keyDisabled = 5
5252
export const keyTinyLimit = 6
53+
export const keyFallbackTest = 7
5354
}
5455

5556
class TestKeysDB extends KeysDbD1 {
@@ -153,6 +154,34 @@ class TestKeysDB extends KeysDbD1 {
153154
projectSpendingLimitMonthly: 4,
154155
providers: [this.allProviders[0]!],
155156
}
157+
case 'fallback-test':
158+
return {
159+
id: IDS.keyFallbackTest,
160+
project: IDS.projectDefault,
161+
org: IDS.orgDefault,
162+
key,
163+
status: 'active',
164+
providers: [
165+
{
166+
baseUrl: 'http://test.example.com/provider1',
167+
providerId: 'test',
168+
injectCost: true,
169+
credentials: 'test1',
170+
apiTypes: ['chat'],
171+
routingGroup: 'test-group',
172+
priority: 100,
173+
},
174+
{
175+
baseUrl: 'http://test.example.com/provider2',
176+
providerId: 'test',
177+
injectCost: true,
178+
credentials: 'test2',
179+
apiTypes: ['chat'],
180+
routingGroup: 'test-group',
181+
priority: 50,
182+
},
183+
],
184+
}
156185
default:
157186
return null
158187
}

0 commit comments

Comments
 (0)