From f10635641bb198f079b20e625450350800857b8b Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 2 Jul 2025 13:01:11 +0100 Subject: [PATCH 1/6] self improving agents --- human-seeded-evals/app/agent.py | 22 +- human-seeded-evals/app/main.py | 24 +- .../app/self_improving_agent.py | 638 ++++++++++++++++++ .../app/self_improving_agent_storage.py | 32 + human-seeded-evals/pyproject.toml | 21 - human-seeded-evals/self_improving_demo.py | 71 ++ pyproject.toml | 5 + uv.lock | 50 ++ 8 files changed, 835 insertions(+), 28 deletions(-) create mode 100644 human-seeded-evals/app/self_improving_agent.py create mode 100644 human-seeded-evals/app/self_improving_agent_storage.py delete mode 100644 human-seeded-evals/pyproject.toml create mode 100644 human-seeded-evals/self_improving_demo.py diff --git a/human-seeded-evals/app/agent.py b/human-seeded-evals/app/agent.py index 246e4b0..a4a8984 100644 --- a/human-seeded-evals/app/agent.py +++ b/human-seeded-evals/app/agent.py @@ -1,11 +1,18 @@ from __future__ import annotations as _annotations +import os +from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import datetime +from typing import AsyncIterator +from cloudkv import AsyncCloudKV from pydantic_ai import Agent, RunContext +from pydantic_ai.models import Model from .models import TimeRangeInputs, TimeRangeResponse +from .self_improving_agent import SelfImprovingAgentModel +from .self_improving_agent_storage import CloudKVStorage @dataclass @@ -23,13 +30,24 @@ class TimeRangeDeps: ) +@asynccontextmanager +async def self_improving_model() -> AsyncIterator[SelfImprovingAgentModel]: + cloudkv_read_token, cloudkv_write_token = os.environ['CLOUDKV_TOKEN'].split('.') + logfire_read_token = os.environ['LOGFIRE_READ_TOKEN'] + async with AsyncCloudKV(cloudkv_read_token, cloudkv_write_token) as cloudkv: + storage = CloudKVStorage(cloudkv) + m = SelfImprovingAgentModel('anthropic:claude-sonnet-4-0', storage, logfire_read_token, 'time_range_agent') + yield m + await m.wait_for_coach() + + @time_range_agent.instructions def inject_current_time(ctx: RunContext[TimeRangeDeps]) -> str: """Add the user's current time and timezone in the format 'Friday, November 22, 2024 11:15:14 PST' to context.""" return f"The user's current time is {ctx.deps.now:%A, %B %d, %Y %H:%M:%S %Z}." -async def infer_time_range(inputs: TimeRangeInputs) -> TimeRangeResponse: +async def infer_time_range(inputs: TimeRangeInputs, *, model: Model | None = None) -> TimeRangeResponse: """Infer a time range from a user prompt.""" - result = await time_range_agent.run(inputs.prompt, deps=TimeRangeDeps(now=inputs.now)) + result = await time_range_agent.run(inputs.prompt, deps=TimeRangeDeps(now=inputs.now), model=model) return result.output diff --git a/human-seeded-evals/app/main.py b/human-seeded-evals/app/main.py index f04baa1..8157ea4 100644 --- a/human-seeded-evals/app/main.py +++ b/human-seeded-evals/app/main.py @@ -1,16 +1,30 @@ +from contextlib import asynccontextmanager +from typing import cast + import logfire -from fastapi import FastAPI +from fastapi import FastAPI, Request -from .agent import infer_time_range +from .agent import infer_time_range, self_improving_model from .models import TimeRangeInputs, TimeRangeResponse +from .self_improving_agent import SelfImprovingAgentModel logfire.configure(environment='dev') + logfire.instrument_pydantic_ai() -app = FastAPI() + +@asynccontextmanager +async def lifespan(app: FastAPI): + async with self_improving_model() as model: + app.state.model = model + yield + + +app = FastAPI(lifespan=lifespan) logfire.instrument_fastapi(app) @app.post('/api/timerange') -async def convert_time_range(time_range_inputs: TimeRangeInputs) -> TimeRangeResponse: - return await infer_time_range(time_range_inputs) +async def convert_time_range(request: Request, time_range_inputs: TimeRangeInputs) -> TimeRangeResponse: + model = cast(SelfImprovingAgentModel, request.app.state.model) + return await infer_time_range(time_range_inputs, model=model) diff --git a/human-seeded-evals/app/self_improving_agent.py b/human-seeded-evals/app/self_improving_agent.py new file mode 100644 index 0000000..85c7243 --- /dev/null +++ b/human-seeded-evals/app/self_improving_agent.py @@ -0,0 +1,638 @@ +from __future__ import annotations as _annotations + +import asyncio +import csv +import io +from abc import ABC, abstractmethod +from contextlib import contextmanager +from copy import deepcopy +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from typing import ( + Annotated, + Any, + AsyncContextManager, + Iterable, + Iterator, + Literal, + Protocol, + TypeAlias, + TypedDict, + TypeGuard, + cast, +) + +from annotated_types import Ge, Le +from logfire import Logfire +from logfire.experimental.query_client import AsyncLogfireQueryClient +from pydantic import AwareDatetime, BaseModel, Field, TypeAdapter +from pydantic_ai import Agent, format_as_xml +from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, SystemPromptPart +from pydantic_ai.models import KnownModelName, Model, ModelRequestParameters, infer_model +from pydantic_ai.models.wrapper import WrapperModel +from pydantic_ai.settings import ModelSettings +from pydantic_ai.tools import ObjectJsonSchema, ToolDefinition + +logfire = Logfire(otel_scope='self-improving-agent') +FieldsPatch = dict[str, str] + + +class ModelContextPatch(BaseModel): + context_patch: FieldsPatch + timestamp: AwareDatetime + + +class SelfImprovingAgentStorage(ABC): + @abstractmethod + async def get_patch(self, agent_name: str) -> ModelContextPatch | None: + """Get the patch for the given agent.""" + + @abstractmethod + async def set_patch(self, agent_name: str, patch: ModelContextPatch, expires: timedelta) -> None: + """Set the patch for the given agent.""" + + @abstractmethod + def lock(self, agent_name: str) -> AsyncContextManager[bool]: + """Try to obtain a lock for a coach run. + + This doesn't have to be perfect, but being relatively reliable will reduce the changes of duplicate + concurrent agent runs. + + Returns: + Async context manager that yields `True` if the lock was obtained, `False` otherwise and + releases the lock when exited. + """ + + +class AbstractCoachOutput(Protocol): + context_patch: FieldsPatch + developer_suggestions: str | None + overall_context_score: int + + +@dataclass(init=False) +class SelfImprovingAgentModel(WrapperModel): + wrapped_model: Model + storage: SelfImprovingAgentStorage + logfire_read_token: str + agent_name: str + """Used to filter to data about runs of this agent""" + logfire_environment: str | None + """Name of the environment in logfire where the main agent is running, improves query performance""" + logfire_filter: str | None + """Additional logfire filter when looking for agent run traces to improve performance""" + + minimum_run_interval: timedelta + """Minimum interval between coach runs""" + + minimum_new_runs: int = 1 + """Minimum number of new runs required to trigger a coach run""" + + force_blocking: bool = False + """Whether to block the main agent until the coach has completed its run, and forces a run of the coach""" + + coach_model: Model + """Model used for the coach agent""" + + _patch: ModelContextPatch | None = None + _coach_task: asyncio.Task[None] | None = None + _coach_agent: Agent[None, AbstractCoachOutput] | None = None + + def __init__( + self, + wrapped_model: Model | KnownModelName, + storage: SelfImprovingAgentStorage, + logfire_read_token: str, + agent_name: str, + logfire_environment: str | None = None, + logfire_filter: str | None = None, + minimum_run_interval: timedelta = timedelta(minutes=30), + minimum_new_runs: int = 1, + coach_model: Model | KnownModelName = 'anthropic:claude-opus-4-0', + ): + super().__init__(wrapped_model) + self.storage = storage + self.logfire_read_token = logfire_read_token + self.agent_name = agent_name + self.logfire_environment = logfire_environment + self.logfire_filter = logfire_filter + self.minimum_run_interval = minimum_run_interval + self.minimum_new_runs = minimum_new_runs + self.coach_model = infer_model(coach_model) + + async def request( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> ModelResponse: + if self.force_blocking: + await self._blocking_coach(messages, model_settings, model_request_parameters) + else: + await self._deferred_coach(messages, model_settings, model_request_parameters) + + if patch := self._patch: + messages, model_request_parameters = apply_patch(messages, model_request_parameters, patch.context_patch) + + return await super().request(messages, model_settings, model_request_parameters) + + @property + def model_name(self) -> str: + """The model name.""" + return self.inner_model.model_name + + @property + def system(self) -> str: + """The system prompt.""" + return self.inner_model.system + + @contextmanager + def blocking_context(self) -> Iterator[None]: + force_blocking = self.force_blocking + self.force_blocking = True + try: + yield + finally: + self.force_blocking = force_blocking + + async def wait_for_coach(self): + if self._coach_task is not None: + logfire.info('waiting for existing coach run to finish') + await self._coach_task + self._coach_task = None + + async def _blocking_coach( + self, + messages: list[ModelMessage], + _model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ): + if self._coach_task is not None: + logfire.info('waiting for existing coach run to finish') + await self._coach_task + self._coach_task = None + + if self._patch is None: + self._patch = await self.storage.get_patch(self.agent_name) + + fields = list(context_patch_fields(messages, model_request_parameters)) + await self._run_coach(fields, force=True) + + async def _deferred_coach( + self, + messages: list[ModelMessage], + _model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ): + # check if the coach task is done, and await it if so (to raise any errors), then set it to None + if self._coach_task is not None and self._coach_task.done(): + self._coach_task = None + + if self._coach_task is not None: + # coach is already running + return + + if self._patch is None: + self._patch = await self.storage.get_patch(self.agent_name) + + if ( + self._patch is not None + and datetime.now(tz=timezone.utc) - self._patch.timestamp <= self.minimum_run_interval + ): + logfire.info('Got patch, not yet expired {patch_timestamp=}', patch_timestamp=self._patch.timestamp) + return + + # should we wait to build fields until we've checked the coach is going to run? + # in theory messages and model_request_parameters might be altered while the task is running + fields = list(context_patch_fields(messages, model_request_parameters)) + self._coach_task = asyncio.create_task(self._run_coach_wrapper(fields)) + + async def _run_coach_wrapper(self, fields: list[FieldDetails]): + try: + await self._run_coach(fields) + except Exception: + logfire.exception('Error running coach') + + async def _run_coach(self, fields: list[FieldDetails], *, force: bool = False): + if ( + not force + and self._patch is not None + and datetime.now(tz=timezone.utc) - self._patch.timestamp <= self.minimum_run_interval + ): + logfire.info('Got patch, not yet expired {patch_timestamp=}', patch_timestamp=self._patch.timestamp) + return + + async with self.storage.lock(self.agent_name) as got_lock: + if not got_lock: + logfire.info('another agent is already running the coach') + return + + runs, last_run = await self._get_runs(self._patch and self._patch.timestamp) + if not force and runs is None and self._patch is not None: + logfire.info('no new runs and we have a patch, so we not running the coach') + return + + coach_agent = self._get_coach_agent(fields) + prompt_data: dict[str, Any] = { + 'default_model_context': {f.key: f.current_prompt for f in fields if f.current_prompt} + } + if self._patch: + prompt_data['previous_context_patch'] = self._patch.context_patch + if runs: + prompt_data['recent_agent_runs'] = runs + prompt = format_as_xml(prompt_data, include_root_tag=False) + coach_agent = self._get_coach_agent(fields) + r = await coach_agent.run(prompt) + run_count = len(runs) if runs is not None else None + if r.output.overall_context_score < 5: + logfire.warning( + 'Coach run with quality warning, score={output.overall_context_score} {run_count=}', + output=r.output, + run_count=run_count, + ) + else: + logfire.info( + 'Coach run with quality ok, score={output.overall_context_score} {run_count=}', + output=r.output, + run_count=run_count, + ) + self._patch = patch = ModelContextPatch(context_patch=r.output.context_patch, timestamp=last_run) + await self.storage.set_patch(self.agent_name, patch, timedelta(days=1)) + + async def _get_runs(self, min_timestamp: datetime | None) -> tuple[list[AgentRunSummary] | None, datetime]: + async with AsyncLogfireQueryClient(self.logfire_read_token) as client: + runs_where = ["otel_scope_name='pydantic-ai'", f"message='{self.agent_name} run'"] + if self.logfire_environment: + runs_where.append(f"deployment_environment='{self.logfire_environment}'") + if self.logfire_filter: + runs_where.append(self.logfire_filter) + sql = runs_query.format(where=' AND '.join(runs_where)) + r = await client.query_json_rows(sql=sql, min_timestamp=min_timestamp) + runs_rows = r['rows'] + count = len(runs_rows) + if count < self.minimum_new_runs: + logfire.info('Found {run_count} runs, not enough to run coach', run_count=count) + return None, datetime.now(tz=timezone.utc) - timedelta(seconds=30) + + created_ats = datetime_list_schema.validate_python([row['created_at'] for row in runs_rows]) + last_run = max(created_ats) + + r = await client.query_json_rows(sql=feedback_query, min_timestamp=min_timestamp) + feedback_lookup: dict[str, Any] = { + '{trace_id}-{parent_span_id}'.format(**row): RunFeedback(**row) for row in r['rows'] + } + + runs: list[AgentRunSummary] = [] + feedback_count = 0 + for row in runs_rows: + if feedback := feedback_lookup.get('{trace_id}-{span_id}'.format(**row)): + row['feedback'] = feedback + feedback_count += 1 + run = AgentRunSummary.model_validate(row) + if run.prompt is not None: + runs.append(run) + + logfire.info( + 'Found {run_count} runs, {feedback_count} with feedback, running coach', + run_count=count, + feedback_count=feedback_count, + ) + return runs, last_run + + def _get_coach_agent(self, fields: list[FieldDetails]) -> Agent[None, AbstractCoachOutput]: + if self._coach_agent: + return self._coach_agent + + fields_dict = {f.key: Annotated[str, Field(description=f.description)] for f in fields} + ModelRequestFields = TypedDict( + 'ModelRequestFields', + fields_dict, # type: ignore + total=False, + ) + + class CoachOutput(BaseModel, use_attribute_docstrings=True): + context_patch: ModelRequestFields + """Patch to update context fields to improve the agent's performance.""" + developer_suggestions: str | None = None + """Suggestions to the developer about how to improve the agent code.""" + overall_context_score: Annotated[int, Ge(0), Le(10)] + """Overall quality of the context, on a scale from zero to ten, zero being the worst, ten being the best. + + Any value below 5 will trigger a warning to the agent maintainers. + """ + + self._coach_agent = agent = cast( + Agent[None, AbstractCoachOutput], + Agent(self.coach_model, output_type=CoachOutput, instructions=coach_instrunctions), + ) + return agent + + +class RunFeedback(BaseModel): + reaction: Literal['positive', 'negative'] | None + comment: str | None + + +class AgentRunSummary(BaseModel): + prompt: str | None + output: Any + feedback: RunFeedback | None = None + + +datetime_list_schema = TypeAdapter(list[AwareDatetime]) + +runs_query = """ +select + created_at, + trace_id, + span_id, + attributes->'all_messages_events'->1->>'content' as prompt, + attributes->'final_result' as output +from records +where {where} +order by created_at desc +limit 20 +""" +feedback_query = """ +select + trace_id, + parent_span_id, + attributes->>'Annotation' as reaction, + attributes->>'logfire.feedback.comment' as comment +from records +where + kind='annotation' and + attributes->>'logfire.feedback.name'='Annotation' +order by created_at desc +-- bigger limit to get all feedback linked to relevant runs +limit 200 +""" +# this is a rough prompt, can almost certainly be improved +coach_instrunctions = """\ +Your job is to improve the performance of an AI agent by analyzing the context provided to the model +and the agent's behavior (inputs, outputs and feedback where available), +then rewriting context where you are confident that it will improve the agent's performance. +To do this return a patch of model context prompts and descriptions where appropriate. + +Pay special attention to the `instructions` or `system_prompt` fields as they have the most significant impact on the agent's behavior. + +Be concise and clear: increasing text length will increase token usage and thereby cost. + +If you identify shortcomings in the context provided to the model that cannot be solved by adjusting the instructions +and tool descriptions, please suggest improvements the developer should make. + +YOU SHOULD ONLY INCLUDE SUGGESTIONS IF THERE ARE SHORTCOMINGS IN THE CONTEXT PROVIDED TO THE MODEL +THAT CANNOT BE SOLVED BY ADJUSTING THE INSTRUCTIONS AND TOOL DESCRIPTIONS. +""" + + +@dataclass +class FieldDetails: + key: str + description: str + current_prompt: str | None = None + + +def context_patch_fields( + messages: list[ModelMessage], model_request_parameters: ModelRequestParameters +) -> Iterable[FieldDetails]: + found_sys_prompt = False + if system_prompt := get_system_prompt(messages): + found_sys_prompt = True + yield FieldDetails('system_prompt', 'System prompt', system_prompt) + + instructions = get_instrunctions(messages) + if instructions or not found_sys_prompt: + yield FieldDetails('instructions', 'Instructions', instructions) + + yield from get_tools_fields(model_request_parameters.function_tools, 'function_tools', 'Function tool description') + + yield from get_tools_fields(model_request_parameters.output_tools, 'output_tools', 'Output tool description') + + +def get_system_prompt(messages: list[ModelMessage]) -> str | None: + """Get the first system prompt from messages, other system prompts are ignored.""" + for message in messages: + if isinstance(message, ModelRequest): + for part in message.parts: + if isinstance(part, SystemPromptPart): + return part.content + + +def get_instrunctions(messages: list[ModelMessage]) -> str | None: + """Get the first instruction from messages, other instructions are ignored.""" + for message in messages: + if isinstance(message, ModelRequest): + if message.instructions: + return message.instructions + + +def get_tools_fields(tools: list[ToolDefinition], prefix: str, description: str) -> Iterable[FieldDetails]: + for t in tools: + prefix = f'{prefix}.{escape_key(t.name)}' + yield FieldDetails(f'{prefix}.description', description, t.description) + yield from json_schema_fields(t.parameters_json_schema, f'{prefix}.parameters') + + +JsonSchema = dict[str, Any] + + +def json_schema_fields(schema: JsonSchema, prefix: str) -> Iterable[FieldDetails]: + yield FieldDetails(f'{prefix}.description', 'JSON schema field description', schema.get('description')) + + type_ = schema.get('type') + if type_ == 'object': + yield from _js_object(schema, prefix) + elif type_ == 'array': + yield from _js_array(schema, prefix) + elif type_ is None: + yield from _js_union(schema, prefix, 'anyOf') + yield from _js_union(schema, prefix, 'oneOf') + + +def _js_object(schema: ObjectJsonSchema, prefix: str) -> Iterable[FieldDetails]: + if properties := schema.get('properties'): + for key, value in properties.items(): + yield from json_schema_fields(value, f'{prefix}.properties.{escape_key(key)}') + + if additional_properties := schema.get('additionalProperties'): + if _is_json_schema(additional_properties): + yield from json_schema_fields(additional_properties, f'{prefix}.additionalProperties') + + if pattern_properties := schema.get('patternProperties'): + for key, value in pattern_properties.items(): + yield from json_schema_fields(value, f'{prefix}.patternProperties.{escape_key(key)}') + + +def _js_array(schema: ObjectJsonSchema, prefix: str) -> Iterable[FieldDetails]: + if prefix_items := schema.get('prefixItems'): + assert isinstance(prefix_items, list), f'Expected list for prefixItems, got {type(prefix_items)}' + for i, item in enumerate(cast(list[Any], prefix_items)): + if _is_json_schema(item): + yield from json_schema_fields(item, f'{prefix}.prefixItems.{i}') + + if items := schema.get('items'): + if _is_json_schema(items): + yield from json_schema_fields(items, f'{prefix}.items') + + +def _js_union(schema: JsonSchema, prefix: str, union_kind: Literal['anyOf', 'oneOf']) -> Iterable[FieldDetails]: + members = schema.get(union_kind) + if not members: + return + + for member in members: + if _is_json_schema(member): + yield from json_schema_fields(member, f'{prefix}.{union_kind}') + + +def escape_key(s: str) -> str: + if '.' in s: + # double double quotes matches how the csv module parses strings + return '"' + s.replace('"', '""') + '"' + else: + return s + + +def apply_patch( + messages: list[ModelMessage], model_request_parameters: ModelRequestParameters, patch: FieldsPatch +) -> tuple[list[ModelMessage], ModelRequestParameters]: + if not patch: + return messages, model_request_parameters + + messages = deepcopy(messages) + model_request_parameters = deepcopy(model_request_parameters) + changes = 0 + + nested_patch = unflatten(patch) + if system_prompt := nested_patch.get('system_prompt'): + assert isinstance(system_prompt, str), f'Expected str for system_prompt, got {type(system_prompt)}' + if set_system_prompt(messages, system_prompt): + changes += 1 + else: + logfire.warning('No system prompt found to replace') + + if instructions := nested_patch.get('instructions'): + assert isinstance(instructions, str), f'Expected str for instructions, got {type(instructions)}' + if set_instructions(messages, instructions): + changes += 1 + else: + logfire.warning('No instructions found to replace') + + changes += set_tools_fields(model_request_parameters.function_tools, 'function_tools', nested_patch) + changes += set_tools_fields(model_request_parameters.output_tools, 'output_tools', nested_patch) + + logfire.info('updated {changes} fields in messages and model request parameters', changes=changes) + return messages, model_request_parameters + + +UnflattenedPatch: TypeAlias = 'dict[str, UnflattenedPatch | str]' + + +def unflatten(patch: FieldsPatch) -> UnflattenedPatch: + d: UnflattenedPatch = {} + for key, value in patch.items(): + local_d = d + *parts, last = split_key(key) + for part in parts: + local_d = local_d.setdefault(part, {}) + assert isinstance(local_d, dict), f'Expected dict at {part}, got {type(local_d)}' + local_d[last] = value + return d + + +def set_system_prompt(messages: list[ModelMessage], system_prompt: str) -> bool: + for message in messages: + if isinstance(message, ModelRequest): + for part in message.parts: + if isinstance(part, SystemPromptPart): + part.content = system_prompt + return True + return False + + +def set_instructions(messages: list[ModelMessage], instructions: str) -> bool: + for message in messages: + if isinstance(message, ModelRequest): + if message.instructions is not None: + message.instructions = instructions + return True + + # if we didn't find existing instructions to replace, set instructions on the first model request + for message in messages: + if isinstance(message, ModelRequest): + message.instructions = instructions + return True + + return False + + +def set_tools_fields(tools: list[ToolDefinition], key: str, patch: UnflattenedPatch) -> int: + tools_patch = patch.get(key) + changes = 0 + if not tools_patch: + return changes + assert isinstance(tools_patch, dict), f'Expected dict at {key}, got {type(tools_patch)}' + for tool_name, tool_patch in tools_patch.items(): + assert isinstance(tool_name, str), f'Expected str at {key}.{tool_name}, got {type(tool_name)}' + assert isinstance(tool_patch, dict), f'Expected dict at {key}.{tool_name}, got {type(tool_patch)}' + tool = next((t for t in tools if t.name == tool_name), None) + assert tool is not None, f'Unable to find tool {key}.{tool_name}' + + if description := tool_patch.get('description'): + assert isinstance(description, str), ( + f'Expected str at {key}.{tool_name}.description, got {type(description)}' + ) + tool.description = description + changes += 1 + + if parameters := tool_patch.get('parameters'): + assert isinstance(parameters, dict), ( + f'Expected dict at {key}.{tool_name}.parameters, got {type(parameters)}' + ) + changes += update_json_schema(tool.parameters_json_schema, parameters, [key, tool_name, 'parameters']) + return changes + + +def update_json_schema(schema: JsonSchema, patch: UnflattenedPatch, path: list[str]) -> int: + changes = 0 + patch_copy = patch.copy() + if description := patch_copy.pop('description', None): + assert isinstance(description, str), f'Expected str at {".".join(path)}.description, got {type(description)}' + schema['description'] = description + changes += 1 + + for k, v in patch_copy.items(): + sub_path = path + [k] + assert isinstance(v, dict), f'Expected dict at {".".join(sub_path)}, got {type(v)}' + + sub_schema = schema.get(k) + if not sub_schema: + print('WARNING: Schema key not found') + + if _is_json_schema(sub_schema): + changes += update_json_schema(sub_schema, v, sub_path) + else: + assert isinstance(sub_schema, list), ( + f'Expected dict or list at {".".join(sub_path)}, got {type(sub_schema)}' + ) + + for k2, v2 in patch_copy.items(): + sub_sub_path = sub_path + [k2] + array_schema = cast(JsonSchema, sub_schema[int(k2)]) + assert isinstance(v2, dict), f'Expected dict at {".".join(sub_sub_path)}, got {type(v)}' + changes += update_json_schema(array_schema, v2, sub_sub_path) + + return changes + + +def split_key(s: str) -> list[str]: + if '"' in s: + # quotes in the string means we have to parse it properly + return next(csv.reader(io.StringIO(s), delimiter='.')) + else: + return s.split('.') + + +def _is_json_schema(obj: Any) -> TypeGuard[JsonSchema]: + return isinstance(obj, dict) diff --git a/human-seeded-evals/app/self_improving_agent_storage.py b/human-seeded-evals/app/self_improving_agent_storage.py new file mode 100644 index 0000000..f1c872a --- /dev/null +++ b/human-seeded-evals/app/self_improving_agent_storage.py @@ -0,0 +1,32 @@ +from contextlib import asynccontextmanager +from dataclasses import dataclass +from datetime import timedelta +from typing import AsyncIterator + +from cloudkv import AsyncCloudKV + +from .self_improving_agent import ModelContextPatch, SelfImprovingAgentStorage + + +@dataclass +class CloudKVStorage(SelfImprovingAgentStorage): + cloud_kv: AsyncCloudKV + + async def get_patch(self, agent_name: str) -> ModelContextPatch | None: + return await self.cloud_kv.get_as(agent_name, ModelContextPatch) + + async def set_patch(self, agent_name: str, patch: ModelContextPatch, expires: timedelta) -> None: + await self.cloud_kv.set(agent_name, patch, expires=expires) + + @asynccontextmanager + async def lock(self, agent_name: str) -> AsyncIterator[bool]: + key = f'lock:{agent_name}' + r = await self.cloud_kv.get(key) + if r is None: + await self.cloud_kv.set(key, True, expires=3600) + try: + yield True + finally: + await self.cloud_kv.delete(key) + else: + yield False diff --git a/human-seeded-evals/pyproject.toml b/human-seeded-evals/pyproject.toml deleted file mode 100644 index e5a3c83..0000000 --- a/human-seeded-evals/pyproject.toml +++ /dev/null @@ -1,21 +0,0 @@ -[project] -name = "human-seeded-evals" -version = "0.1.0" -readme = "README.md" - -[tool.ruff] -line-length = 120 -target-version = "py312" - -[tool.ruff.lint] -extend-select = ["Q", "RUF100", "C90", "UP", "I"] -flake8-quotes = { inline-quotes = "single", multiline-quotes = "double" } -isort = { combine-as-imports = true } -mccabe = { max-complexity = 15 } -pydocstyle = { convention = "google" } - -[tool.ruff.format] -quote-style = "single" - -[dependency-groups] -dev = ["pyright>=1.1.402", "ruff>=0.12.1", "watchfiles>=1.1.0"] diff --git a/human-seeded-evals/self_improving_demo.py b/human-seeded-evals/self_improving_demo.py new file mode 100644 index 0000000..275083d --- /dev/null +++ b/human-seeded-evals/self_improving_demo.py @@ -0,0 +1,71 @@ +import asyncio +import os +from contextlib import asynccontextmanager +from dataclasses import dataclass +from datetime import timedelta +from typing import AsyncIterator + +import logfire +from cloudkv import AsyncCloudKV +from pydantic import BaseModel +from pydantic_ai import Agent + +from app.self_improving_agent import ModelContextPatch, SelfImprovingAgentModel, SelfImprovingAgentStorage + +logfire.configure() +logfire.instrument_pydantic_ai() +logfire.instrument_httpx(capture_all=True) + + +class City(BaseModel, use_attribute_docstrings=True): + """Details about a city.""" + + city_name: str + """The city name.""" + country: str + """The country name.""" + + +cloudkv_read_token = 'r2WqFgs0tQBv4jUH9basLcjT' +cloudkv_write_token = 'Cpnj8XazGk9oPDHREA76680UPaV8juHfZ5eWDJadiijkQQnz' + + +@dataclass +class CloudKVStorage(SelfImprovingAgentStorage): + cloud_kv: AsyncCloudKV + + async def get_patch(self, agent_name: str) -> ModelContextPatch | None: + return await self.cloud_kv.get_as(agent_name, ModelContextPatch) + + async def set_patch(self, agent_name: str, patch: ModelContextPatch, expires: timedelta) -> None: + await self.cloud_kv.set(agent_name, patch, expires=expires) + + @asynccontextmanager + async def lock(self, agent_name: str) -> AsyncIterator[bool]: + key = f'lock:{agent_name}' + r = await self.cloud_kv.get(key) + if r is None: + await self.cloud_kv.set(key, True, expires=3600) + try: + yield True + finally: + await self.cloud_kv.delete(key) + else: + yield False + + +city_agent = Agent(output_type=City) + + +async def main(): + async with AsyncCloudKV(cloudkv_read_token, cloudkv_write_token) as cloudkv: + storage = CloudKVStorage(cloudkv) + model = SelfImprovingAgentModel('openai:gpt-4o', storage, os.environ['LOGFIRE_READ_TOKEN'], 'city_agent') + # with model.blocking_context(): + result = await city_agent.run('The windy city in the US of A.', model=model) + debug(result.output) + await model.wait_for_coach() + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index cb825ad..bacbbd2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ description = "Add your description here" readme = "README.md" requires-python = ">=3.12" dependencies = [ + "cloudkv>=0.3.0", "devtools>=0.12.2", "fastapi>=0.115.14", "httpx>=0.28.1", @@ -15,6 +16,10 @@ dependencies = [ "uvicorn>=0.34.3", ] +[dependency-groups] +dev = ["pyright>=1.1.402", "ruff>=0.12.1", "watchfiles>=1.1.0"] + + [tool.uv.workspace] members = ["human-seeded-evals"] diff --git a/uv.lock b/uv.lock index f057940..0de9ecc 100644 --- a/uv.lock +++ b/uv.lock @@ -181,6 +181,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/85/32/10bb5764d90a8eee674e9dc6f4db6a0ab47c8c4d0d83c27f7c39ac415a4d/click-8.2.1-py3-none-any.whl", hash = "sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b", size = 102215, upload-time = "2025-05-20T23:19:47.796Z" }, ] +[[package]] +name = "cloudkv" +version = "0.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "eval-type-backport" }, + { name = "httpx" }, + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ca/cc/7a77d60cb1fb2760ea4a35f71d7a3b15bef4a584803c6730f91b382f739b/cloudkv-0.3.0.tar.gz", hash = "sha256:9e3ca1d79a4fe0a04d8a17deeddce3fafed5646570ba097c86a64d118b568bc5", size = 131191, upload-time = "2025-06-08T09:31:57.437Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/2a/edb0131f554079d679b3cbac9aadd8895981d1b78f3ff6513c07b06ecccc/cloudkv-0.3.0-py3-none-any.whl", hash = "sha256:a1881107d1be92d0ce052eb85877c30e478a231a56d9825be348caa24cdc8447", size = 12993, upload-time = "2025-06-08T09:31:56.468Z" }, +] + [[package]] name = "cohere" version = "5.15.0" @@ -480,6 +494,16 @@ wheels = [ name = "human-seeded-evals" version = "0.1.0" source = { virtual = "human-seeded-evals" } +dependencies = [ + { name = "cloudkv" }, + { name = "devtools" }, + { name = "fastapi" }, + { name = "httpx" }, + { name = "logfire", extra = ["fastapi", "httpx"] }, + { name = "pydantic" }, + { name = "pydantic-ai" }, + { name = "uvicorn" }, +] [package.dev-dependencies] dev = [ @@ -489,6 +513,16 @@ dev = [ ] [package.metadata] +requires-dist = [ + { name = "cloudkv", specifier = ">=0.3.0" }, + { name = "devtools", specifier = ">=0.12.2" }, + { name = "fastapi", specifier = ">=0.115.14" }, + { name = "httpx", specifier = ">=0.28.1" }, + { name = "logfire", extras = ["fastapi", "httpx"], specifier = ">=3.21.1" }, + { name = "pydantic", specifier = ">=2.11.7" }, + { name = "pydantic-ai", specifier = ">=0.3.4" }, + { name = "uvicorn", specifier = ">=0.34.3" }, +] [package.metadata.requires-dev] dev = [ @@ -1066,6 +1100,7 @@ name = "pydantic-demo" version = "0.1.0" source = { virtual = "." } dependencies = [ + { name = "cloudkv" }, { name = "devtools" }, { name = "fastapi" }, { name = "httpx" }, @@ -1076,8 +1111,16 @@ dependencies = [ { name = "uvicorn" }, ] +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "ruff" }, + { name = "watchfiles" }, +] + [package.metadata] requires-dist = [ + { name = "cloudkv", specifier = ">=0.3.0" }, { name = "devtools", specifier = ">=0.12.2" }, { name = "fastapi", specifier = ">=0.115.14" }, { name = "httpx", specifier = ">=0.28.1" }, @@ -1088,6 +1131,13 @@ requires-dist = [ { name = "uvicorn", specifier = ">=0.34.3" }, ] +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.402" }, + { name = "ruff", specifier = ">=0.12.1" }, + { name = "watchfiles", specifier = ">=1.1.0" }, +] + [[package]] name = "pydantic-evals" version = "0.3.4" From 9ce319f6370e8cbf561dd2fe2b759be238bead08 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 2 Jul 2025 13:47:16 +0100 Subject: [PATCH 2/6] working self-improving agents --- .python-version | 2 +- human-seeded-evals/app/agent.py | 16 +++---- human-seeded-evals/app/main.py | 1 + .../app/self_improving_agent.py | 6 +-- .../app/self_improving_agent_storage.py | 45 +++++++++++++++++- human-seeded-evals/update_sia.py | 20 ++++++++ pyproject.toml | 4 -- uv.lock | 47 ------------------- 8 files changed, 77 insertions(+), 64 deletions(-) create mode 100644 human-seeded-evals/update_sia.py diff --git a/.python-version b/.python-version index e4fba21..24ee5b1 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.12 +3.13 diff --git a/human-seeded-evals/app/agent.py b/human-seeded-evals/app/agent.py index a4a8984..5202291 100644 --- a/human-seeded-evals/app/agent.py +++ b/human-seeded-evals/app/agent.py @@ -6,13 +6,12 @@ from datetime import datetime from typing import AsyncIterator -from cloudkv import AsyncCloudKV from pydantic_ai import Agent, RunContext from pydantic_ai.models import Model from .models import TimeRangeInputs, TimeRangeResponse from .self_improving_agent import SelfImprovingAgentModel -from .self_improving_agent_storage import CloudKVStorage +from .self_improving_agent_storage import LocalStorage @dataclass @@ -32,13 +31,14 @@ class TimeRangeDeps: @asynccontextmanager async def self_improving_model() -> AsyncIterator[SelfImprovingAgentModel]: - cloudkv_read_token, cloudkv_write_token = os.environ['CLOUDKV_TOKEN'].split('.') logfire_read_token = os.environ['LOGFIRE_READ_TOKEN'] - async with AsyncCloudKV(cloudkv_read_token, cloudkv_write_token) as cloudkv: - storage = CloudKVStorage(cloudkv) - m = SelfImprovingAgentModel('anthropic:claude-sonnet-4-0', storage, logfire_read_token, 'time_range_agent') - yield m - await m.wait_for_coach() + # cloudkv_read_token, cloudkv_write_token = os.environ['CLOUDKV_TOKEN'].split('.') + # async with AsyncCloudKV(cloudkv_read_token, cloudkv_write_token) as cloudkv: + # storage = CloudKVStorage(cloudkv) + storage = LocalStorage() + m = SelfImprovingAgentModel('anthropic:claude-sonnet-4-0', storage, logfire_read_token, 'time_range_agent') + yield m + await m.wait_for_coach() @time_range_agent.instructions diff --git a/human-seeded-evals/app/main.py b/human-seeded-evals/app/main.py index 8157ea4..847288d 100644 --- a/human-seeded-evals/app/main.py +++ b/human-seeded-evals/app/main.py @@ -11,6 +11,7 @@ logfire.configure(environment='dev') logfire.instrument_pydantic_ai() +logfire.instrument_httpx() @asynccontextmanager diff --git a/human-seeded-evals/app/self_improving_agent.py b/human-seeded-evals/app/self_improving_agent.py index 85c7243..0e95056 100644 --- a/human-seeded-evals/app/self_improving_agent.py +++ b/human-seeded-evals/app/self_improving_agent.py @@ -429,9 +429,9 @@ def get_instrunctions(messages: list[ModelMessage]) -> str | None: def get_tools_fields(tools: list[ToolDefinition], prefix: str, description: str) -> Iterable[FieldDetails]: for t in tools: - prefix = f'{prefix}.{escape_key(t.name)}' - yield FieldDetails(f'{prefix}.description', description, t.description) - yield from json_schema_fields(t.parameters_json_schema, f'{prefix}.parameters') + tool_prefix = f'{prefix}.{escape_key(t.name)}' + yield FieldDetails(f'{tool_prefix}.description', description, t.description) + yield from json_schema_fields(t.parameters_json_schema, f'{tool_prefix}.parameters') JsonSchema = dict[str, Any] diff --git a/human-seeded-evals/app/self_improving_agent_storage.py b/human-seeded-evals/app/self_improving_agent_storage.py index f1c872a..d3f33fd 100644 --- a/human-seeded-evals/app/self_improving_agent_storage.py +++ b/human-seeded-evals/app/self_improving_agent_storage.py @@ -1,7 +1,10 @@ +import asyncio from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import timedelta -from typing import AsyncIterator +from functools import partial +from pathlib import Path +from typing import AsyncIterator, Callable, ParamSpec, TypeVar from cloudkv import AsyncCloudKV @@ -30,3 +33,43 @@ async def lock(self, agent_name: str) -> AsyncIterator[bool]: await self.cloud_kv.delete(key) else: yield False + + +@dataclass +class LocalStorage(SelfImprovingAgentStorage): + directory: Path = Path('.self-improving-agent') + + def __post_init__(self): + self.directory.mkdir(exist_ok=True) + + async def get_patch(self, agent_name: str) -> ModelContextPatch | None: + file = self.directory / f'{agent_name}.json' + if file.exists(): + content = await asyncify(file.read_bytes) + return ModelContextPatch.model_validate_json(content) + + async def set_patch(self, agent_name: str, patch: ModelContextPatch, expires: timedelta) -> None: + # note we're ignoring expiry here + file = self.directory / f'{agent_name}.json' + content = patch.model_dump_json(indent=2) + await asyncify(file.write_text, content) + + @asynccontextmanager + async def lock(self, agent_name: str) -> AsyncIterator[bool]: + file = self.directory / f'lock:{agent_name}' + if not await asyncify(file.exists): + await asyncify(file.touch) + try: + yield True + finally: + await asyncify(file.unlink) + else: + yield False + + +P = ParamSpec('P') +R = TypeVar('R') + + +async def asyncify(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: + return await asyncio.get_event_loop().run_in_executor(None, partial(func, *args, **kwargs)) diff --git a/human-seeded-evals/update_sia.py b/human-seeded-evals/update_sia.py new file mode 100644 index 0000000..bf23dc7 --- /dev/null +++ b/human-seeded-evals/update_sia.py @@ -0,0 +1,20 @@ +import asyncio + +import logfire +from app.agent import infer_time_range, self_improving_model +from app.models import TimeRangeInputs + +logfire.configure(environment='evals') + +logfire.instrument_pydantic_ai() + + +async def main(): + async with self_improving_model() as model: + with model.blocking_context(): + with logfire.span('running infer_time_range with blocking coach'): + await infer_time_range(TimeRangeInputs(prompt='yesterday'), model=model) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index bacbbd2..093c386 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,10 +19,6 @@ dependencies = [ [dependency-groups] dev = ["pyright>=1.1.402", "ruff>=0.12.1", "watchfiles>=1.1.0"] - -[tool.uv.workspace] -members = ["human-seeded-evals"] - [tool.ruff] line-length = 120 target-version = "py39" diff --git a/uv.lock b/uv.lock index 0de9ecc..6532d81 100644 --- a/uv.lock +++ b/uv.lock @@ -2,12 +2,6 @@ version = 1 revision = 2 requires-python = ">=3.12" -[manifest] -members = [ - "human-seeded-evals", - "pydantic-demo", -] - [[package]] name = "annotated-types" version = "0.7.0" @@ -490,47 +484,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d0/fb/5307bd3612eb0f0e62c3a916ae531d3a31e58fb5c82b58e3ebf7fd6f47a1/huggingface_hub-0.33.1-py3-none-any.whl", hash = "sha256:ec8d7444628210c0ba27e968e3c4c973032d44dcea59ca0d78ef3f612196f095", size = 515377, upload-time = "2025-06-25T12:02:55.611Z" }, ] -[[package]] -name = "human-seeded-evals" -version = "0.1.0" -source = { virtual = "human-seeded-evals" } -dependencies = [ - { name = "cloudkv" }, - { name = "devtools" }, - { name = "fastapi" }, - { name = "httpx" }, - { name = "logfire", extra = ["fastapi", "httpx"] }, - { name = "pydantic" }, - { name = "pydantic-ai" }, - { name = "uvicorn" }, -] - -[package.dev-dependencies] -dev = [ - { name = "pyright" }, - { name = "ruff" }, - { name = "watchfiles" }, -] - -[package.metadata] -requires-dist = [ - { name = "cloudkv", specifier = ">=0.3.0" }, - { name = "devtools", specifier = ">=0.12.2" }, - { name = "fastapi", specifier = ">=0.115.14" }, - { name = "httpx", specifier = ">=0.28.1" }, - { name = "logfire", extras = ["fastapi", "httpx"], specifier = ">=3.21.1" }, - { name = "pydantic", specifier = ">=2.11.7" }, - { name = "pydantic-ai", specifier = ">=0.3.4" }, - { name = "uvicorn", specifier = ">=0.34.3" }, -] - -[package.metadata.requires-dev] -dev = [ - { name = "pyright", specifier = ">=1.1.402" }, - { name = "ruff", specifier = ">=0.12.1" }, - { name = "watchfiles", specifier = ">=1.1.0" }, -] - [[package]] name = "idna" version = "3.10" From 06fdca0487bbf714a0a2d8f32b541bd0d602c3fa Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 3 Jul 2025 08:37:36 +0100 Subject: [PATCH 3/6] convert inject_current_time to a tool --- human-seeded-evals/app/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/human-seeded-evals/app/agent.py b/human-seeded-evals/app/agent.py index 5202291..1692110 100644 --- a/human-seeded-evals/app/agent.py +++ b/human-seeded-evals/app/agent.py @@ -41,7 +41,7 @@ async def self_improving_model() -> AsyncIterator[SelfImprovingAgentModel]: await m.wait_for_coach() -@time_range_agent.instructions +@time_range_agent.tool def inject_current_time(ctx: RunContext[TimeRangeDeps]) -> str: """Add the user's current time and timezone in the format 'Friday, November 22, 2024 11:15:14 PST' to context.""" return f"The user's current time is {ctx.deps.now:%A, %B %d, %Y %H:%M:%S %Z}." From f604a8534b8b25108e3d935c7d0411530321426b Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 15 Jul 2025 08:22:31 -0700 Subject: [PATCH 4/6] new agent context view --- human-seeded-evals/app/agent.py | 6 +- human-seeded-evals/app/main.py | 20 +++- human-seeded-evals/frontend/src/App.tsx | 27 ++++- human-seeded-evals/frontend/src/api.ts | 15 +++ .../frontend/src/components/Card.tsx | 12 -- .../frontend/src/components/Input.tsx | 2 +- .../frontend/src/components/PromptView.tsx | 108 ++++++++++++++++++ .../frontend/src/components/TimeConverter.tsx | 5 +- human-seeded-evals/update_sia.py | 2 +- 9 files changed, 176 insertions(+), 21 deletions(-) delete mode 100644 human-seeded-evals/frontend/src/components/Card.tsx create mode 100644 human-seeded-evals/frontend/src/components/PromptView.tsx diff --git a/human-seeded-evals/app/agent.py b/human-seeded-evals/app/agent.py index 1692110..058da37 100644 --- a/human-seeded-evals/app/agent.py +++ b/human-seeded-evals/app/agent.py @@ -30,14 +30,16 @@ class TimeRangeDeps: @asynccontextmanager -async def self_improving_model() -> AsyncIterator[SelfImprovingAgentModel]: +async def self_improving_model() -> AsyncIterator[tuple[LocalStorage, SelfImprovingAgentModel]]: logfire_read_token = os.environ['LOGFIRE_READ_TOKEN'] # cloudkv_read_token, cloudkv_write_token = os.environ['CLOUDKV_TOKEN'].split('.') # async with AsyncCloudKV(cloudkv_read_token, cloudkv_write_token) as cloudkv: # storage = CloudKVStorage(cloudkv) storage = LocalStorage() m = SelfImprovingAgentModel('anthropic:claude-sonnet-4-0', storage, logfire_read_token, 'time_range_agent') - yield m + + yield storage, m + await m.wait_for_coach() diff --git a/human-seeded-evals/app/main.py b/human-seeded-evals/app/main.py index 847288d..07a800d 100644 --- a/human-seeded-evals/app/main.py +++ b/human-seeded-evals/app/main.py @@ -3,10 +3,12 @@ import logfire from fastapi import FastAPI, Request +from pydantic import BaseModel from .agent import infer_time_range, self_improving_model from .models import TimeRangeInputs, TimeRangeResponse from .self_improving_agent import SelfImprovingAgentModel +from .self_improving_agent_storage import LocalStorage logfire.configure(environment='dev') @@ -16,7 +18,8 @@ @asynccontextmanager async def lifespan(app: FastAPI): - async with self_improving_model() as model: + async with self_improving_model() as (storage, model): + app.state.storage = storage app.state.model = model yield @@ -29,3 +32,18 @@ async def lifespan(app: FastAPI): async def convert_time_range(request: Request, time_range_inputs: TimeRangeInputs) -> TimeRangeResponse: model = cast(SelfImprovingAgentModel, request.app.state.model) return await infer_time_range(time_range_inputs, model=model) + + +class Field(BaseModel): + id: str + text: str + + +@app.get('/api/context') +async def get_agent_context(request: Request) -> list[Field]: + storage = cast(LocalStorage, request.app.state.storage) + patch = await storage.get_patch('time_range_agent') + if not patch: + return [] + else: + return [Field(id=key, text=value) for key, value in patch.context_patch.items()] diff --git a/human-seeded-evals/frontend/src/App.tsx b/human-seeded-evals/frontend/src/App.tsx index b7654e6..a89768e 100644 --- a/human-seeded-evals/frontend/src/App.tsx +++ b/human-seeded-evals/frontend/src/App.tsx @@ -1,7 +1,32 @@ +import { useState, useEffect } from 'react'; import TimeConverter from './components/TimeConverter'; +import { PromptView } from './components/PromptView'; function App() { - return ; + const [currentView, setCurrentView] = useState(''); + + useEffect(() => { + const handleHashChange = () => { + setCurrentView(window.location.hash.slice(1)); + }; + + // Set initial view + handleHashChange(); + + // Listen for hash changes + window.addEventListener('hashchange', handleHashChange); + + return () => { + window.removeEventListener('hashchange', handleHashChange); + }; + }, []); + + if (currentView === 'agent-context') { + return ; + } else { + return ; + } + } export default App diff --git a/human-seeded-evals/frontend/src/api.ts b/human-seeded-evals/frontend/src/api.ts index a5d0a52..3f07f05 100644 --- a/human-seeded-evals/frontend/src/api.ts +++ b/human-seeded-evals/frontend/src/api.ts @@ -8,6 +8,11 @@ export interface ConversionError { error: string; } +export interface Field { + id: string; // The label text + text: string; // Default input value +} + export async function convertTimeInterval(prompt: string): Promise { try { const response = await fetch('/api/timerange', { @@ -33,3 +38,13 @@ export async function convertTimeInterval(prompt: string): Promise { + const response = await fetch('/api/context'); + + if (!response.ok) { + const errorData = await response.json().catch(() => ({})); + throw new Error(errorData.error || `Server error: ${response.status}`); + } + return await response.json(); +} diff --git a/human-seeded-evals/frontend/src/components/Card.tsx b/human-seeded-evals/frontend/src/components/Card.tsx deleted file mode 100644 index 2d32d37..0000000 --- a/human-seeded-evals/frontend/src/components/Card.tsx +++ /dev/null @@ -1,12 +0,0 @@ -interface CardProps { - children: React.ReactNode; - className?: string; -} - -export default function Card({ children, className = '' }: CardProps) { - return ( -
- {children} -
- ); -} \ No newline at end of file diff --git a/human-seeded-evals/frontend/src/components/Input.tsx b/human-seeded-evals/frontend/src/components/Input.tsx index d563170..2bd6e04 100644 --- a/human-seeded-evals/frontend/src/components/Input.tsx +++ b/human-seeded-evals/frontend/src/components/Input.tsx @@ -65,4 +65,4 @@ export default function Input({ ); -} \ No newline at end of file +} diff --git a/human-seeded-evals/frontend/src/components/PromptView.tsx b/human-seeded-evals/frontend/src/components/PromptView.tsx new file mode 100644 index 0000000..8a4c6ad --- /dev/null +++ b/human-seeded-evals/frontend/src/components/PromptView.tsx @@ -0,0 +1,108 @@ +import React, { useState, useEffect } from 'react'; +import { getFields, type Field } from '../api'; + +export function PromptView() { + const [fields, setFields] = useState([]); + const [formData, setFormData] = useState>({}); + const [loading, setLoading] = useState(true); + + useEffect(() => { + async function loadFields() { + try { + const fieldsData = await getFields(); + setFields(fieldsData); + const initialData: Record = {}; + fieldsData.forEach(field => { + initialData[field.id] = field.text; + }); + setFormData(initialData); + } finally { + setLoading(false); + } + } + loadFields(); + }, []); + + const handleInputChange = (fieldId: string, value: string) => { + setFormData(prev => ({ + ...prev, + [fieldId]: value + })); + }; + + const handleSubmit = (e: React.FormEvent) => { + e.preventDefault(); + console.log('Form submitted:', formData); + // Handle form submission here + }; + + const handleBack = () => { + window.location.href = '/'; + }; + + if (loading) { + return ( +
+
Loading...
+
+ ); + } + + return ( +
+
+
+ +

Agent Context Form

+
+
+ +
+ {fields.length === 0 ? ( +
+ No agent context available +
+ ) : ( +
+ {fields.map(field => ( +
+ +