Skip to content

Commit 922b13e

Browse files
authored
Add apiTypes and routingGroups (#112)
1 parent 6017ac3 commit 922b13e

File tree

11 files changed

+84
-33
lines changed

11 files changed

+84
-33
lines changed

deploy/example.config.ts

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,28 @@ export const config: Config<ProviderKeys> = {
4747
injectCost: true,
4848
// credentials are used by the ProviderProxy to authenticate the forwarded request
4949
credentials: env.OPENAI_API_KEY,
50+
apiTypes: ['chat', 'responses'],
51+
},
52+
b: {
53+
providerId: 'groq',
54+
baseUrl: 'https://api.groq.com',
55+
injectCost: true,
56+
credentials: env.GROQ_API_KEY,
57+
apiTypes: ['groq'],
5058
},
51-
b: { providerId: 'groq', baseUrl: 'https://api.groq.com', injectCost: true, credentials: env.GROQ_API_KEY },
5259
c: {
5360
providerId: 'google-vertex',
5461
baseUrl: 'https://us-central1-aiplatform.googleapis.com',
5562
injectCost: true,
5663
credentials: env.GOOGLE_SERVICE_ACCOUNT_KEY,
64+
apiTypes: ['gemini', 'anthropic'],
5765
},
5866
d: {
5967
providerId: 'anthropic',
6068
baseUrl: 'https://api.anthropic.com',
6169
injectCost: true,
6270
credentials: env.ANTHROPIC_API_KEY,
71+
apiTypes: ['anthropic'],
6372
},
6473
},
6574
// individual apiKeys

deploy/test.config.ts

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,14 @@ export const config: Config<ProviderKeys> = {
3232
injectCost: true,
3333
// credentials are used by the ProviderProxy to authenticate the forwarded request
3434
credentials: env.OPENAI_API_KEY,
35+
apiTypes: ['chat'],
3536
},
3637
groq: {
3738
baseUrl: 'http://localhost:8005/groq',
3839
providerId: 'groq',
3940
injectCost: true,
4041
credentials: env.GROQ_API_KEY,
42+
apiTypes: ['groq'],
4143
},
4244
// google: {
4345
// baseUrl:
@@ -51,14 +53,22 @@ export const config: Config<ProviderKeys> = {
5153
providerId: 'anthropic',
5254
injectCost: true,
5355
credentials: env.ANTHROPIC_API_KEY,
56+
apiTypes: ['anthropic'],
5457
},
5558
bedrock: {
5659
baseUrl: 'http://localhost:8005/bedrock',
5760
providerId: 'bedrock',
5861
injectCost: true,
5962
credentials: env.AWS_BEARER_TOKEN_BEDROCK,
63+
apiTypes: ['anthropic', 'converse'],
64+
},
65+
test: {
66+
baseUrl: 'http://test.example.com/test',
67+
providerId: 'test',
68+
injectCost: true,
69+
credentials: 'test',
70+
apiTypes: ['test'],
6071
},
61-
test: { baseUrl: 'http://test.example.com/test', providerId: 'test', injectCost: true, credentials: 'test' },
6272
},
6373
// individual apiKeys
6474
apiKeys: {

deploy/test/index.spec.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ describe('deploy', () => {
7272

7373
const client = new OpenAI({
7474
apiKey: 'healthy-key',
75-
baseURL: 'https://example.com/openai',
75+
baseURL: 'https://example.com/chat',
7676
fetch: SELF.fetch.bind(SELF),
7777
})
7878

gateway/src/gateway.ts

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import { currentScopeIntervals, type ExceededScope, endOfMonth, endOfWeek, type
55
import { OtelTrace } from './otel'
66
import { genAiOtelAttributes } from './otel/attributes'
77
import { getProvider } from './providers'
8-
import { type ApiKeyInfo, guardProviderID, providerIdArray } from './types'
8+
import { type ApiKeyInfo, apiTypesArray, guardAPIType } from './types'
99
import { runAfter, textResponse } from './utils'
1010

1111
export async function gateway(
@@ -14,14 +14,14 @@ export async function gateway(
1414
ctx: ExecutionContext,
1515
options: GatewayOptions,
1616
): Promise<Response> {
17-
const providerMatch = /^\/([^/]+)\/(.*)$/.exec(proxyPath)
18-
if (!providerMatch) {
17+
const apiTypeMatch = /^\/([^/]+)\/(.*)$/.exec(proxyPath)
18+
if (!apiTypeMatch) {
1919
return textResponse(404, 'Path not found')
2020
}
21-
const [, provider, restOfPath] = providerMatch as unknown as [string, string, string]
21+
const [, apiType, restOfPath] = apiTypeMatch as unknown as [string, string, string]
2222

23-
if (!guardProviderID(provider)) {
24-
return textResponse(400, `Invalid provider '${provider}', should be one of ${providerIdArray.join(', ')}`)
23+
if (!guardAPIType(apiType)) {
24+
return textResponse(400, `Invalid API type '${apiType}', should be one of ${apiTypesArray.join(', ')}`)
2525
}
2626

2727
const apiKeyInfo = await apiKeyAuth(request, ctx, options)
@@ -30,7 +30,12 @@ export async function gateway(
3030
return textResponse(403, `Unauthorized - Key ${apiKeyInfo.status}`)
3131
}
3232

33-
let providerProxies = apiKeyInfo.providers.filter((p) => p.providerId === provider)
33+
let providerProxies = apiKeyInfo.providers.filter((p) => p.apiTypes.includes(apiType))
34+
35+
const routingGroup = request.headers.get('pydantic-ai-gateway-routing-group')
36+
if (routingGroup !== null) {
37+
providerProxies = providerProxies.filter((p) => p.routingGroup === routingGroup)
38+
}
3439

3540
const profile = request.headers.get('pydantic-ai-gateway-profile')
3641
if (profile !== null) {

gateway/src/types.ts

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,22 @@ export interface ApiKeyInfo {
3535
export type ProviderID = 'groq' | 'openai' | 'google-vertex' | 'anthropic' | 'test' | 'bedrock'
3636
// TODO | 'azure' | 'fireworks' | 'mistral' | 'cohere'
3737

38-
const providerIds: Record<ProviderID, boolean> = {
39-
groq: true,
40-
openai: true,
41-
'google-vertex': true,
38+
export type APIType = 'chat' | 'responses' | 'converse' | 'anthropic' | 'gemini' | 'groq' | 'test'
39+
40+
const apiTypes: Record<APIType, boolean> = {
41+
chat: true,
42+
responses: true,
43+
converse: true,
4244
anthropic: true,
45+
gemini: true,
46+
groq: true,
4347
test: true,
44-
bedrock: true,
4548
}
4649

47-
export const providerIdArray = Object.keys(providerIds).filter((id) => id !== 'test') as ProviderID[]
50+
export const apiTypesArray = Object.keys(apiTypes) as APIType[]
4851

49-
export function guardProviderID(id: string): id is ProviderID {
50-
return id in providerIds
52+
export function guardAPIType(type: string): type is APIType {
53+
return type in apiTypes
5154
}
5255

5356
export interface ProviderProxy {
@@ -67,6 +70,12 @@ export interface ProviderProxy {
6770
priority?: number
6871
/** @disableKey: weather to disable the key in case of error, if missing defaults to True. */
6972
disableKey?: boolean
73+
74+
/** @apiTypes: the APIs that the provider supports. Example: ['chat', 'responses'] */
75+
apiTypes: APIType[]
76+
/** @routingGroups: a grouping of APIs that serve the same models.
77+
* @example: 'anthropic' would route the requests to Anthropic, Bedrock and Vertex AI. */
78+
routingGroup?: string
7079
}
7180

7281
export interface OtelSettings {

gateway/test/gateway.spec.ts

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ import { buildGatewayEnv, type DisableEvent, IDS } from './worker'
1515

1616
describe('invalid request', () => {
1717
test('401 on no auth header', async ({ gateway }) => {
18-
const response = await gateway.fetch('https://example.com/openai/gpt-5')
18+
const response = await gateway.fetch('https://example.com/chat/gpt-5')
1919
const text = await response.text()
2020
expect(response.status, `got ${response.status} response: ${text}`).toBe(401)
2121
expect(text).toMatchInlineSnapshot(`"Unauthorized - Missing Authorization Header"`)
2222
})
2323
test('401 on unknown auth header', async ({ gateway }) => {
24-
const response = await gateway.fetch('https://example.com/openai/gpt-5', {
24+
const response = await gateway.fetch('https://example.com/chat/gpt-5', {
2525
headers: { Authorization: 'unknown-token' },
2626
})
2727
const text = await response.text()
@@ -35,7 +35,7 @@ describe('invalid request', () => {
3535
const text = await response.text()
3636
expect(response.status, `got ${response.status} response: ${text}`).toBe(400)
3737
expect(text).toMatchInlineSnapshot(
38-
`"Invalid provider 'wrong', should be one of groq, openai, google-vertex, anthropic, bedrock"`,
38+
`"Invalid API type 'wrong', should be one of chat, responses, converse, anthropic, gemini, groq, test"`,
3939
)
4040
})
4141
})
@@ -66,7 +66,7 @@ describe('key status', () => {
6666
test('should block request if key is disabled', async ({ gateway }) => {
6767
const { fetch } = gateway
6868

69-
const response = await fetch('https://example.com/openai/xxx', { headers: { Authorization: 'disabled' } })
69+
const response = await fetch('https://example.com/chat/xxx', { headers: { Authorization: 'disabled' } })
7070
const text = await response.text()
7171
expect(response.status, `got response: ${response.status} ${text}`).toBe(403)
7272
expect(text).toMatchInlineSnapshot(`"Unauthorized - Key disabled"`)
@@ -132,7 +132,7 @@ describe('key status', () => {
132132
expect(Math.abs(keyStatusQuery.results[0]!.expiresAtDiff - disableEvents[0]!.expirationTtl!)).toBeLessThan(2)
133133

134134
{
135-
const response = await fetch('https://example.com/openai/xxx', { headers: { Authorization: 'tiny-limit' } })
135+
const response = await fetch('https://example.com/chat/xxx', { headers: { Authorization: 'tiny-limit' } })
136136
const text = await response.text()
137137
expect(response.status, `got ${response.status} response: ${text}`).toBe(403)
138138
expect(text).toMatchInlineSnapshot(`"Unauthorized - Key limit-exceeded"`)
@@ -215,7 +215,7 @@ describe('custom proxyPrefixLength', () => {
215215
const disableEvents: DisableEvent[] = []
216216
const mockFetch = mockFetchFactory(disableEvents)
217217

218-
const client = new OpenAI({ apiKey: 'healthy', baseURL: 'https://example.com/proxy/openai', fetch: mockFetch })
218+
const client = new OpenAI({ apiKey: 'healthy', baseURL: 'https://example.com/proxy/chat', fetch: mockFetch })
219219

220220
const completion = await client.chat.completions.create({
221221
model: 'gpt-5',
@@ -249,7 +249,7 @@ describe('custom middleware', () => {
249249
)[] = []
250250

251251
const ctx = createExecutionContext()
252-
const request = new Request<unknown, IncomingRequestCfProperties>('https://example.com/openai/gpt-5', {
252+
const request = new Request<unknown, IncomingRequestCfProperties>('https://example.com/chat/gpt-5', {
253253
headers: { Authorization: 'healthy' },
254254
})
255255

gateway/test/gateway.spec.ts.snap

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ exports[`key status > should change key status if limit is exceeded > kv-value 1
9393
"projectSpendingLimitMonthly": 4,
9494
"providers": [
9595
{
96+
"apiTypes": [
97+
"test",
98+
],
9699
"baseUrl": "http://test.example.com/test",
97100
"credentials": "test",
98101
"injectCost": true,

gateway/test/providers/bedrock.spec.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ describe('bedrock', () => {
55
test('should call bedrock via gateway', async ({ gateway }) => {
66
const { fetch, otelBatch } = gateway
77

8-
const result = await fetch('https://example.com/bedrock/model/amazon.nova-micro-v1%3A0/converse', {
8+
const result = await fetch('https://example.com/converse/model/amazon.nova-micro-v1%3A0/converse', {
99
method: 'POST',
1010
headers: { 'Content-Type': 'application/json', Authorization: 'healthy' },
1111
body: JSON.stringify({

gateway/test/providers/openai.spec.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ describe('openai', () => {
99
test('openai chat', async ({ gateway }) => {
1010
const { fetch, otelBatch } = gateway
1111

12-
const client = new OpenAI({ apiKey: 'healthy', baseURL: 'https://example.com/openai', fetch })
12+
const client = new OpenAI({ apiKey: 'healthy', baseURL: 'https://example.com/chat', fetch })
1313

1414
const completion = await client.chat.completions.create({
1515
model: 'gpt-5',
@@ -103,7 +103,7 @@ describe('openai', () => {
103103
test('openai responses', async ({ gateway }) => {
104104
const { fetch, otelBatch } = gateway
105105

106-
const client = new OpenAI({ apiKey: 'healthy', baseURL: 'https://example.com/openai', fetch })
106+
const client = new OpenAI({ apiKey: 'healthy', baseURL: 'https://example.com/chat', fetch })
107107

108108
const completion = await client.responses.create({
109109
model: 'gpt-5',
@@ -118,7 +118,7 @@ describe('openai', () => {
118118
test('openai responses with builtin tools', async ({ gateway }) => {
119119
const { fetch, otelBatch } = gateway
120120

121-
const client = new OpenAI({ apiKey: 'healthy', baseURL: 'https://example.com/openai', fetch })
121+
const client = new OpenAI({ apiKey: 'healthy', baseURL: 'https://example.com/chat', fetch })
122122

123123
const completion = await client.responses.create({
124124
model: 'gpt-5',
@@ -150,7 +150,7 @@ describe('openai', () => {
150150
test('openai chat stream', async ({ gateway }) => {
151151
const { fetch, otelBatch } = gateway
152152

153-
const client = new OpenAI({ apiKey: 'healthy', baseURL: 'https://example.com/openai', fetch })
153+
const client = new OpenAI({ apiKey: 'healthy', baseURL: 'https://example.com/chat', fetch })
154154

155155
const stream = await client.chat.completions.create({
156156
stream: true,

gateway/test/providers/openai.spec.ts.snap

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ exports[`openai > openai chat stream > span 1`] = `
547547
{
548548
"key": "url.full",
549549
"value": {
550-
"stringValue": "https://example.com/openai/chat/completions",
550+
"stringValue": "https://example.com/chat/chat/completions",
551551
},
552552
},
553553
{

0 commit comments

Comments
 (0)