|
1 | 1 | import re |
2 | 2 | from pathlib import Path |
3 | 3 |
|
4 | | -from mcp import SamplingMessage |
| 4 | +import logfire |
5 | 5 | from mcp.server.fastmcp import Context, FastMCP |
6 | 6 | from mcp.server.session import ServerSessionT |
7 | 7 | from mcp.shared.context import LifespanContextT, RequestT |
8 | | -from mcp.types import TextContent |
| 8 | +from pydantic_ai import Agent |
| 9 | +from pydantic_ai.models.mcp_sampling import MCPSamplingModel |
9 | 10 |
|
10 | | -app = FastMCP(log_level='WARNING') |
| 11 | +logfire.configure(service_name='mcp-server', environment='mcp-sampling', console=False) |
| 12 | +logfire.instrument_mcp() |
11 | 13 |
|
12 | | -import logfire |
| 14 | +app = FastMCP(log_level='WARNING') |
13 | 15 |
|
14 | | -logfire.configure(service_name='mcp-sampling-server', console=False) |
15 | | -logfire.instrument_mcp() |
| 16 | +svg_agent = Agent(instructions='Generate an SVG image as per the user input. Return the SVG data only as a string.') |
16 | 17 |
|
17 | 18 |
|
18 | 19 | @app.tool() |
19 | 20 | async def image_generator(ctx: Context[ServerSessionT, LifespanContextT, RequestT], subject: str, style: str) -> str: |
20 | | - prompt = f'{subject=} {style=}' |
21 | | - # `ctx.session.create_message` is the sampling call |
22 | | - result = await ctx.session.create_message( |
23 | | - [SamplingMessage(role='user', content=TextContent(type='text', text=prompt))], |
24 | | - max_tokens=1_024, |
25 | | - system_prompt='Generate an SVG image as per the user input', |
26 | | - ) |
27 | | - assert isinstance(result.content, TextContent) |
| 21 | + # run the agent, using MCPSamplingModel to proxy the LLM call through the client. |
| 22 | + svg_result = await svg_agent.run(f'{subject=} {style=}', model=MCPSamplingModel(ctx.session)) |
28 | 23 |
|
29 | 24 | path = Path(f'{subject}_{style}.svg') |
30 | 25 | # remove triple backticks if the svg was returned within markdown |
31 | | - if m := re.search(r'^```\w*$(.+?)```$', result.content.text, re.S | re.M): |
| 26 | + if m := re.search(r'^```\w*$(.+?)```$', svg_result.output, re.S | re.M): |
32 | 27 | path.write_text(m.group(1)) |
33 | 28 | else: |
34 | | - path.write_text(result.content.text) |
| 29 | + path.write_text(svg_result.output) |
35 | 30 | return f'See {path}' |
36 | 31 |
|
37 | 32 |
|
|
0 commit comments