Skip to content

Commit cde4ef1

Browse files
committed
add pydantic ai memory examples
1 parent 6d58f2f commit cde4ef1

File tree

5 files changed

+239
-3
lines changed

5 files changed

+239
-3
lines changed

pai-memory/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Pydantic AI memory examples
2+
3+
Run postgres in docker with
4+
5+
```bash
6+
docker run -e POSTGRES_HOST_AUTH_METHOD=trust --rm -it --name pg -p 5432:5432 -d postgres
7+
```

pai-memory/with_messages.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from collections.abc import AsyncIterator
5+
from contextlib import asynccontextmanager
6+
from typing import TYPE_CHECKING
7+
8+
import asyncpg
9+
from pydantic_ai import Agent
10+
from pydantic_ai.messages import ModelMessage, ModelMessagesTypeAdapter
11+
12+
# hack to get around asyncpg's poor typing support
13+
if TYPE_CHECKING:
14+
DbConn = asyncpg.Connection[asyncpg.Record]
15+
else:
16+
DbConn = asyncpg.Connection
17+
18+
19+
import logfire
20+
21+
logfire.configure(service_name='mem-msgs')
22+
logfire.instrument_pydantic_ai()
23+
logfire.instrument_asyncpg()
24+
25+
26+
@asynccontextmanager
27+
async def db() -> AsyncIterator[DbConn]:
28+
conn = await asyncpg.connect('postgresql://postgres@localhost:5432')
29+
await conn.execute("""
30+
create table if not exists messages(
31+
id serial primary key,
32+
ts timestamp not null default now(),
33+
user_id integer not null,
34+
messages json not null
35+
)
36+
""")
37+
38+
try:
39+
yield conn
40+
finally:
41+
await conn.close()
42+
43+
44+
agent = Agent(
45+
'openai:gpt-4o',
46+
instructions='You are a helpful assistant.',
47+
)
48+
49+
50+
@logfire.instrument
51+
async def run_agent(prompt: str, user_id: int):
52+
async with db() as conn:
53+
with logfire.span('retrieve messages'):
54+
messages: list[ModelMessage] = []
55+
for row in await conn.fetch('SELECT messages FROM messages WHERE user_id = $1 order by ts', user_id):
56+
messages += ModelMessagesTypeAdapter.validate_json(row[0])
57+
58+
result = await agent.run(prompt, message_history=messages)
59+
print(result.output)
60+
61+
with logfire.span('record messages'):
62+
msgs = result.new_messages_json().decode()
63+
await conn.execute('INSERT INTO messages(user_id, messages) VALUES($1, $2)', user_id, msgs)
64+
65+
66+
async def memory_messages():
67+
await run_agent('My name is Samuel.', 123)
68+
69+
await run_agent('What is my name?', 123)
70+
71+
72+
if __name__ == '__main__':
73+
asyncio.run(memory_messages())

pai-memory/with_tools.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from collections.abc import AsyncIterator
5+
from contextlib import asynccontextmanager
6+
from dataclasses import dataclass
7+
from typing import TYPE_CHECKING
8+
9+
import asyncpg
10+
from pydantic_ai import Agent, RunContext
11+
12+
# hack to get around asyncpg's poor typing support
13+
if TYPE_CHECKING:
14+
DbConn = asyncpg.Connection[asyncpg.Record]
15+
else:
16+
DbConn = asyncpg.Connection
17+
18+
19+
import logfire
20+
21+
logfire.configure(service_name='mem-tool')
22+
logfire.instrument_pydantic_ai()
23+
logfire.instrument_asyncpg()
24+
25+
26+
@asynccontextmanager
27+
async def db(reset: bool = False) -> AsyncIterator[DbConn]:
28+
conn = await asyncpg.connect('postgresql://postgres@localhost:5432')
29+
if reset:
30+
await conn.execute('drop table if exists memory')
31+
await conn.execute("""
32+
create table if not exists memory(
33+
id serial primary key,
34+
user_id integer not null,
35+
value text not null,
36+
unique(user_id, value)
37+
)
38+
""")
39+
40+
try:
41+
yield conn
42+
finally:
43+
await conn.close()
44+
45+
46+
@dataclass
47+
class Deps:
48+
user_id: int
49+
conn: DbConn
50+
51+
52+
agent = Agent(
53+
'openai:gpt-4o',
54+
deps_type=Deps,
55+
instructions='You are a helpful assistant.',
56+
)
57+
58+
59+
@agent.tool
60+
async def record_memory(ctx: RunContext[Deps], value: str) -> str:
61+
"""Use this tool to store information in memory."""
62+
await ctx.deps.conn.execute(
63+
'insert into memory(user_id, value) values($1, $2) on conflict do nothing',
64+
ctx.deps.user_id,
65+
value,
66+
)
67+
return 'Value added to memory.'
68+
69+
70+
@agent.tool
71+
async def retrieve_memories(ctx: RunContext[Deps], memory_contains: str) -> str:
72+
"""Get all memories about the user."""
73+
rows = await ctx.deps.conn.fetch(
74+
'select value from memory where user_id = $1 and value ilike $2',
75+
ctx.deps.user_id,
76+
f'%{memory_contains}%',
77+
)
78+
return '\n'.join(row[0] for row in rows)
79+
80+
81+
async def memory_tools():
82+
async with db(True) as conn:
83+
deps = Deps(123, conn)
84+
result = await agent.run('My name is Samuel.', deps=deps)
85+
print(result.output)
86+
87+
# time goes by...
88+
89+
async with db() as conn:
90+
deps = Deps(123, conn)
91+
result = await agent.run('What is my name?', deps=deps)
92+
print(result.output)
93+
94+
95+
if __name__ == '__main__':
96+
asyncio.run(memory_tools())

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@ description = "Add your description here"
55
readme = "README.md"
66
requires-python = ">=3.12"
77
dependencies = [
8+
"asyncpg>=0.30.0",
9+
"asyncpg-stubs>=0.30.2",
810
"devtools>=0.12.2",
911
"fastapi>=0.115.14",
10-
"logfire[fastapi,httpx]>=3.21.1",
12+
"logfire[asyncpg,fastapi,httpx]>=3.21.1",
1113
"pydantic-ai>=0.3.4",
1214
]
1315

uv.lock

Lines changed: 60 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)