Skip to content

Commit 48849b3

Browse files
authored
refactor(gateway): add upstream_provider back (#3391)
1 parent 365b67b commit 48849b3

13 files changed

+142
-158
lines changed

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -780,9 +780,10 @@ def infer_model( # noqa: C901
780780

781781
model_kind = provider_name
782782
if model_kind.startswith('gateway/'):
783-
from ..providers.gateway import infer_gateway_model
783+
from ..providers.gateway import normalize_gateway_provider
784784

785-
return infer_gateway_model(model_kind.removeprefix('gateway/'), model_name=model_name)
785+
model_kind = provider_name.removeprefix('gateway/')
786+
model_kind = normalize_gateway_provider(model_kind)
786787
if model_kind in (
787788
'openai',
788789
'azure',

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def __init__(
240240
self._model_name = model_name
241241

242242
if isinstance(provider, str):
243-
provider = infer_provider('gateway/converse' if provider == 'gateway' else provider)
243+
provider = infer_provider('gateway/bedrock' if provider == 'gateway' else provider)
244244
self._provider = provider
245245
self.client = cast('BedrockRuntimeClient', provider.client)
246246

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def __init__(
204204
self._model_name = model_name
205205

206206
if isinstance(provider, str):
207-
provider = infer_provider('gateway/gemini' if provider == 'gateway' else provider)
207+
provider = infer_provider('gateway/google-vertex' if provider == 'gateway' else provider)
208208
self._provider = provider
209209
self.client = provider.client
210210

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def __init__(
375375
self._model_name = model_name
376376

377377
if isinstance(provider, str):
378-
provider = infer_provider('gateway/chat' if provider == 'gateway' else provider)
378+
provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
379379
self._provider = provider
380380
self.client = provider.client
381381

@@ -944,7 +944,7 @@ def __init__(
944944
self._model_name = model_name
945945

946946
if isinstance(provider, str):
947-
provider = infer_provider('gateway/responses' if provider == 'gateway' else provider)
947+
provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
948948
self._provider = provider
949949
self.client = provider.client
950950

pydantic_ai_slim/pydantic_ai/providers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ def infer_provider(provider: str) -> Provider[Any]:
158158
if provider.startswith('gateway/'):
159159
from .gateway import gateway_provider
160160

161-
api_type = provider.removeprefix('gateway/')
162-
return gateway_provider(api_type)
161+
upstream_provider = provider.removeprefix('gateway/')
162+
return gateway_provider(upstream_provider)
163163
elif provider in ('google-vertex', 'google-gla'):
164164
from .google import GoogleProvider
165165

pydantic_ai_slim/pydantic_ai/providers/gateway.py

Lines changed: 61 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from groq import AsyncGroq
1818
from openai import AsyncOpenAI
1919

20-
from pydantic_ai.models import Model
2120
from pydantic_ai.models.anthropic import AsyncAnthropicClient
2221
from pydantic_ai.providers import Provider
2322

@@ -26,11 +25,10 @@
2625

2726
@overload
2827
def gateway_provider(
29-
api_type: Literal['chat', 'responses'],
28+
upstream_provider: Literal['openai', 'openai-chat', 'openai-responses', 'chat', 'responses'],
3029
/,
3130
*,
32-
routing_group: str | None = None,
33-
profile: str | None = None,
31+
route: str | None = None,
3432
api_key: str | None = None,
3533
base_url: str | None = None,
3634
http_client: httpx.AsyncClient | None = None,
@@ -39,11 +37,10 @@ def gateway_provider(
3937

4038
@overload
4139
def gateway_provider(
42-
api_type: Literal['groq'],
40+
upstream_provider: Literal['groq'],
4341
/,
4442
*,
45-
routing_group: str | None = None,
46-
profile: str | None = None,
43+
route: str | None = None,
4744
api_key: str | None = None,
4845
base_url: str | None = None,
4946
http_client: httpx.AsyncClient | None = None,
@@ -52,11 +49,10 @@ def gateway_provider(
5249

5350
@overload
5451
def gateway_provider(
55-
api_type: Literal['anthropic'],
52+
upstream_provider: Literal['anthropic'],
5653
/,
5754
*,
58-
routing_group: str | None = None,
59-
profile: str | None = None,
55+
route: str | None = None,
6056
api_key: str | None = None,
6157
base_url: str | None = None,
6258
http_client: httpx.AsyncClient | None = None,
@@ -65,23 +61,21 @@ def gateway_provider(
6561

6662
@overload
6763
def gateway_provider(
68-
api_type: Literal['converse'],
64+
upstream_provider: Literal['bedrock', 'converse'],
6965
/,
7066
*,
71-
routing_group: str | None = None,
72-
profile: str | None = None,
67+
route: str | None = None,
7368
api_key: str | None = None,
7469
base_url: str | None = None,
7570
) -> Provider[BaseClient]: ...
7671

7772

7873
@overload
7974
def gateway_provider(
80-
api_type: Literal['gemini'],
75+
upstream_provider: Literal['gemini', 'google-vertex'],
8176
/,
8277
*,
83-
routing_group: str | None = None,
84-
profile: str | None = None,
78+
route: str | None = None,
8579
api_key: str | None = None,
8680
base_url: str | None = None,
8781
http_client: httpx.AsyncClient | None = None,
@@ -90,26 +84,37 @@ def gateway_provider(
9084

9185
@overload
9286
def gateway_provider(
93-
api_type: str,
87+
upstream_provider: str,
9488
/,
9589
*,
96-
routing_group: str | None = None,
97-
profile: str | None = None,
90+
route: str | None = None,
9891
api_key: str | None = None,
9992
base_url: str | None = None,
10093
) -> Provider[Any]: ...
10194

10295

103-
APIType = Literal['chat', 'responses', 'gemini', 'converse', 'anthropic', 'groq']
96+
UpstreamProvider = Literal[
97+
'openai',
98+
'groq',
99+
'anthropic',
100+
'bedrock',
101+
'google-vertex',
102+
# Those are only API formats, but we still support them for convenience.
103+
'openai-chat',
104+
'openai-responses',
105+
'chat',
106+
'responses',
107+
'converse',
108+
'gemini',
109+
]
104110

105111

106112
def gateway_provider(
107-
api_type: APIType | str,
113+
upstream_provider: UpstreamProvider | str,
108114
/,
109115
*,
110116
# Every provider
111-
routing_group: str | None = None,
112-
profile: str | None = None,
117+
route: str | None = None,
113118
api_key: str | None = None,
114119
base_url: str | None = None,
115120
# OpenAI, Groq, Anthropic & Gemini - Only Bedrock doesn't have an HTTPX client.
@@ -118,11 +123,9 @@ def gateway_provider(
118123
"""Create a new Gateway provider.
119124
120125
Args:
121-
api_type: Determines the API type to use.
122-
routing_group: The group of APIs that support the same models - the idea is that you can route the requests to
123-
any provider in a routing group. The `pydantic-ai-gateway-routing-group` header will be added.
124-
profile: A provider may have a profile, which is a unique identifier for the provider.
125-
The `pydantic-ai-gateway-profile` header will be added.
126+
upstream_provider: The upstream provider to use.
127+
route: The name of the provider or routing group to use to handle the request. If not provided, the default
128+
routing group for the API format will be used.
126129
api_key: The API key to use for authentication. If not provided, the `PYDANTIC_AI_GATEWAY_API_KEY`
127130
environment variable will be used if available.
128131
base_url: The base URL to use for the Gateway. If not provided, the `PYDANTIC_AI_GATEWAY_BASE_URL`
@@ -137,54 +140,45 @@ def gateway_provider(
137140
)
138141

139142
base_url = base_url or os.getenv('PYDANTIC_AI_GATEWAY_BASE_URL', GATEWAY_BASE_URL)
140-
http_client = http_client or cached_async_http_client(provider=f'gateway/{api_type}')
143+
http_client = http_client or cached_async_http_client(provider=f'gateway/{upstream_provider}')
141144
http_client.event_hooks = {'request': [_request_hook(api_key)]}
142145

143-
if profile is not None:
144-
http_client.headers.setdefault('pydantic-ai-gateway-profile', profile)
146+
if route is None:
147+
# Use the implied providerId as the default route.
148+
route = normalize_gateway_provider(upstream_provider)
145149

146-
if routing_group is not None:
147-
http_client.headers.setdefault('pydantic-ai-gateway-routing-group', routing_group)
150+
base_url = _merge_url_path(base_url, route)
148151

149-
if api_type in ('chat', 'responses'):
152+
if upstream_provider in ('openai', 'openai-chat', 'openai-responses', 'chat', 'responses'):
150153
from .openai import OpenAIProvider
151154

152-
return OpenAIProvider(api_key=api_key, base_url=_merge_url_path(base_url, api_type), http_client=http_client)
153-
elif api_type == 'groq':
155+
return OpenAIProvider(api_key=api_key, base_url=base_url, http_client=http_client)
156+
elif upstream_provider == 'groq':
154157
from .groq import GroqProvider
155158

156-
return GroqProvider(api_key=api_key, base_url=_merge_url_path(base_url, 'groq'), http_client=http_client)
157-
elif api_type == 'anthropic':
159+
return GroqProvider(api_key=api_key, base_url=base_url, http_client=http_client)
160+
elif upstream_provider == 'anthropic':
158161
from anthropic import AsyncAnthropic
159162

160163
from .anthropic import AnthropicProvider
161164

162165
return AnthropicProvider(
163-
anthropic_client=AsyncAnthropic(
164-
auth_token=api_key,
165-
base_url=_merge_url_path(base_url, 'anthropic'),
166-
http_client=http_client,
167-
)
166+
anthropic_client=AsyncAnthropic(auth_token=api_key, base_url=base_url, http_client=http_client)
168167
)
169-
elif api_type == 'converse':
168+
elif upstream_provider in ('bedrock', 'converse'):
170169
from .bedrock import BedrockProvider
171170

172171
return BedrockProvider(
173172
api_key=api_key,
174-
base_url=_merge_url_path(base_url, api_type),
173+
base_url=base_url,
175174
region_name='pydantic-ai-gateway', # Fake region name to avoid NoRegionError
176175
)
177-
elif api_type == 'gemini':
176+
elif upstream_provider in ('google-vertex', 'gemini'):
178177
from .google import GoogleProvider
179178

180-
return GoogleProvider(
181-
vertexai=True,
182-
api_key=api_key,
183-
base_url=_merge_url_path(base_url, 'gemini'),
184-
http_client=http_client,
185-
)
179+
return GoogleProvider(vertexai=True, api_key=api_key, base_url=base_url, http_client=http_client)
186180
else:
187-
raise UserError(f'Unknown API type: {api_type}')
181+
raise UserError(f'Unknown upstream provider: {upstream_provider}')
188182

189183

190184
def _request_hook(api_key: str) -> Callable[[httpx.Request], Awaitable[httpx.Request]]:
@@ -218,31 +212,18 @@ def _merge_url_path(base_url: str, path: str) -> str:
218212
return base_url.rstrip('/') + '/' + path.lstrip('/')
219213

220214

221-
def infer_gateway_model(api_type: APIType | str, *, model_name: str) -> Model:
222-
"""Infer the model class for a given API type."""
223-
if api_type == 'chat':
224-
from pydantic_ai.models.openai import OpenAIChatModel
225-
226-
return OpenAIChatModel(model_name=model_name, provider='gateway')
227-
elif api_type == 'groq':
228-
from pydantic_ai.models.groq import GroqModel
229-
230-
return GroqModel(model_name=model_name, provider='gateway')
231-
elif api_type == 'responses':
232-
from pydantic_ai.models.openai import OpenAIResponsesModel
233-
234-
return OpenAIResponsesModel(model_name=model_name, provider='gateway')
235-
elif api_type == 'gemini':
236-
from pydantic_ai.models.google import GoogleModel
215+
def normalize_gateway_provider(provider: str) -> str:
216+
"""Normalize a gateway provider name.
237217
238-
return GoogleModel(model_name=model_name, provider='gateway')
239-
elif api_type == 'converse':
240-
from pydantic_ai.models.bedrock import BedrockConverseModel
241-
242-
return BedrockConverseModel(model_name=model_name, provider='gateway')
243-
elif api_type == 'anthropic':
244-
from pydantic_ai.models.anthropic import AnthropicModel
245-
246-
return AnthropicModel(model_name=model_name, provider='gateway')
247-
else:
248-
raise ValueError(f'Unknown API type: {api_type}') # pragma: no cover
218+
Args:
219+
provider: The provider name to normalize.
220+
"""
221+
if provider in ('openai', 'openai-chat', 'chat'):
222+
return 'openai'
223+
elif provider in ('openai-responses', 'responses'):
224+
return 'openai-responses'
225+
elif provider in ('gemini', 'google-vertex'):
226+
return 'google-vertex'
227+
elif provider in ('bedrock', 'converse'):
228+
return 'bedrock'
229+
return provider

tests/providers/cassettes/test_gateway/test_gateway_provider_with_anthropic.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ interactions:
88
connection:
99
- keep-alive
1010
content-length:
11-
- '166'
11+
- '159'
1212
content-type:
1313
- application/json
1414
host:
@@ -32,6 +32,8 @@ interactions:
3232
- application/json
3333
pydantic-ai-gateway-price-estimate:
3434
- 0.0002USD
35+
retry-after:
36+
- '34'
3537
strict-transport-security:
3638
- max-age=31536000; includeSubDomains; preload
3739
transfer-encoding:
@@ -40,7 +42,7 @@ interactions:
4042
content:
4143
- text: The capital of France is Paris.
4244
type: text
43-
id: msg_0116L5r52AZ42YhvvdUuHEsk
45+
id: msg_015jjU4Q5dqhSc9vyfCdoujx
4446
model: claude-sonnet-4-5-20250929
4547
role: assistant
4648
stop_reason: end_turn

tests/providers/cassettes/test_gateway/test_gateway_provider_with_bedrock.yaml

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ interactions:
55
headers:
66
amz-sdk-invocation-id:
77
- !!binary |
8-
MmEzMzkzMGUtNzI3YS00YzFhLWFmYWQtYzFhYWMyMTI3NDlj
8+
MTU4ODY4OTctOGU4MC00YzJlLWEyZTctMDA2ZmM0NTZjMmYy
99
amz-sdk-request:
1010
- !!binary |
1111
YXR0ZW1wdD0x
@@ -15,35 +15,34 @@ interactions:
1515
- !!binary |
1616
YXBwbGljYXRpb24vanNvbg==
1717
method: POST
18-
uri: http://localhost:8787/converse/model/amazon.nova-micro-v1%3A0/converse
18+
uri: http://localhost:8787/bedrock/model/amazon.nova-micro-v1%3A0/converse
1919
response:
2020
headers:
2121
content-length:
22-
- '741'
22+
- '631'
2323
content-type:
2424
- application/json
2525
pydantic-ai-gateway-price-estimate:
2626
- 0.0000USD
2727
parsed_body:
2828
metrics:
29-
latencyMs: 682
29+
latencyMs: 717
3030
output:
3131
message:
3232
content:
3333
- text: The capital of France is Paris. Paris is not only the capital city but also the most populous city in France,
3434
and it is a major center for culture, commerce, fashion, and international diplomacy. The city is known for
35-
its historical and architectural landmarks, including the Eiffel Tower, the Louvre Museum, Notre-Dame Cathedral,
36-
and the Champs-Élysées. Paris plays a significant role in the global arts, fashion, research, technology, education,
37-
and entertainment scenes.
35+
its historical landmarks, such as the Eiffel Tower, the Louvre Museum, Notre-Dame Cathedral, and the Champs-Élysées,
36+
among many other attractions.
3837
role: assistant
3938
stopReason: end_turn
4039
usage:
4140
inputTokens: 7
42-
outputTokens: 96
41+
outputTokens: 78
4342
pydantic_ai_gateway:
44-
cost_estimate: 1.3685000000000002e-05
43+
cost_estimate: 1.1165000000000002e-05
4544
serverToolUsage: {}
46-
totalTokens: 103
45+
totalTokens: 85
4746
status:
4847
code: 200
4948
message: OK

0 commit comments

Comments
 (0)