Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/api/models/grok.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# `pydantic_ai.models.grok`

## Setup

For details on how to set up authentication with this model, see [model configuration for Grokq](../../models/grokq.md).

::: pydantic_ai.models.grok
77 changes: 77 additions & 0 deletions docs/models/grok.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Groq

## Install

To use `GroqModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `groq` optional group:

```bash
pip/uv-add "pydantic-ai-slim[groq]"
```

## Configuration

To use [Groq](https://groq.com/) through their API, go to [console.groq.com/keys](https://console.groq.com/keys) and follow your nose until you find the place to generate an API key.

`GroqModelName` contains a list of available Groq models.

## Environment variable

Once you have the API key, you can set it as an environment variable:

```bash
export GROQ_API_KEY='your-api-key'
```

You can then use `GroqModel` by name:

```python
from pydantic_ai import Agent

agent = Agent('groq:llama-3.3-70b-versatile')
...
```

Or initialise the model directly with just the model name:

```python
from pydantic_ai import Agent
from pydantic_ai.models.groq import GroqModel

model = GroqModel('llama-3.3-70b-versatile')
agent = Agent(model)
...
```

## `provider` argument

You can provide a custom `Provider` via the `provider` argument:

```python
from pydantic_ai import Agent
from pydantic_ai.models.groq import GroqModel
from pydantic_ai.providers.groq import GroqProvider

model = GroqModel(
'llama-3.3-70b-versatile', provider=GroqProvider(api_key='your-api-key')
)
agent = Agent(model)
...
```

You can also customize the `GroqProvider` with a custom `httpx.AsyncHTTPClient`:

```python
from httpx import AsyncClient

from pydantic_ai import Agent
from pydantic_ai.models.groq import GroqModel
from pydantic_ai.providers.groq import GroqProvider

custom_http_client = AsyncClient(timeout=30)
model = GroqModel(
'llama-3.3-70b-versatile',
provider=GroqProvider(api_key='your-api-key', http_client=custom_http_client),
)
agent = Agent(model)
...
```
1 change: 1 addition & 0 deletions docs/models/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Pydantic AI is model-agnostic and has built-in support for multiple model provid
* [OpenAI](openai.md)
* [Anthropic](anthropic.md)
* [Gemini](google.md) (via two different APIs: Generative Language API and VertexAI API)
* [Grok](grok.md)
* [Groq](groq.md)
* [Mistral](mistral.md)
* [Cohere](cohere.md)
Expand Down
248 changes: 248 additions & 0 deletions examples/pydantic_ai_examples/flight_booking_grok.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
"""Example of a multi-agent flow where one agent delegates work to another.

In this scenario, a group of agents work together to find flights for a user.
"""

import datetime
from dataclasses import dataclass
from typing import Literal
import os
import logfire
from pydantic import BaseModel, Field
from rich.prompt import Prompt

from pydantic_ai import Agent, ModelRetry, RunContext, RunUsage, UsageLimits
from pydantic_ai.messages import ModelMessage

# Import local GrokModel
from pydantic_ai.models.grok import GrokModel

logfire.configure()
logfire.instrument_pydantic_ai()
logfire.instrument_httpx()

# Configure for xAI API
xai_api_key = os.getenv("XAI_API_KEY")
if not xai_api_key:
raise ValueError("XAI_API_KEY environment variable is required")


# Create the model using the new GrokModelpwd
model = GrokModel("grok-4-fast-non-reasoning", api_key=xai_api_key)


class FlightDetails(BaseModel):
"""Details of the most suitable flight."""

flight_number: str
price: int
origin: str = Field(description="Three-letter airport code")
destination: str = Field(description="Three-letter airport code")
date: datetime.date


class NoFlightFound(BaseModel):
"""When no valid flight is found."""


@dataclass
class Deps:
web_page_text: str
req_origin: str
req_destination: str
req_date: datetime.date


# This agent is responsible for controlling the flow of the conversation.
search_agent = Agent[Deps, FlightDetails | NoFlightFound](
model=model,
output_type=FlightDetails | NoFlightFound, # type: ignore
retries=4,
system_prompt=("Your job is to find the cheapest flight for the user on the given date. "),
)


# This agent is responsible for extracting flight details from web page text.
extraction_agent = Agent(
model=model,
output_type=list[FlightDetails],
system_prompt="Extract all the flight details from the given text.",
)


@search_agent.tool
async def extract_flights(ctx: RunContext[Deps]) -> list[FlightDetails]:
"""Get details of all flights."""
# we pass the usage to the search agent so requests within this agent are counted
result = await extraction_agent.run(ctx.deps.web_page_text, usage=ctx.usage)
logfire.info("found {flight_count} flights", flight_count=len(result.output))
return result.output


@search_agent.output_validator
async def validate_output(
ctx: RunContext[Deps], output: FlightDetails | NoFlightFound
) -> FlightDetails | NoFlightFound:
"""Procedural validation that the flight meets the constraints."""
if isinstance(output, NoFlightFound):
return output

errors: list[str] = []
if output.origin != ctx.deps.req_origin:
errors.append(f"Flight should have origin {ctx.deps.req_origin}, not {output.origin}")
if output.destination != ctx.deps.req_destination:
errors.append(
f"Flight should have destination {ctx.deps.req_destination}, not {output.destination}"
)
if output.date != ctx.deps.req_date:
errors.append(f"Flight should be on {ctx.deps.req_date}, not {output.date}")

if errors:
raise ModelRetry("\n".join(errors))
else:
return output


class SeatPreference(BaseModel):
row: int = Field(ge=1, le=30)
seat: Literal["A", "B", "C", "D", "E", "F"]


class Failed(BaseModel):
"""Unable to extract a seat selection."""


# This agent is responsible for extracting the user's seat selection
seat_preference_agent = Agent[None, SeatPreference | Failed](
model=model,
output_type=SeatPreference | Failed,
system_prompt=(
"Extract the user's seat preference. "
"Seats A and F are window seats. "
"Row 1 is the front row and has extra leg room. "
"Rows 14, and 20 also have extra leg room. "
),
)


# in reality this would be downloaded from a booking site,
# potentially using another agent to navigate the site
flights_web_page = """
1. Flight SFO-AK123
- Price: $350
- Origin: San Francisco International Airport (SFO)
- Destination: Ted Stevens Anchorage International Airport (ANC)
- Date: January 10, 2025

2. Flight SFO-AK456
- Price: $370
- Origin: San Francisco International Airport (SFO)
- Destination: Fairbanks International Airport (FAI)
- Date: January 10, 2025

3. Flight SFO-AK789
- Price: $400
- Origin: San Francisco International Airport (SFO)
- Destination: Juneau International Airport (JNU)
- Date: January 20, 2025

4. Flight NYC-LA101
- Price: $250
- Origin: San Francisco International Airport (SFO)
- Destination: Ted Stevens Anchorage International Airport (ANC)
- Date: January 10, 2025

5. Flight CHI-MIA202
- Price: $200
- Origin: Chicago O'Hare International Airport (ORD)
- Destination: Miami International Airport (MIA)
- Date: January 12, 2025

6. Flight BOS-SEA303
- Price: $120
- Origin: Boston Logan International Airport (BOS)
- Destination: Ted Stevens Anchorage International Airport (ANC)
- Date: January 12, 2025

7. Flight DFW-DEN404
- Price: $150
- Origin: Dallas/Fort Worth International Airport (DFW)
- Destination: Denver International Airport (DEN)
- Date: January 10, 2025

8. Flight ATL-HOU505
- Price: $180
- Origin: Hartsfield-Jackson Atlanta International Airport (ATL)
- Destination: George Bush Intercontinental Airport (IAH)
- Date: January 10, 2025
"""

# restrict how many requests this app can make to the LLM
usage_limits = UsageLimits(request_limit=15)


async def main():
deps = Deps(
web_page_text=flights_web_page,
req_origin="SFO",
req_destination="ANC",
req_date=datetime.date(2025, 1, 10),
)
message_history: list[ModelMessage] | None = None
usage: RunUsage = RunUsage()
# run the agent until a satisfactory flight is found
while True:
result = await search_agent.run(
f"Find me a flight from {deps.req_origin} to {deps.req_destination} on {deps.req_date}",
deps=deps,
usage=usage,
message_history=message_history,
usage_limits=usage_limits,
)
if isinstance(result.output, NoFlightFound):
print("No flight found")
break
else:
flight = result.output
print(f"Flight found: {flight}")
answer = Prompt.ask(
"Do you want to buy this flight, or keep searching? (buy/*search)",
choices=["buy", "search", ""],
show_choices=False,
)
if answer == "buy":
seat = await find_seat(usage)
await buy_tickets(flight, seat)
break
else:
message_history = result.all_messages(
output_tool_return_content="Please suggest another flight"
)


async def find_seat(usage: RunUsage) -> SeatPreference:
message_history: list[ModelMessage] | None = None
while True:
answer = Prompt.ask("What seat would you like?")

result = await seat_preference_agent.run(
answer,
message_history=message_history,
usage=usage,
usage_limits=usage_limits,
)
if isinstance(result.output, SeatPreference):
return result.output
else:
print("Could not understand seat preference. Please try again.")
message_history = result.all_messages()


async def buy_tickets(flight_details: FlightDetails, seat: SeatPreference):
print(f"Purchasing flight {flight_details=!r} {seat=!r}...")


if __name__ == "__main__":
import asyncio

asyncio.run(main())
Loading
Loading