@@ -5,8 +5,7 @@ import { currentScopeIntervals, type ExceededScope, endOfMonth, endOfWeek, type
55import { OtelTrace } from './otel'
66import { genAiOtelAttributes } from './otel/attributes'
77import { getProvider } from './providers'
8- import type { APIType } from './types'
9- import { type ApiKeyInfo , apiTypesArray , guardAPIType } from './types'
8+ import type { ApiKeyInfo , ProviderProxy } from './types'
109import { runAfter , textResponse } from './utils'
1110
1211export async function gateway (
@@ -15,23 +14,19 @@ export async function gateway(
1514 ctx : ExecutionContext ,
1615 options : GatewayOptions ,
1716) : Promise < Response > {
18- const apiTypeMatch = / ^ \/ ( [ ^ / ] + ) \/ ( .* ) $ / . exec ( proxyPath )
19- if ( ! apiTypeMatch ) {
17+ const routeMatch = / ^ \/ ( [ ^ / ] + ) \/ ( .* ) $ / . exec ( proxyPath )
18+ if ( ! routeMatch ) {
2019 return textResponse ( 404 , 'Path not found' )
2120 }
22- let [ , apiType , restOfPath ] = apiTypeMatch as unknown as [ string , string , string ]
23-
24- // support for other common names for openai api types
25- if ( apiType === 'openai' || apiType === 'openai-chat' ) {
26- apiType = 'chat'
27- } else if ( apiType === 'openai-responses' ) {
28- apiType = 'responses'
29- } else if ( apiType === 'google-vertex' ) {
30- apiType = 'gemini'
31- }
32-
33- if ( ! guardAPIType ( apiType ) ) {
34- return textResponse ( 400 , `Invalid API type '${ apiType } ', should be one of ${ apiTypesArray . join ( ', ' ) } ` )
21+ let [ , route , restOfPath ] = routeMatch as unknown as [ string , string , string ]
22+
23+ // Backwards compatibility with the old route format.
24+ if ( route === 'openai-responses' || route === 'openai-chat' || route === 'chat' || route === 'responses' ) {
25+ route = 'openai'
26+ } else if ( route === 'gemini' ) {
27+ route = 'google-vertex'
28+ } else if ( route === 'converse' ) {
29+ route = 'bedrock'
3530 }
3631
3732 const rateLimiter = options . rateLimiter ?? noopLimiter
@@ -41,16 +36,43 @@ export async function gateway(
4136 }
4237 const apiKeyInfo = authResult
4338 try {
44- return await gatewayWithLimiter ( request , restOfPath , apiType , apiKeyInfo , ctx , options )
39+ return await gatewayWithLimiter ( request , restOfPath , route , apiKeyInfo , ctx , options )
4540 } finally {
4641 runAfter ( ctx , 'options.rateLimiter.requestFinish' , rateLimiter . requestFinish ( ) )
4742 }
4843}
4944
45+ const getProviderProxies = (
46+ route : string ,
47+ providerProxyMapping : Record < string , ProviderProxy > ,
48+ routingGroups : ApiKeyInfo [ 'routingGroups' ] ,
49+ ) : ProviderProxy [ ] | { status : number ; message : string } => {
50+ if ( route in providerProxyMapping ) {
51+ return [ providerProxyMapping [ route ] ! ]
52+ }
53+ const routingGroup = routingGroups ?. [ route ]
54+ if ( ! routingGroup ) {
55+ const supportedValues = [ ...new Set ( [ ...Object . keys ( providerProxyMapping ) , ...Object . keys ( routingGroups ?? { } ) ] ) ]
56+ . sort ( )
57+ . join ( ', ' )
58+ return { status : 404 , message : `Route not found: ${ route } . Supported values: ${ supportedValues } ` }
59+ }
60+ const providerProxies = routingGroup
61+ . map ( ( { key } ) => providerProxyMapping [ key ] )
62+ . filter ( ( x ) : x is ProviderProxy & { key : string } => ! ! x )
63+ if ( providerProxies . length === 0 ) {
64+ return {
65+ status : 400 ,
66+ message : `No providers included in routing group '${ route } '. Add one or more providers to this routing group in the Pydantic AI Gateway console.` ,
67+ }
68+ }
69+ return providerProxies
70+ }
71+
5072export async function gatewayWithLimiter (
5173 request : Request ,
5274 restOfPath : string ,
53- apiType : APIType ,
75+ route : string ,
5476 apiKeyInfo : ApiKeyInfo ,
5577 ctx : ExecutionContext ,
5678 options : GatewayOptions ,
@@ -59,23 +81,13 @@ export async function gatewayWithLimiter(
5981 return textResponse ( 403 , `Unauthorized - Key ${ apiKeyInfo . status } ` )
6082 }
6183
62- let providerProxies = apiKeyInfo . providers . filter ( ( p ) => p . apiTypes . includes ( apiType ) )
63-
64- const routingGroup = request . headers . get ( 'pydantic-ai-gateway-routing-group' )
65- if ( routingGroup !== null ) {
66- providerProxies = providerProxies . filter ( ( p ) => p . routingGroup === routingGroup )
67- }
68-
69- const profile = request . headers . get ( 'pydantic-ai-gateway-profile' )
70- if ( profile !== null ) {
71- providerProxies = providerProxies . filter ( ( p ) => p . profile === profile )
72- }
73-
74- // sort providers on priority, highest first
75- providerProxies . sort ( ( a , b ) => ( b . priority ?? 0 ) - ( a . priority ?? 0 ) )
76-
77- if ( providerProxies . length === 0 ) {
78- return textResponse ( 403 , 'Forbidden - Provider not supported by this API Key' )
84+ const { routingGroups } = apiKeyInfo
85+ const providerProxyMapping : Record < string , ProviderProxy > = Object . fromEntries (
86+ apiKeyInfo . providers . map ( ( p ) => [ p . key , p ] ) ,
87+ )
88+ const providerProxies = getProviderProxies ( route , providerProxyMapping , routingGroups )
89+ if ( ! Array . isArray ( providerProxies ) ) {
90+ return textResponse ( providerProxies . status , providerProxies . message )
7991 }
8092
8193 const otel = new OtelTrace ( request , apiKeyInfo . otelSettings , options )
@@ -102,10 +114,7 @@ export async function gatewayWithLimiter(
102114 try {
103115 result = await proxy . dispatch ( )
104116 } catch ( error ) {
105- logfire . reportError ( 'Connection error' , error as Error , {
106- providerId : providerProxy . providerId ,
107- routingGroup : providerProxy . routingGroup ,
108- } )
117+ logfire . reportError ( 'Connection error' , error as Error , { providerId : providerProxy . providerId , route } )
109118 continue
110119 }
111120
@@ -120,7 +129,7 @@ export async function gatewayWithLimiter(
120129 logfire . info ( 'Provider failed with retryable error, trying next provider' , {
121130 providerId : providerProxy . providerId ,
122131 status : result . unexpectedStatus ,
123- routingGroup : providerProxy . routingGroup ,
132+ route ,
124133 } )
125134 continue
126135 }
0 commit comments