Skip to content

Commit 8fa1d40

Browse files
committed
self improving agents
1 parent dc1f3aa commit 8fa1d40

File tree

8 files changed

+835
-28
lines changed

8 files changed

+835
-28
lines changed

human-seeded-evals/app/agent.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
from __future__ import annotations as _annotations
22

3+
import os
4+
from contextlib import asynccontextmanager
35
from dataclasses import dataclass
46
from datetime import datetime
7+
from typing import AsyncIterator
58

9+
from cloudkv import AsyncCloudKV
610
from pydantic_ai import Agent, RunContext
11+
from pydantic_ai.models import Model
712

813
from .models import TimeRangeInputs, TimeRangeResponse
14+
from .self_improving_agent import SelfImprovingAgentModel
15+
from .self_improving_agent_storage import CloudKVStorage
916

1017

1118
@dataclass
@@ -23,13 +30,24 @@ class TimeRangeDeps:
2330
)
2431

2532

33+
@asynccontextmanager
34+
async def self_improving_model() -> AsyncIterator[SelfImprovingAgentModel]:
35+
cloudkv_read_token, cloudkv_write_token = os.environ['CLOUDKV_TOKEN'].split('.')
36+
logfire_read_token = os.environ['LOGFIRE_READ_TOKEN']
37+
async with AsyncCloudKV(cloudkv_read_token, cloudkv_write_token) as cloudkv:
38+
storage = CloudKVStorage(cloudkv)
39+
m = SelfImprovingAgentModel('anthropic:claude-sonnet-4-0', storage, logfire_read_token, 'time_range_agent')
40+
yield m
41+
await m.wait_for_coach()
42+
43+
2644
@time_range_agent.instructions
2745
def inject_current_time(ctx: RunContext[TimeRangeDeps]) -> str:
2846
"""Add the user's current time and timezone in the format 'Friday, November 22, 2024 11:15:14 PST' to context."""
2947
return f"The user's current time is {ctx.deps.now:%A, %B %d, %Y %H:%M:%S %Z}."
3048

3149

32-
async def infer_time_range(inputs: TimeRangeInputs) -> TimeRangeResponse:
50+
async def infer_time_range(inputs: TimeRangeInputs, *, model: Model | None = None) -> TimeRangeResponse:
3351
"""Infer a time range from a user prompt."""
34-
result = await time_range_agent.run(inputs.prompt, deps=TimeRangeDeps(now=inputs.now))
52+
result = await time_range_agent.run(inputs.prompt, deps=TimeRangeDeps(now=inputs.now), model=model)
3553
return result.output

human-seeded-evals/app/main.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,30 @@
1+
from contextlib import asynccontextmanager
2+
from typing import cast
3+
14
import logfire
2-
from fastapi import FastAPI
5+
from fastapi import FastAPI, Request
36

4-
from .agent import infer_time_range
7+
from .agent import infer_time_range, self_improving_model
58
from .models import TimeRangeInputs, TimeRangeResponse
9+
from .self_improving_agent import SelfImprovingAgentModel
610

711
logfire.configure(environment='dev')
12+
813
logfire.instrument_pydantic_ai()
914

10-
app = FastAPI()
15+
16+
@asynccontextmanager
17+
async def lifespan(app: FastAPI):
18+
async with self_improving_model() as model:
19+
app.state.model = model
20+
yield
21+
22+
23+
app = FastAPI(lifespan=lifespan)
1124
logfire.instrument_fastapi(app)
1225

1326

1427
@app.post('/api/timerange')
15-
async def convert_time_range(time_range_inputs: TimeRangeInputs) -> TimeRangeResponse:
16-
return await infer_time_range(time_range_inputs)
28+
async def convert_time_range(request: Request, time_range_inputs: TimeRangeInputs) -> TimeRangeResponse:
29+
model = cast(SelfImprovingAgentModel, request.app.state.model)
30+
return await infer_time_range(time_range_inputs, model=model)

0 commit comments

Comments
 (0)