Skip to content

Commit cdcfc11

Browse files
Kludexdmontagu
andauthored
Changes (#133)
Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com>
1 parent 76f8e4b commit cdcfc11

File tree

37 files changed

+1060
-1078
lines changed

37 files changed

+1060
-1078
lines changed

deploy/example.config.ts

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,35 +47,25 @@ export const config: Config<ProviderKeys> = {
4747
injectCost: true,
4848
// credentials are used by the ProviderProxy to authenticate the forwarded request
4949
credentials: env.OPENAI_API_KEY,
50-
apiTypes: ['chat', 'responses'],
51-
},
52-
b: {
53-
providerId: 'groq',
54-
baseUrl: 'https://api.groq.com',
55-
injectCost: true,
56-
credentials: env.GROQ_API_KEY,
57-
apiTypes: ['groq'],
5850
},
51+
b: { providerId: 'groq', baseUrl: 'https://api.groq.com', injectCost: true, credentials: env.GROQ_API_KEY },
5952
c: {
6053
providerId: 'google-vertex',
6154
baseUrl: 'https://us-central1-aiplatform.googleapis.com',
6255
injectCost: true,
6356
credentials: env.GOOGLE_SERVICE_ACCOUNT_KEY,
64-
apiTypes: ['gemini', 'anthropic'],
6557
},
6658
d: {
6759
providerId: 'anthropic',
6860
baseUrl: 'https://api.anthropic.com',
6961
injectCost: true,
7062
credentials: env.ANTHROPIC_API_KEY,
71-
apiTypes: ['anthropic'],
7263
},
7364
e: {
7465
providerId: 'bedrock',
7566
baseUrl: 'https://bedrock-runtime.us-east-1.amazonaws.com',
7667
injectCost: true,
7768
credentials: env.AWS_BEARER_TOKEN_BEDROCK,
78-
apiTypes: ['anthropic', 'converse'],
7969
},
8070
},
8171
// individual apiKeys

deploy/src/db.ts

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,19 @@ export class ConfigDB extends KeysDbD1 {
1010
const project = config.projects[keyInfo.project]!
1111
const user = keyInfo.user ? project.users[keyInfo.user] : undefined
1212

13-
let providers: ProviderProxy[]
13+
let providersWithKeys: (ProviderProxy & { key: string })[]
1414
if (keyInfo.providers === '__all__') {
15-
providers = Object.values(config.providers)
15+
providersWithKeys = Object.entries(config.providers).map(([key, provider]) => ({ ...provider, key }))
1616
} else {
17-
providers = keyInfo.providers.map((name) => config.providers[name])
17+
providersWithKeys = keyInfo.providers.map((key) => ({ ...config.providers[key], key }))
18+
}
19+
20+
// Transform routes to routingGroups
21+
const routingGroups: Record<string, { key: string }[]> = {}
22+
if (config.routes) {
23+
for (const [routeName, routeProviderKeys] of Object.entries(config.routes)) {
24+
routingGroups[routeName] = routeProviderKeys.map((providerKey) => ({ key: providerKey }))
25+
}
1826
}
1927

2028
// if keyInfo.id is unset, hash the API key to give something unique without explicitly using the key directly
@@ -49,7 +57,8 @@ export class ConfigDB extends KeysDbD1 {
4957
userSpendingLimitDaily: user?.spendingLimitDaily,
5058
userSpendingLimitWeekly: user?.spendingLimitWeekly,
5159
userSpendingLimitMonthly: user?.spendingLimitMonthly,
52-
providers,
60+
providers: providersWithKeys,
61+
routingGroups,
5362
otelSettings: user?.otel ?? project.otel,
5463
}
5564
}

deploy/src/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import type { OtelSettings, ProviderProxy } from '@pydantic/ai-gateway'
33
export interface Config<ProviderKey extends string = string> {
44
/** @param project: record keys are the project ids */
55
projects: Record<number, Project>
6+
routes?: Record<string, ProviderKey[]>
67
providers: Record<ProviderKey, ProviderProxy>
78
apiKeys: Record<string, ApiKey<ProviderKey>>
89
}

deploy/test.config.ts

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,12 @@ export const config: Config<ProviderKeys> = {
3232
injectCost: true,
3333
// credentials are used by the ProviderProxy to authenticate the forwarded request
3434
credentials: env.OPENAI_API_KEY,
35-
apiTypes: ['chat'],
3635
},
3736
groq: {
3837
baseUrl: 'http://localhost:8005/groq',
3938
providerId: 'groq',
4039
injectCost: true,
4140
credentials: env.GROQ_API_KEY,
42-
apiTypes: ['groq'],
4341
},
4442
// google: {
4543
// baseUrl:
@@ -53,22 +51,14 @@ export const config: Config<ProviderKeys> = {
5351
providerId: 'anthropic',
5452
injectCost: true,
5553
credentials: env.ANTHROPIC_API_KEY,
56-
apiTypes: ['anthropic'],
5754
},
5855
bedrock: {
5956
baseUrl: 'http://localhost:8005/bedrock',
6057
providerId: 'bedrock',
6158
injectCost: true,
6259
credentials: env.AWS_BEARER_TOKEN_BEDROCK,
63-
apiTypes: ['anthropic', 'converse'],
64-
},
65-
test: {
66-
baseUrl: 'http://test.example.com/test',
67-
providerId: 'test',
68-
injectCost: true,
69-
credentials: 'test',
70-
apiTypes: ['test'],
7160
},
61+
test: { baseUrl: 'http://test.example.com/test', providerId: 'test', injectCost: true, credentials: 'test' },
7262
},
7363
// individual apiKeys
7464
apiKeys: {

deploy/test/index.spec.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ describe('deploy', () => {
7272

7373
const client = new OpenAI({
7474
apiKey: 'healthy-key',
75-
baseURL: 'https://example.com/chat',
75+
baseURL: 'https://example.com/openai',
7676
fetch: SELF.fetch.bind(SELF),
7777
})
7878

deploy/test/index.spec.ts.snap

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ exports[`deploy > should call openai via gateway > llm 1`] = `
1414
},
1515
},
1616
],
17-
"created": 1762272055,
18-
"id": "chatcmpl-CYDklwaN7x9okuWTnABMCrZykoiRj",
17+
"created": 1762861820,
18+
"id": "chatcmpl-CahB69PCs04fSZmb69YXvx65zQ0XE",
1919
"model": "gpt-5-2025-08-07",
2020
"object": "chat.completion",
2121
"service_tier": "default",
@@ -118,7 +118,7 @@ exports[`deploy > should call openai via gateway > span 1`] = `
118118
{
119119
"key": "gen_ai.response.id",
120120
"value": {
121-
"stringValue": "chatcmpl-CYDklwaN7x9okuWTnABMCrZykoiRj",
121+
"stringValue": "chatcmpl-CahB69PCs04fSZmb69YXvx65zQ0XE",
122122
},
123123
},
124124
{
@@ -293,7 +293,7 @@ exports[`deploy > should call openai via gateway > span 1`] = `
293293
{
294294
"key": "http.response.body.text",
295295
"value": {
296-
"stringValue": "{"id":"chatcmpl-CYDklwaN7x9okuWTnABMCrZykoiRj","object":"chat.completion","created":1762272055,"model":"gpt-5-2025-08-07","choices":[{"index":0,"message":{"role":"assistant","content":"Paris.","refusal":null,"annotations":[]},"finish_reason":"stop"}],"usage":{"prompt_tokens":23,"completion_tokens":75,"total_tokens":98,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":64,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0},"pydantic_ai_gateway":{"cost_estimate":0.00077875}},"service_tier":"default","system_fingerprint":null}",
296+
"stringValue": "{"id":"chatcmpl-CahB69PCs04fSZmb69YXvx65zQ0XE","object":"chat.completion","created":1762861820,"model":"gpt-5-2025-08-07","choices":[{"index":0,"message":{"role":"assistant","content":"Paris.","refusal":null,"annotations":[]},"finish_reason":"stop"}],"usage":{"prompt_tokens":23,"completion_tokens":75,"total_tokens":98,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":64,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0},"pydantic_ai_gateway":{"cost_estimate":0.00077875}},"service_tier":"default","system_fingerprint":null}",
297297
},
298298
},
299299
{

examples/pai_openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def validate_dob(cls, v: date) -> date:
2626

2727

2828
person_agent = Agent(
29-
'gateway/openai-chat:gpt-4.1-mini',
29+
'gateway/openai:gpt-4.1-mini',
3030
output_type=Person,
3131
instructions='Extract information about the person',
3232
)

examples/pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,10 @@ dependencies = [
1111
"mypy-boto3-bedrock-runtime",
1212
]
1313

14+
# Install pydantic ai from git
15+
[tool.uv.sources]
16+
# Please don't remove this line, it's useful to test branches.
17+
pydantic-ai = { git = "https://github.com/pydantic/pydantic-ai.git", rev = '395cbe73' }
18+
1419
[tool.uv]
1520
package = false

gateway/src/gateway.ts

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@ import { currentScopeIntervals, type ExceededScope, endOfMonth, endOfWeek, type
55
import { OtelTrace } from './otel'
66
import { genAiOtelAttributes } from './otel/attributes'
77
import { getProvider } from './providers'
8-
import type { APIType } from './types'
9-
import { type ApiKeyInfo, apiTypesArray, guardAPIType } from './types'
8+
import type { ApiKeyInfo, ProviderProxy } from './types'
109
import { runAfter, textResponse } from './utils'
1110

1211
export async function gateway(
@@ -15,23 +14,19 @@ export async function gateway(
1514
ctx: ExecutionContext,
1615
options: GatewayOptions,
1716
): Promise<Response> {
18-
const apiTypeMatch = /^\/([^/]+)\/(.*)$/.exec(proxyPath)
19-
if (!apiTypeMatch) {
17+
const routeMatch = /^\/([^/]+)\/(.*)$/.exec(proxyPath)
18+
if (!routeMatch) {
2019
return textResponse(404, 'Path not found')
2120
}
22-
let [, apiType, restOfPath] = apiTypeMatch as unknown as [string, string, string]
23-
24-
// support for other common names for openai api types
25-
if (apiType === 'openai' || apiType === 'openai-chat') {
26-
apiType = 'chat'
27-
} else if (apiType === 'openai-responses') {
28-
apiType = 'responses'
29-
} else if (apiType === 'google-vertex') {
30-
apiType = 'gemini'
31-
}
32-
33-
if (!guardAPIType(apiType)) {
34-
return textResponse(400, `Invalid API type '${apiType}', should be one of ${apiTypesArray.join(', ')}`)
21+
let [, route, restOfPath] = routeMatch as unknown as [string, string, string]
22+
23+
// Backwards compatibility with the old route format.
24+
if (route === 'openai-responses' || route === 'openai-chat' || route === 'chat' || route === 'responses') {
25+
route = 'openai'
26+
} else if (route === 'gemini') {
27+
route = 'google-vertex'
28+
} else if (route === 'converse') {
29+
route = 'bedrock'
3530
}
3631

3732
const rateLimiter = options.rateLimiter ?? noopLimiter
@@ -41,16 +36,43 @@ export async function gateway(
4136
}
4237
const apiKeyInfo = authResult
4338
try {
44-
return await gatewayWithLimiter(request, restOfPath, apiType, apiKeyInfo, ctx, options)
39+
return await gatewayWithLimiter(request, restOfPath, route, apiKeyInfo, ctx, options)
4540
} finally {
4641
runAfter(ctx, 'options.rateLimiter.requestFinish', rateLimiter.requestFinish())
4742
}
4843
}
4944

45+
const getProviderProxies = (
46+
route: string,
47+
providerProxyMapping: Record<string, ProviderProxy>,
48+
routingGroups: ApiKeyInfo['routingGroups'],
49+
): ProviderProxy[] | { status: number; message: string } => {
50+
if (route in providerProxyMapping) {
51+
return [providerProxyMapping[route]!]
52+
}
53+
const routingGroup = routingGroups?.[route]
54+
if (!routingGroup) {
55+
const supportedValues = [...new Set([...Object.keys(providerProxyMapping), ...Object.keys(routingGroups ?? {})])]
56+
.sort()
57+
.join(', ')
58+
return { status: 404, message: `Route not found: ${route}. Supported values: ${supportedValues}` }
59+
}
60+
const providerProxies = routingGroup
61+
.map(({ key }) => providerProxyMapping[key])
62+
.filter((x): x is ProviderProxy & { key: string } => !!x)
63+
if (providerProxies.length === 0) {
64+
return {
65+
status: 400,
66+
message: `No providers included in routing group '${route}'. Add one or more providers to this routing group in the Pydantic AI Gateway console.`,
67+
}
68+
}
69+
return providerProxies
70+
}
71+
5072
export async function gatewayWithLimiter(
5173
request: Request,
5274
restOfPath: string,
53-
apiType: APIType,
75+
route: string,
5476
apiKeyInfo: ApiKeyInfo,
5577
ctx: ExecutionContext,
5678
options: GatewayOptions,
@@ -59,23 +81,13 @@ export async function gatewayWithLimiter(
5981
return textResponse(403, `Unauthorized - Key ${apiKeyInfo.status}`)
6082
}
6183

62-
let providerProxies = apiKeyInfo.providers.filter((p) => p.apiTypes.includes(apiType))
63-
64-
const routingGroup = request.headers.get('pydantic-ai-gateway-routing-group')
65-
if (routingGroup !== null) {
66-
providerProxies = providerProxies.filter((p) => p.routingGroup === routingGroup)
67-
}
68-
69-
const profile = request.headers.get('pydantic-ai-gateway-profile')
70-
if (profile !== null) {
71-
providerProxies = providerProxies.filter((p) => p.profile === profile)
72-
}
73-
74-
// sort providers on priority, highest first
75-
providerProxies.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0))
76-
77-
if (providerProxies.length === 0) {
78-
return textResponse(403, 'Forbidden - Provider not supported by this API Key')
84+
const { routingGroups } = apiKeyInfo
85+
const providerProxyMapping: Record<string, ProviderProxy> = Object.fromEntries(
86+
apiKeyInfo.providers.map((p) => [p.key, p]),
87+
)
88+
const providerProxies = getProviderProxies(route, providerProxyMapping, routingGroups)
89+
if (!Array.isArray(providerProxies)) {
90+
return textResponse(providerProxies.status, providerProxies.message)
7991
}
8092

8193
const otel = new OtelTrace(request, apiKeyInfo.otelSettings, options)
@@ -102,10 +114,7 @@ export async function gatewayWithLimiter(
102114
try {
103115
result = await proxy.dispatch()
104116
} catch (error) {
105-
logfire.reportError('Connection error', error as Error, {
106-
providerId: providerProxy.providerId,
107-
routingGroup: providerProxy.routingGroup,
108-
})
117+
logfire.reportError('Connection error', error as Error, { providerId: providerProxy.providerId, route })
109118
continue
110119
}
111120

@@ -120,7 +129,7 @@ export async function gatewayWithLimiter(
120129
logfire.info('Provider failed with retryable error, trying next provider', {
121130
providerId: providerProxy.providerId,
122131
status: result.unexpectedStatus,
123-
routingGroup: providerProxy.routingGroup,
132+
route,
124133
})
125134
continue
126135
}

gateway/src/types.ts

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ export type KeyStatus =
88
| 'blocked' // when we got a valid response that we couldn't calculate the cost for
99

1010
// Info about an API key for a particular provider returned by the DB during a request
11-
export interface ApiKeyInfo {
11+
export interface ApiKeyInfo<ProviderKey extends string = string> {
1212
id: number
1313
user?: number
1414
project: number
@@ -30,29 +30,27 @@ export interface ApiKeyInfo {
3030
userSpendingLimitDaily?: number
3131
userSpendingLimitWeekly?: number
3232
userSpendingLimitMonthly?: number
33-
providers: ProviderProxy[]
33+
providers: (ProviderProxy & { key: ProviderKey })[]
34+
routingGroups: Record<string, { key: ProviderKey }[]>
3435
otelSettings?: OtelSettings
3536
}
3637

3738
export type ProviderID = 'groq' | 'openai' | 'google-vertex' | 'anthropic' | 'test' | 'bedrock'
3839
// TODO | 'azure' | 'fireworks' | 'mistral' | 'cohere'
3940

40-
export type APIType = 'chat' | 'responses' | 'converse' | 'anthropic' | 'gemini' | 'groq' | 'test'
41-
42-
const apiTypes: Record<APIType, boolean> = {
43-
chat: true,
44-
responses: true,
45-
converse: true,
46-
anthropic: true,
47-
gemini: true,
41+
const providerIds: Record<ProviderID, boolean> = {
4842
groq: true,
43+
openai: true,
44+
'google-vertex': true,
45+
anthropic: true,
4946
test: true,
47+
bedrock: true,
5048
}
5149

52-
export const apiTypesArray = Object.keys(apiTypes) as APIType[]
50+
export const providerIdsArray = Object.keys(providerIds) as ProviderID[]
5351

54-
export function guardAPIType(type: string): type is APIType {
55-
return type in apiTypes
52+
export function guardProviderID(id: string): id is ProviderID {
53+
return id in providerIds
5654
}
5755

5856
export interface ProviderProxy {
@@ -73,18 +71,12 @@ export interface ProviderProxy {
7371
profile?: string
7472

7573
/** Higher priority providers will be used first */
74+
// TODO(Marcelo): Remove now - this should live in the routingGroups.
7675
priority?: number
7776

7877
/** Weather to disable the key in case of error, if missing defaults to True. */
7978
disableKey?: boolean
8079

81-
/** The APIs that the provider supports. Example: ['chat', 'responses'] */
82-
apiTypes: APIType[]
83-
84-
/** A grouping of APIs that serve the same models.
85-
* @example: 'anthropic' would route the requests to Anthropic, Bedrock and Vertex AI. */
86-
routingGroup?: string
87-
8880
/** Whether the provider is managed by the platform and not by the user. */
8981
isBuiltIn?: boolean
9082
}

0 commit comments

Comments
 (0)