diff --git a/.python-version b/.python-version index e4fba21..24ee5b1 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.12 +3.13 diff --git a/human-seeded-evals/app/agent.py b/human-seeded-evals/app/agent.py index 246e4b0..7f57d0d 100644 --- a/human-seeded-evals/app/agent.py +++ b/human-seeded-evals/app/agent.py @@ -1,11 +1,13 @@ from __future__ import annotations as _annotations +import os from dataclasses import dataclass from datetime import datetime from pydantic_ai import Agent, RunContext from .models import TimeRangeInputs, TimeRangeResponse +from .self_improving_agent import Coach, SelfImprovingAgentModel @dataclass @@ -13,17 +15,22 @@ class TimeRangeDeps: now: datetime -instrunctions = "Convert the user's request into a structured time range." +system_prompt = "Convert the user's request into a structured time range." time_range_agent = Agent[TimeRangeDeps, TimeRangeResponse]( 'anthropic:claude-sonnet-4-0', output_type=TimeRangeResponse, # type: ignore # we can't yet annotate something as receiving a TypeForm deps_type=TimeRangeDeps, - instructions=instrunctions, + system_prompt=system_prompt, retries=1, ) -@time_range_agent.instructions +def get_coach() -> Coach: + logfire_read_token = os.environ['LOGFIRE_READ_TOKEN'] + return Coach('time_range_agent', logfire_read_token) + + +@time_range_agent.tool def inject_current_time(ctx: RunContext[TimeRangeDeps]) -> str: """Add the user's current time and timezone in the format 'Friday, November 22, 2024 11:15:14 PST' to context.""" return f"The user's current time is {ctx.deps.now:%A, %B %d, %Y %H:%M:%S %Z}." @@ -31,5 +38,6 @@ def inject_current_time(ctx: RunContext[TimeRangeDeps]) -> str: async def infer_time_range(inputs: TimeRangeInputs) -> TimeRangeResponse: """Infer a time range from a user prompt.""" - result = await time_range_agent.run(inputs.prompt, deps=TimeRangeDeps(now=inputs.now)) + model = SelfImprovingAgentModel('anthropic:claude-sonnet-4-0') + result = await time_range_agent.run(inputs.prompt, deps=TimeRangeDeps(now=inputs.now), model=model) return result.output diff --git a/human-seeded-evals/app/main.py b/human-seeded-evals/app/main.py index f04baa1..6b2f244 100644 --- a/human-seeded-evals/app/main.py +++ b/human-seeded-evals/app/main.py @@ -1,11 +1,19 @@ +from datetime import datetime, timezone + import logfire from fastapi import FastAPI +from pydantic import BaseModel -from .agent import infer_time_range +from .agent import get_coach, infer_time_range from .models import TimeRangeInputs, TimeRangeResponse +from .self_improving_agent import ModelContextPatch logfire.configure(environment='dev') + logfire.instrument_pydantic_ai() +logfire.instrument_httpx(capture_all=True) +coach = get_coach() + app = FastAPI() logfire.instrument_fastapi(app) @@ -14,3 +22,36 @@ @app.post('/api/timerange') async def convert_time_range(time_range_inputs: TimeRangeInputs) -> TimeRangeResponse: return await infer_time_range(time_range_inputs) + + +class Field(BaseModel): + id: str + text: str + + +@app.get('/api/context') +def get_agent_context() -> list[Field]: + coach_fields = coach.get_fields() or [] + fields = [Field(id=f.key, text=f.current_prompt or '') for f in coach_fields] + + if patch := coach.get_patch(): + for field in fields: + if new_text := patch.context_patch.get(field.id): + field.text = new_text + + return fields + + +class PostFields(BaseModel): + fields: list[Field] + + +@app.post('/api/context') +def post_agent_context(m: PostFields): + context_patch = {f.id: f.text for f in m.fields if f.text} + coach.update_patch(ModelContextPatch(context_patch=context_patch, timestamp=datetime.now(tz=timezone.utc))) + + +@app.post('/api/context/update') +async def post_update_agent_context(): + await coach.run() diff --git a/human-seeded-evals/app/self_improving_agent.py b/human-seeded-evals/app/self_improving_agent.py new file mode 100644 index 0000000..95e3813 --- /dev/null +++ b/human-seeded-evals/app/self_improving_agent.py @@ -0,0 +1,504 @@ +from __future__ import annotations as _annotations + +import csv +import io +from copy import deepcopy +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Annotated, Any, Iterable, Literal, Protocol, TypeAlias, TypedDict, TypeGuard, cast + +from annotated_types import Ge, Le +from logfire import Logfire +from logfire.experimental.query_client import AsyncLogfireQueryClient +from pydantic import AwareDatetime, BaseModel, Field, TypeAdapter +from pydantic_ai import Agent, format_as_xml +from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, SystemPromptPart +from pydantic_ai.models import KnownModelName, Model, ModelRequestParameters, infer_model +from pydantic_ai.models.wrapper import WrapperModel +from pydantic_ai.settings import ModelSettings +from pydantic_ai.tools import ObjectJsonSchema, ToolDefinition + +logfire = Logfire(otel_scope='self-improving-agent') +FieldsPatch = dict[str, str] + + +@dataclass +class FieldDetails: + key: str + description: str + current_prompt: str | None = None + + +fields_path = Path('agent_context_fields.json') +fields_schema = TypeAdapter(list[FieldDetails]) + + +class ModelContextPatch(BaseModel): + context_patch: FieldsPatch + timestamp: AwareDatetime + + +class AbstractCoachOutput(Protocol): + context_patch: FieldsPatch + developer_suggestions: str | None + overall_context_score: int + + +patch_path = Path('agent_context_patches.json') + + +@dataclass(init=False) +class SelfImprovingAgentModel(WrapperModel): + wrapped_model: Model + + async def request( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> ModelResponse: + fields = list(context_patch_fields(messages, model_request_parameters)) + fields_path.write_bytes(fields_schema.dump_json(fields, indent=2)) + + if patch := Coach.get_patch(): + messages, model_request_parameters = apply_patch(messages, model_request_parameters, patch.context_patch) + + return await super().request(messages, model_settings, model_request_parameters) + + @property + def model_name(self) -> str: + """The model name.""" + return self.inner_model.model_name + + @property + def system(self) -> str: + """The system prompt.""" + return self.inner_model.system + + +@dataclass +class Coach: + agent_name: str + logfire_read_token: str + logfire_environment: str | None = None + """Name of the environment in logfire where the main agent is running, improves query performance""" + logfire_filter: str | None = None + """Additional logfire filter when looking for agent run traces to improve performance""" + + coach_model: Model | KnownModelName = 'anthropic:claude-opus-4-0' + """Model used for the coach agent""" + + async def run(self): + fields = self.get_fields() + last_path = self.get_patch() + runs, last_run = await self._get_runs(last_path and last_path.timestamp) + + prompt_data: dict[str, Any] = { + 'default_model_context': {f.key: f.current_prompt for f in fields if f.current_prompt} + } + if last_path: + prompt_data['previous_context_patch'] = last_path.context_patch + if runs: + prompt_data['recent_agent_runs'] = runs + prompt = format_as_xml(prompt_data, include_root_tag=False) + coach_agent = self._get_coach_agent(fields) + r = await coach_agent.run(prompt) + run_count = len(runs) if runs is not None else None + if r.output.overall_context_score < 5: + logfire.warning( + 'Coach run with quality warning, score={output.overall_context_score} {run_count=}', + output=r.output, + run_count=run_count, + ) + else: + logfire.info( + 'Coach run with quality ok, score={output.overall_context_score} {run_count=}', + output=r.output, + run_count=run_count, + ) + patch = ModelContextPatch(context_patch=r.output.context_patch, timestamp=last_run) + self.update_patch(patch) + + def get_fields(self) -> list[FieldDetails]: + return fields_schema.validate_json(fields_path.read_bytes()) + + @staticmethod + def get_patch() -> ModelContextPatch | None: + try: + return ModelContextPatch.model_validate_json(patch_path.read_bytes()) + except FileNotFoundError: + return None + + def update_patch(self, patch: ModelContextPatch) -> None: + patch_path.write_text(patch.model_dump_json(indent=2)) + + async def _get_runs(self, min_timestamp: datetime | None) -> tuple[list[AgentRunSummary] | None, datetime]: + async with AsyncLogfireQueryClient(self.logfire_read_token) as client: + runs_where = ["otel_scope_name='pydantic-ai'", f"message='{self.agent_name} run'"] + if self.logfire_environment: + runs_where.append(f"deployment_environment='{self.logfire_environment}'") + if self.logfire_filter: + runs_where.append(self.logfire_filter) + sql = runs_query.format(where=' AND '.join(runs_where)) + min_timestamp = datetime.now(tz=timezone.utc) - timedelta(hours=2) + r = await client.query_json_rows(sql=sql, min_timestamp=min_timestamp) + runs_rows = r['rows'] + count = len(runs_rows) + if not count: + logfire.info('Found {run_count} runs', run_count=count) + return None, datetime.now(tz=timezone.utc) - timedelta(seconds=30) + + created_ats = datetime_list_schema.validate_python([row['created_at'] for row in runs_rows]) + last_run = max(created_ats) + + r = await client.query_json_rows(sql=feedback_query, min_timestamp=min_timestamp) + feedback_lookup: dict[str, Any] = { + '{trace_id}-{parent_span_id}'.format(**row): RunFeedback(**row) for row in r['rows'] + } + + runs: list[AgentRunSummary] = [] + feedback_count = 0 + for row in runs_rows: + if feedback := feedback_lookup.get('{trace_id}-{span_id}'.format(**row)): + row['feedback'] = feedback + feedback_count += 1 + run = AgentRunSummary.model_validate(row) + if run.prompt is not None: + runs.append(run) + + logfire.info( + 'Found {run_count} runs, {feedback_count} with feedback, running coach', + run_count=count, + feedback_count=feedback_count, + ) + return runs, last_run + + def _get_coach_agent(self, fields: list[FieldDetails]) -> Agent[None, AbstractCoachOutput]: + fields_dict = {f.key: Annotated[str, Field(description=f.description)] for f in fields} + ModelRequestFields = TypedDict( + 'ModelRequestFields', + fields_dict, # type: ignore + total=False, + ) + + class CoachOutput(BaseModel, use_attribute_docstrings=True): + context_patch: ModelRequestFields + """Patch to update context fields to improve the agent's performance.""" + developer_suggestions: str | None = None + """Suggestions to the developer about how to improve the agent code.""" + overall_context_score: Annotated[int, Ge(0), Le(10)] + """Overall quality of the context, on a scale from zero to ten, zero being the worst, ten being the best. + + Any value below 5 will trigger a warning to the agent maintainers. + """ + + coach_model = infer_model(self.coach_model) + self._coach_agent = agent = cast( + Agent[None, AbstractCoachOutput], + Agent(coach_model, output_type=CoachOutput, instructions=coach_instrunctions), + ) + return agent + + +class RunFeedback(BaseModel): + reaction: Literal['positive', 'negative'] | None + comment: str | None + + +class AgentRunSummary(BaseModel): + prompt: str | None + output: Any + feedback: RunFeedback | None = None + + +datetime_list_schema = TypeAdapter(list[AwareDatetime]) + +runs_query = """ +select + created_at, + trace_id, + span_id, + attributes->'all_messages_events'->1->>'content' as prompt, + attributes->'final_result' as output +from records +where {where} +order by created_at desc +limit 20 +""" +feedback_query = """ +select + trace_id, + parent_span_id, + attributes->>'Annotation' as reaction, + attributes->>'logfire.feedback.comment' as comment +from records +where + kind='annotation' and + attributes->>'logfire.feedback.name'='Annotation' +order by created_at desc +-- bigger limit to get all feedback linked to relevant runs +limit 200 +""" +# this is a rough prompt, can almost certainly be improved +coach_instrunctions = """\ +Your job is to improve the performance of an AI agent by analyzing the context provided to the model +and the agent's behavior (inputs, outputs and feedback where available), +then rewriting context where you are confident that it will improve the agent's performance. +To do this return a patch of model context prompts and descriptions where appropriate. + +Pay special attention to the `instructions` or `system_prompt` fields as they have the most significant impact on the agent's behavior. + +Be concise and clear: increasing text length will increase token usage and thereby cost. + +If you identify shortcomings in the context provided to the model that cannot be solved by adjusting the instructions +and tool descriptions, please suggest improvements the developer should make. + +YOU SHOULD ONLY INCLUDE SUGGESTIONS IF THERE ARE SHORTCOMINGS IN THE CONTEXT PROVIDED TO THE MODEL +THAT CANNOT BE SOLVED BY ADJUSTING THE INSTRUCTIONS AND TOOL DESCRIPTIONS. +""" + + +def context_patch_fields( + messages: list[ModelMessage], model_request_parameters: ModelRequestParameters +) -> Iterable[FieldDetails]: + found_sys_prompt = False + if system_prompt := get_system_prompt(messages): + found_sys_prompt = True + yield FieldDetails('system_prompt', 'System prompt', system_prompt) + + instructions = get_instrunctions(messages) + if instructions or not found_sys_prompt: + yield FieldDetails('instructions', 'Instructions', instructions) + + yield from get_tools_fields(model_request_parameters.function_tools, 'function_tools', 'Function tool description') + + yield from get_tools_fields(model_request_parameters.output_tools, 'output_tools', 'Output tool description') + + +def get_system_prompt(messages: list[ModelMessage]) -> str | None: + """Get the first system prompt from messages, other system prompts are ignored.""" + for message in messages: + if isinstance(message, ModelRequest): + for part in message.parts: + if isinstance(part, SystemPromptPart): + return part.content + + +def get_instrunctions(messages: list[ModelMessage]) -> str | None: + """Get the first instruction from messages, other instructions are ignored.""" + for message in messages: + if isinstance(message, ModelRequest): + if message.instructions: + return message.instructions + + +def get_tools_fields(tools: list[ToolDefinition], prefix: str, description: str) -> Iterable[FieldDetails]: + for t in tools: + tool_prefix = f'{prefix}.{escape_key(t.name)}' + yield FieldDetails(f'{tool_prefix}.description', description, t.description) + yield from json_schema_fields(t.parameters_json_schema, f'{tool_prefix}.parameters') + + +JsonSchema = dict[str, Any] + + +def json_schema_fields(schema: JsonSchema, prefix: str) -> Iterable[FieldDetails]: + yield FieldDetails(f'{prefix}.description', 'JSON schema field description', schema.get('description')) + + type_ = schema.get('type') + if type_ == 'object': + yield from _js_object(schema, prefix) + elif type_ == 'array': + yield from _js_array(schema, prefix) + elif type_ is None: + yield from _js_union(schema, prefix, 'anyOf') + yield from _js_union(schema, prefix, 'oneOf') + + +def _js_object(schema: ObjectJsonSchema, prefix: str) -> Iterable[FieldDetails]: + if properties := schema.get('properties'): + for key, value in properties.items(): + yield from json_schema_fields(value, f'{prefix}.properties.{escape_key(key)}') + + if additional_properties := schema.get('additionalProperties'): + if _is_json_schema(additional_properties): + yield from json_schema_fields(additional_properties, f'{prefix}.additionalProperties') + + if pattern_properties := schema.get('patternProperties'): + for key, value in pattern_properties.items(): + yield from json_schema_fields(value, f'{prefix}.patternProperties.{escape_key(key)}') + + +def _js_array(schema: ObjectJsonSchema, prefix: str) -> Iterable[FieldDetails]: + if prefix_items := schema.get('prefixItems'): + assert isinstance(prefix_items, list), f'Expected list for prefixItems, got {type(prefix_items)}' + for i, item in enumerate(cast(list[Any], prefix_items)): + if _is_json_schema(item): + yield from json_schema_fields(item, f'{prefix}.prefixItems.{i}') + + if items := schema.get('items'): + if _is_json_schema(items): + yield from json_schema_fields(items, f'{prefix}.items') + + +def _js_union(schema: JsonSchema, prefix: str, union_kind: Literal['anyOf', 'oneOf']) -> Iterable[FieldDetails]: + members = schema.get(union_kind) + if not members: + return + + for member in members: + if _is_json_schema(member): + yield from json_schema_fields(member, f'{prefix}.{union_kind}') + + +def escape_key(s: str) -> str: + if '.' in s: + # double double quotes matches how the csv module parses strings + return '"' + s.replace('"', '""') + '"' + else: + return s + + +def apply_patch( + messages: list[ModelMessage], model_request_parameters: ModelRequestParameters, patch: FieldsPatch +) -> tuple[list[ModelMessage], ModelRequestParameters]: + if not patch: + return messages, model_request_parameters + + messages = deepcopy(messages) + model_request_parameters = deepcopy(model_request_parameters) + changes = 0 + + nested_patch = unflatten(patch) + if system_prompt := nested_patch.get('system_prompt'): + assert isinstance(system_prompt, str), f'Expected str for system_prompt, got {type(system_prompt)}' + if set_system_prompt(messages, system_prompt): + changes += 1 + else: + logfire.warning('No system prompt found to replace') + + if instructions := nested_patch.get('instructions'): + assert isinstance(instructions, str), f'Expected str for instructions, got {type(instructions)}' + if set_instructions(messages, instructions): + changes += 1 + else: + logfire.warning('No instructions found to replace') + + changes += set_tools_fields(model_request_parameters.function_tools, 'function_tools', nested_patch) + changes += set_tools_fields(model_request_parameters.output_tools, 'output_tools', nested_patch) + + logfire.info('updated {changes} fields in messages and model request parameters', changes=changes) + return messages, model_request_parameters + + +UnflattenedPatch: TypeAlias = 'dict[str, UnflattenedPatch | str]' + + +def unflatten(patch: FieldsPatch) -> UnflattenedPatch: + d: UnflattenedPatch = {} + for key, value in patch.items(): + local_d = d + *parts, last = split_key(key) + for part in parts: + local_d = local_d.setdefault(part, {}) + assert isinstance(local_d, dict), f'Expected dict at {part}, got {type(local_d)}' + local_d[last] = value + return d + + +def set_system_prompt(messages: list[ModelMessage], system_prompt: str) -> bool: + for message in messages: + if isinstance(message, ModelRequest): + for part in message.parts: + if isinstance(part, SystemPromptPart): + part.content = system_prompt + return True + return False + + +def set_instructions(messages: list[ModelMessage], instructions: str) -> bool: + for message in messages: + if isinstance(message, ModelRequest): + if message.instructions is not None: + message.instructions = instructions + return True + + # if we didn't find existing instructions to replace, set instructions on the first model request + for message in messages: + if isinstance(message, ModelRequest): + message.instructions = instructions + return True + + return False + + +def set_tools_fields(tools: list[ToolDefinition], key: str, patch: UnflattenedPatch) -> int: + tools_patch = patch.get(key) + changes = 0 + if not tools_patch: + return changes + assert isinstance(tools_patch, dict), f'Expected dict at {key}, got {type(tools_patch)}' + for tool_name, tool_patch in tools_patch.items(): + assert isinstance(tool_name, str), f'Expected str at {key}.{tool_name}, got {type(tool_name)}' + assert isinstance(tool_patch, dict), f'Expected dict at {key}.{tool_name}, got {type(tool_patch)}' + tool = next((t for t in tools if t.name == tool_name), None) + assert tool is not None, f'Unable to find tool {key}.{tool_name}' + + if description := tool_patch.get('description'): + assert isinstance(description, str), ( + f'Expected str at {key}.{tool_name}.description, got {type(description)}' + ) + tool.description = description + changes += 1 + + if parameters := tool_patch.get('parameters'): + assert isinstance(parameters, dict), ( + f'Expected dict at {key}.{tool_name}.parameters, got {type(parameters)}' + ) + changes += update_json_schema(tool.parameters_json_schema, parameters, [key, tool_name, 'parameters']) + return changes + + +def update_json_schema(schema: JsonSchema, patch: UnflattenedPatch, path: list[str]) -> int: + changes = 0 + patch_copy = patch.copy() + if description := patch_copy.pop('description', None): + assert isinstance(description, str), f'Expected str at {".".join(path)}.description, got {type(description)}' + schema['description'] = description + changes += 1 + + for k, v in patch_copy.items(): + sub_path = path + [k] + assert isinstance(v, dict), f'Expected dict at {".".join(sub_path)}, got {type(v)}' + + sub_schema = schema.get(k) + if not sub_schema: + print('WARNING: Schema key not found') + + if _is_json_schema(sub_schema): + changes += update_json_schema(sub_schema, v, sub_path) + else: + assert isinstance(sub_schema, list), ( + f'Expected dict or list at {".".join(sub_path)}, got {type(sub_schema)}' + ) + + for k2, v2 in patch_copy.items(): + sub_sub_path = sub_path + [k2] + array_schema = cast(JsonSchema, sub_schema[int(k2)]) + assert isinstance(v2, dict), f'Expected dict at {".".join(sub_sub_path)}, got {type(v)}' + changes += update_json_schema(array_schema, v2, sub_sub_path) + + return changes + + +def split_key(s: str) -> list[str]: + if '"' in s: + # quotes in the string means we have to parse it properly + return next(csv.reader(io.StringIO(s), delimiter='.')) + else: + return s.split('.') + + +def _is_json_schema(obj: Any) -> TypeGuard[JsonSchema]: + return isinstance(obj, dict) diff --git a/human-seeded-evals/frontend/src/App.tsx b/human-seeded-evals/frontend/src/App.tsx index b7654e6..a89768e 100644 --- a/human-seeded-evals/frontend/src/App.tsx +++ b/human-seeded-evals/frontend/src/App.tsx @@ -1,7 +1,32 @@ +import { useState, useEffect } from 'react'; import TimeConverter from './components/TimeConverter'; +import { PromptView } from './components/PromptView'; function App() { - return ; + const [currentView, setCurrentView] = useState(''); + + useEffect(() => { + const handleHashChange = () => { + setCurrentView(window.location.hash.slice(1)); + }; + + // Set initial view + handleHashChange(); + + // Listen for hash changes + window.addEventListener('hashchange', handleHashChange); + + return () => { + window.removeEventListener('hashchange', handleHashChange); + }; + }, []); + + if (currentView === 'agent-context') { + return ; + } else { + return ; + } + } export default App diff --git a/human-seeded-evals/frontend/src/api.ts b/human-seeded-evals/frontend/src/api.ts index a5d0a52..5daa2f3 100644 --- a/human-seeded-evals/frontend/src/api.ts +++ b/human-seeded-evals/frontend/src/api.ts @@ -8,6 +8,11 @@ export interface ConversionError { error: string; } +export interface Field { + id: string; // The label text + text: string; // Default input value +} + export async function convertTimeInterval(prompt: string): Promise { try { const response = await fetch('/api/timerange', { @@ -33,3 +38,40 @@ export async function convertTimeInterval(prompt: string): Promise { + const response = await fetch('/api/context'); + + if (!response.ok) { + const errorData = await response.json().catch(() => ({})); + throw new Error(errorData.error || `Server error: ${response.status}`); + } + return await response.json(); +} + +export async function submitContext(formData: Record): Promise { + const fields: Field[] = Object.entries(formData).map(([id, text]) => ({ id, text })); + const response = await fetch('/api/context', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ fields }), + }); + + if (!response.ok) { + const errorData = await response.json().catch(() => ({})); + throw new Error(errorData.error || `Server error: ${response.status}`); + } +} + +export async function updateContext(): Promise { + const response = await fetch('/api/context/update', { + method: 'POST' + }); + + if (!response.ok) { + const errorData = await response.json().catch(() => ({})); + throw new Error(errorData.error || `Server error: ${response.status}`); + } +} diff --git a/human-seeded-evals/frontend/src/components/Card.tsx b/human-seeded-evals/frontend/src/components/Card.tsx deleted file mode 100644 index 2d32d37..0000000 --- a/human-seeded-evals/frontend/src/components/Card.tsx +++ /dev/null @@ -1,12 +0,0 @@ -interface CardProps { - children: React.ReactNode; - className?: string; -} - -export default function Card({ children, className = '' }: CardProps) { - return ( - - {children} - - ); -} \ No newline at end of file diff --git a/human-seeded-evals/frontend/src/components/Input.tsx b/human-seeded-evals/frontend/src/components/Input.tsx index d563170..2bd6e04 100644 --- a/human-seeded-evals/frontend/src/components/Input.tsx +++ b/human-seeded-evals/frontend/src/components/Input.tsx @@ -65,4 +65,4 @@ export default function Input({ ); -} \ No newline at end of file +} diff --git a/human-seeded-evals/frontend/src/components/PromptView.tsx b/human-seeded-evals/frontend/src/components/PromptView.tsx new file mode 100644 index 0000000..fe8664b --- /dev/null +++ b/human-seeded-evals/frontend/src/components/PromptView.tsx @@ -0,0 +1,165 @@ +import React, { useState, useEffect } from 'react'; +import { getFields, submitContext, updateContext, type Field } from '../api'; + +export function PromptView() { + const [fields, setFields] = useState([]); + const [formData, setFormData] = useState>({}); + const [loading, setLoading] = useState(true); + const [submitting, setSubmitting] = useState(false); + const [improving, setImproving] = useState(false); + + const loadFields = async () => { + setLoading(true); + try { + const fieldsData = await getFields(); + setFields(fieldsData); + const initialData: Record = {}; + fieldsData.forEach(field => { + initialData[field.id] = field.text; + }); + setFormData(initialData); + } finally { + setLoading(false); + } + }; + + useEffect(() => { + document.title = 'Agent Context Form'; + loadFields(); + }, []); + + const handleInputChange = (fieldId: string, value: string) => { + setFormData(prev => ({ + ...prev, + [fieldId]: value + })); + }; + + const handleSubmit = async (e: React.FormEvent) => { + e.preventDefault(); + setSubmitting(true); + + try { + await submitContext(formData); + console.log('Form submitted successfully'); + // Reload data after successful submission + await loadFields(); + } catch (error) { + console.error('Error submitting form:', error); + // Handle error (e.g., show error message) + } finally { + setSubmitting(false); + } + }; + + const handleImproveContext = async () => { + setImproving(true); + + try { + await updateContext(); + console.log('Context updated successfully'); + // Reload data after successful update + await loadFields(); + } catch (error) { + console.error('Error updating context:', error); + // Handle error (e.g., show error message) + } finally { + setImproving(false); + } + }; + + const handleBack = () => { + window.location.href = '/'; + }; + + if (loading) { + return ( + + Loading... + + ); + } + + return ( + + + + + + + + Back + + Agent Context Form + + {improving ? ( + <> + + Improving... + > + ) : ( + <> + + + + Improve Agent Context + > + )} + + + + + {fields.length === 0 ? ( + + No agent context available + + ) : ( + + {fields.map(field => ( + + + {field.id} + + handleInputChange(field.id, e.target.value)} + rows={4} + className="w-full px-4 py-3 bg-gray-700 border border-gray-600 rounded-lg text-white placeholder-gray-400 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent resize-vertical" + /> + + ))} + + + + {submitting ? ( + + + Submitting... + + ) : ( + 'Submit' + )} + + + + )} + + + + ); +} diff --git a/human-seeded-evals/frontend/src/components/TimeConverter.tsx b/human-seeded-evals/frontend/src/components/TimeConverter.tsx index 0ce1862..5c0f061 100644 --- a/human-seeded-evals/frontend/src/components/TimeConverter.tsx +++ b/human-seeded-evals/frontend/src/components/TimeConverter.tsx @@ -1,6 +1,5 @@ import { useState } from 'react'; import Input from './Input'; -import Card from './Card'; import { convertTimeInterval } from '../api'; import type { TimeInterval, ConversionError } from '../api'; @@ -46,7 +45,7 @@ export default function TimeConverter() { /> {result && ( - + {'error' in result ? ( Error @@ -80,7 +79,7 @@ export default function TimeConverter() { )} - + )} diff --git a/human-seeded-evals/pyproject.toml b/human-seeded-evals/pyproject.toml deleted file mode 100644 index e5a3c83..0000000 --- a/human-seeded-evals/pyproject.toml +++ /dev/null @@ -1,21 +0,0 @@ -[project] -name = "human-seeded-evals" -version = "0.1.0" -readme = "README.md" - -[tool.ruff] -line-length = 120 -target-version = "py312" - -[tool.ruff.lint] -extend-select = ["Q", "RUF100", "C90", "UP", "I"] -flake8-quotes = { inline-quotes = "single", multiline-quotes = "double" } -isort = { combine-as-imports = true } -mccabe = { max-complexity = 15 } -pydocstyle = { convention = "google" } - -[tool.ruff.format] -quote-style = "single" - -[dependency-groups] -dev = ["pyright>=1.1.402", "ruff>=0.12.1", "watchfiles>=1.1.0"] diff --git a/human-seeded-evals/run_coach.py b/human-seeded-evals/run_coach.py new file mode 100644 index 0000000..2535f51 --- /dev/null +++ b/human-seeded-evals/run_coach.py @@ -0,0 +1,12 @@ +import asyncio + +import logfire +from app.agent import get_coach + +logfire.configure(environment='evals') + +logfire.instrument_pydantic_ai() + + +if __name__ == '__main__': + asyncio.run(get_coach().run()) diff --git a/pyproject.toml b/pyproject.toml index cb825ad..093c386 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ description = "Add your description here" readme = "README.md" requires-python = ">=3.12" dependencies = [ + "cloudkv>=0.3.0", "devtools>=0.12.2", "fastapi>=0.115.14", "httpx>=0.28.1", @@ -15,8 +16,8 @@ dependencies = [ "uvicorn>=0.34.3", ] -[tool.uv.workspace] -members = ["human-seeded-evals"] +[dependency-groups] +dev = ["pyright>=1.1.402", "ruff>=0.12.1", "watchfiles>=1.1.0"] [tool.ruff] line-length = 120 diff --git a/uv.lock b/uv.lock index f057940..6532d81 100644 --- a/uv.lock +++ b/uv.lock @@ -2,12 +2,6 @@ version = 1 revision = 2 requires-python = ">=3.12" -[manifest] -members = [ - "human-seeded-evals", - "pydantic-demo", -] - [[package]] name = "annotated-types" version = "0.7.0" @@ -181,6 +175,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/85/32/10bb5764d90a8eee674e9dc6f4db6a0ab47c8c4d0d83c27f7c39ac415a4d/click-8.2.1-py3-none-any.whl", hash = "sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b", size = 102215, upload-time = "2025-05-20T23:19:47.796Z" }, ] +[[package]] +name = "cloudkv" +version = "0.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "eval-type-backport" }, + { name = "httpx" }, + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ca/cc/7a77d60cb1fb2760ea4a35f71d7a3b15bef4a584803c6730f91b382f739b/cloudkv-0.3.0.tar.gz", hash = "sha256:9e3ca1d79a4fe0a04d8a17deeddce3fafed5646570ba097c86a64d118b568bc5", size = 131191, upload-time = "2025-06-08T09:31:57.437Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/2a/edb0131f554079d679b3cbac9aadd8895981d1b78f3ff6513c07b06ecccc/cloudkv-0.3.0-py3-none-any.whl", hash = "sha256:a1881107d1be92d0ce052eb85877c30e478a231a56d9825be348caa24cdc8447", size = 12993, upload-time = "2025-06-08T09:31:56.468Z" }, +] + [[package]] name = "cohere" version = "5.15.0" @@ -476,27 +484,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d0/fb/5307bd3612eb0f0e62c3a916ae531d3a31e58fb5c82b58e3ebf7fd6f47a1/huggingface_hub-0.33.1-py3-none-any.whl", hash = "sha256:ec8d7444628210c0ba27e968e3c4c973032d44dcea59ca0d78ef3f612196f095", size = 515377, upload-time = "2025-06-25T12:02:55.611Z" }, ] -[[package]] -name = "human-seeded-evals" -version = "0.1.0" -source = { virtual = "human-seeded-evals" } - -[package.dev-dependencies] -dev = [ - { name = "pyright" }, - { name = "ruff" }, - { name = "watchfiles" }, -] - -[package.metadata] - -[package.metadata.requires-dev] -dev = [ - { name = "pyright", specifier = ">=1.1.402" }, - { name = "ruff", specifier = ">=0.12.1" }, - { name = "watchfiles", specifier = ">=1.1.0" }, -] - [[package]] name = "idna" version = "3.10" @@ -1066,6 +1053,7 @@ name = "pydantic-demo" version = "0.1.0" source = { virtual = "." } dependencies = [ + { name = "cloudkv" }, { name = "devtools" }, { name = "fastapi" }, { name = "httpx" }, @@ -1076,8 +1064,16 @@ dependencies = [ { name = "uvicorn" }, ] +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "ruff" }, + { name = "watchfiles" }, +] + [package.metadata] requires-dist = [ + { name = "cloudkv", specifier = ">=0.3.0" }, { name = "devtools", specifier = ">=0.12.2" }, { name = "fastapi", specifier = ">=0.115.14" }, { name = "httpx", specifier = ">=0.28.1" }, @@ -1088,6 +1084,13 @@ requires-dist = [ { name = "uvicorn", specifier = ">=0.34.3" }, ] +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.402" }, + { name = "ruff", specifier = ">=0.12.1" }, + { name = "watchfiles", specifier = ">=1.1.0" }, +] + [[package]] name = "pydantic-evals" version = "0.3.4"