1717 from groq import AsyncGroq
1818 from openai import AsyncOpenAI
1919
20- from pydantic_ai .models import Model
2120 from pydantic_ai .models .anthropic import AsyncAnthropicClient
2221 from pydantic_ai .providers import Provider
2322
2625
2726@overload
2827def gateway_provider (
29- api_type : Literal ['chat' , 'responses' ],
28+ upstream_provider : Literal ['openai' , 'openai-chat' , 'openai-responses' , 'chat' , 'responses' ],
3029 / ,
3130 * ,
32- routing_group : str | None = None ,
33- profile : str | None = None ,
31+ route : str | None = None ,
3432 api_key : str | None = None ,
3533 base_url : str | None = None ,
3634 http_client : httpx .AsyncClient | None = None ,
@@ -39,11 +37,10 @@ def gateway_provider(
3937
4038@overload
4139def gateway_provider (
42- api_type : Literal ['groq' ],
40+ upstream_provider : Literal ['groq' ],
4341 / ,
4442 * ,
45- routing_group : str | None = None ,
46- profile : str | None = None ,
43+ route : str | None = None ,
4744 api_key : str | None = None ,
4845 base_url : str | None = None ,
4946 http_client : httpx .AsyncClient | None = None ,
@@ -52,11 +49,10 @@ def gateway_provider(
5249
5350@overload
5451def gateway_provider (
55- api_type : Literal ['anthropic' ],
52+ upstream_provider : Literal ['anthropic' ],
5653 / ,
5754 * ,
58- routing_group : str | None = None ,
59- profile : str | None = None ,
55+ route : str | None = None ,
6056 api_key : str | None = None ,
6157 base_url : str | None = None ,
6258 http_client : httpx .AsyncClient | None = None ,
@@ -65,23 +61,21 @@ def gateway_provider(
6561
6662@overload
6763def gateway_provider (
68- api_type : Literal ['converse' ],
64+ upstream_provider : Literal ['bedrock' , 'converse' ],
6965 / ,
7066 * ,
71- routing_group : str | None = None ,
72- profile : str | None = None ,
67+ route : str | None = None ,
7368 api_key : str | None = None ,
7469 base_url : str | None = None ,
7570) -> Provider [BaseClient ]: ...
7671
7772
7873@overload
7974def gateway_provider (
80- api_type : Literal ['gemini' ],
75+ upstream_provider : Literal ['gemini' , 'google-vertex ' ],
8176 / ,
8277 * ,
83- routing_group : str | None = None ,
84- profile : str | None = None ,
78+ route : str | None = None ,
8579 api_key : str | None = None ,
8680 base_url : str | None = None ,
8781 http_client : httpx .AsyncClient | None = None ,
@@ -90,26 +84,37 @@ def gateway_provider(
9084
9185@overload
9286def gateway_provider (
93- api_type : str ,
87+ upstream_provider : str ,
9488 / ,
9589 * ,
96- routing_group : str | None = None ,
97- profile : str | None = None ,
90+ route : str | None = None ,
9891 api_key : str | None = None ,
9992 base_url : str | None = None ,
10093) -> Provider [Any ]: ...
10194
10295
103- APIType = Literal ['chat' , 'responses' , 'gemini' , 'converse' , 'anthropic' , 'groq' ]
96+ UpstreamProvider = Literal [
97+ 'openai' ,
98+ 'groq' ,
99+ 'anthropic' ,
100+ 'bedrock' ,
101+ 'google-vertex' ,
102+ # Those are only API formats, but we still support them for convenience.
103+ 'openai-chat' ,
104+ 'openai-responses' ,
105+ 'chat' ,
106+ 'responses' ,
107+ 'converse' ,
108+ 'gemini' ,
109+ ]
104110
105111
106112def gateway_provider (
107- api_type : APIType | str ,
113+ upstream_provider : UpstreamProvider | str ,
108114 / ,
109115 * ,
110116 # Every provider
111- routing_group : str | None = None ,
112- profile : str | None = None ,
117+ route : str | None = None ,
113118 api_key : str | None = None ,
114119 base_url : str | None = None ,
115120 # OpenAI, Groq, Anthropic & Gemini - Only Bedrock doesn't have an HTTPX client.
@@ -118,11 +123,9 @@ def gateway_provider(
118123 """Create a new Gateway provider.
119124
120125 Args:
121- api_type: Determines the API type to use.
122- routing_group: The group of APIs that support the same models - the idea is that you can route the requests to
123- any provider in a routing group. The `pydantic-ai-gateway-routing-group` header will be added.
124- profile: A provider may have a profile, which is a unique identifier for the provider.
125- The `pydantic-ai-gateway-profile` header will be added.
126+ upstream_provider: The upstream provider to use.
127+ route: The name of the provider or routing group to use to handle the request. If not provided, the default
128+ routing group for the API format will be used.
126129 api_key: The API key to use for authentication. If not provided, the `PYDANTIC_AI_GATEWAY_API_KEY`
127130 environment variable will be used if available.
128131 base_url: The base URL to use for the Gateway. If not provided, the `PYDANTIC_AI_GATEWAY_BASE_URL`
@@ -137,54 +140,45 @@ def gateway_provider(
137140 )
138141
139142 base_url = base_url or os .getenv ('PYDANTIC_AI_GATEWAY_BASE_URL' , GATEWAY_BASE_URL )
140- http_client = http_client or cached_async_http_client (provider = f'gateway/{ api_type } ' )
143+ http_client = http_client or cached_async_http_client (provider = f'gateway/{ upstream_provider } ' )
141144 http_client .event_hooks = {'request' : [_request_hook (api_key )]}
142145
143- if profile is not None :
144- http_client .headers .setdefault ('pydantic-ai-gateway-profile' , profile )
146+ if route is None :
147+ # Use the implied providerId as the default route.
148+ route = normalize_gateway_provider (upstream_provider )
145149
146- if routing_group is not None :
147- http_client .headers .setdefault ('pydantic-ai-gateway-routing-group' , routing_group )
150+ base_url = _merge_url_path (base_url , route )
148151
149- if api_type in ('chat' , 'responses' ):
152+ if upstream_provider in ('openai' , 'openai-chat' , 'openai-responses' , 'chat' , 'responses' ):
150153 from .openai import OpenAIProvider
151154
152- return OpenAIProvider (api_key = api_key , base_url = _merge_url_path ( base_url , api_type ) , http_client = http_client )
153- elif api_type == 'groq' :
155+ return OpenAIProvider (api_key = api_key , base_url = base_url , http_client = http_client )
156+ elif upstream_provider == 'groq' :
154157 from .groq import GroqProvider
155158
156- return GroqProvider (api_key = api_key , base_url = _merge_url_path ( base_url , 'groq' ) , http_client = http_client )
157- elif api_type == 'anthropic' :
159+ return GroqProvider (api_key = api_key , base_url = base_url , http_client = http_client )
160+ elif upstream_provider == 'anthropic' :
158161 from anthropic import AsyncAnthropic
159162
160163 from .anthropic import AnthropicProvider
161164
162165 return AnthropicProvider (
163- anthropic_client = AsyncAnthropic (
164- auth_token = api_key ,
165- base_url = _merge_url_path (base_url , 'anthropic' ),
166- http_client = http_client ,
167- )
166+ anthropic_client = AsyncAnthropic (auth_token = api_key , base_url = base_url , http_client = http_client )
168167 )
169- elif api_type == ' converse' :
168+ elif upstream_provider in ( 'bedrock' , ' converse') :
170169 from .bedrock import BedrockProvider
171170
172171 return BedrockProvider (
173172 api_key = api_key ,
174- base_url = _merge_url_path ( base_url , api_type ) ,
173+ base_url = base_url ,
175174 region_name = 'pydantic-ai-gateway' , # Fake region name to avoid NoRegionError
176175 )
177- elif api_type == ' gemini' :
176+ elif upstream_provider in ( 'google-vertex' , ' gemini') :
178177 from .google import GoogleProvider
179178
180- return GoogleProvider (
181- vertexai = True ,
182- api_key = api_key ,
183- base_url = _merge_url_path (base_url , 'gemini' ),
184- http_client = http_client ,
185- )
179+ return GoogleProvider (vertexai = True , api_key = api_key , base_url = base_url , http_client = http_client )
186180 else :
187- raise UserError (f'Unknown API type : { api_type } ' )
181+ raise UserError (f'Unknown upstream provider : { upstream_provider } ' )
188182
189183
190184def _request_hook (api_key : str ) -> Callable [[httpx .Request ], Awaitable [httpx .Request ]]:
@@ -218,31 +212,18 @@ def _merge_url_path(base_url: str, path: str) -> str:
218212 return base_url .rstrip ('/' ) + '/' + path .lstrip ('/' )
219213
220214
221- def infer_gateway_model (api_type : APIType | str , * , model_name : str ) -> Model :
222- """Infer the model class for a given API type."""
223- if api_type == 'chat' :
224- from pydantic_ai .models .openai import OpenAIChatModel
225-
226- return OpenAIChatModel (model_name = model_name , provider = 'gateway' )
227- elif api_type == 'groq' :
228- from pydantic_ai .models .groq import GroqModel
229-
230- return GroqModel (model_name = model_name , provider = 'gateway' )
231- elif api_type == 'responses' :
232- from pydantic_ai .models .openai import OpenAIResponsesModel
233-
234- return OpenAIResponsesModel (model_name = model_name , provider = 'gateway' )
235- elif api_type == 'gemini' :
236- from pydantic_ai .models .google import GoogleModel
215+ def normalize_gateway_provider (provider : str ) -> str :
216+ """Normalize a gateway provider name.
237217
238- return GoogleModel (model_name = model_name , provider = 'gateway' )
239- elif api_type == 'converse' :
240- from pydantic_ai .models .bedrock import BedrockConverseModel
241-
242- return BedrockConverseModel (model_name = model_name , provider = 'gateway' )
243- elif api_type == 'anthropic' :
244- from pydantic_ai .models .anthropic import AnthropicModel
245-
246- return AnthropicModel (model_name = model_name , provider = 'gateway' )
247- else :
248- raise ValueError (f'Unknown API type: { api_type } ' ) # pragma: no cover
218+ Args:
219+ provider: The provider name to normalize.
220+ """
221+ if provider in ('openai' , 'openai-chat' , 'chat' ):
222+ return 'openai'
223+ elif provider in ('openai-responses' , 'responses' ):
224+ return 'openai-responses'
225+ elif provider in ('gemini' , 'google-vertex' ):
226+ return 'google-vertex'
227+ elif provider in ('bedrock' , 'converse' ):
228+ return 'bedrock'
229+ return provider
0 commit comments