Skip to content

Commit 52be67d

Browse files
committed
make automated tools async all the way down, address comments, revamp profile system to be easier to use for agents
1 parent 397a144 commit 52be67d

File tree

6 files changed

+117
-2311
lines changed

6 files changed

+117
-2311
lines changed

src/fenic/api/mcp/tool_generation.py

Lines changed: 66 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,28 @@
1515
import functools
1616
import hashlib
1717
import inspect
18-
import json
1918
import re
20-
from dataclasses import dataclass, asdict
21-
from typing import Callable, Dict, List, Literal, Optional, TypedDict, Union, Coroutine, Any
19+
from dataclasses import dataclass
20+
from inspect import iscoroutinefunction
21+
from typing import (
22+
Any,
23+
Callable,
24+
Coroutine,
25+
Dict,
26+
List,
27+
Literal,
28+
Optional,
29+
Union,
30+
)
2231

23-
from fastmcp.server.context import Context
2432
import polars as pl
33+
from fastmcp.server.context import Context
2534
from typing_extensions import Annotated
2635

2736
from fenic.api.dataframe.dataframe import DataFrame
2837
from fenic.api.functions import (
2938
avg,
3039
col,
31-
count,
3240
stddev,
3341
)
3442
from fenic.api.functions import max as max_
@@ -47,6 +55,8 @@
4755
StringType,
4856
)
4957

58+
PROFILE_MAX_SAMPLE_SIZE = 10_000
59+
5060

5161
@dataclass
5262
class DatasetSpec:
@@ -101,18 +111,28 @@ def fenic_tool(
101111
tool_name: str,
102112
tool_description: str,
103113
max_result_limit: Optional[int] = None,
114+
client_limit_parameter: bool = True,
104115
default_table_format: TableFormat = "markdown",
105116
read_only: bool = True,
106117
idempotent: bool = True,
107118
destructive: bool = False,
108119
open_world: bool = False,
109-
) -> Callable[[Callable[..., Coroutine[Any, Any, DataFrame]]], DynamicToolDefinition]:
120+
) -> Callable[[
121+
Union[
122+
Callable[..., Coroutine[Any, Any, DataFrame]],
123+
Callable[..., DataFrame]
124+
]], DynamicToolDefinition]:
110125
"""Decorator to bind a DataFrame to a user-authored tool function.
111126
127+
Can be added to a synchronous or asynchronous (recommended) tool function.
128+
Function based tools (dynamic tools) cannot be persisted to the catalog.
129+
See the (Fenic MCP documentation)[https://fenic.ai/docs/topics/fenic-mcp] for more details.
130+
112131
Args:
113132
tool_name: The name of the tool.
114133
tool_description: The description of the tool.
115-
max_result_limit: The maximum number of results to return.
134+
max_result_limit: The maximum number of results to return. If omitted, no limit will be enforced.
135+
client_limit_parameter: Whether to add a client-side limit parameter to the tool.
116136
default_table_format: The default table format to return.
117137
read_only: A hint to provide to the model that the tool does not modify its environment.
118138
idempotent: A hint to provide to the model that calling the tool multiple times with the same input will always return the same result (redundant if read_only is True).
@@ -136,10 +156,10 @@ def find_rust(
136156
137157
Example: Creating an open-world tool that reaches out to an external API. The open_world flag indicates to the model that the tool may interact with an "open world" of external entities
138158
@fenic_tool(tool_name="search_knowledge_base", tool_description="...", open_world=True)
139-
def search_knowledge_base(
159+
async def search_knowledge_base(
140160
query: Annotated[str, "Knowledge base search query"],
141161
) -> DataFrame:
142-
results = requests.get(...)
162+
results = await requests.get(...)
143163
return fc.create_dataframe(results)
144164
145165
Notes:
@@ -149,20 +169,26 @@ def search_knowledge_base(
149169
- The returned object is a DynamicTool ready for registration.
150170
- A `limit` parameter is automatically added to the function signature, which can be used to limit the number of rows returned up to the tool's `max_result_limit`.
151171
- A `table_format` parameter is automatically added to the function signature, which can be used to specify the format of the returned data (markdown, structured)
172+
- The `add_limit_parameter` flag can be used to control whether the client is allowed to specify a limit parameter.
152173
"""
153174

154-
def decorator(func: Callable[..., Coroutine[Any, Any, DataFrame]]) -> DynamicToolDefinition:
175+
def decorator(
176+
func: Union[Callable[..., Coroutine[Any, Any, DataFrame]], Callable[..., DataFrame]]) -> DynamicToolDefinition:
155177
_ensure_no_var_args(func, func_label=tool_name)
156178

157179
@functools.wraps(func)
158180
async def wrapper(*args, **kwargs) -> LogicalPlan:
159-
result_df = await func(*args, **kwargs)
181+
if iscoroutinefunction(func):
182+
result_df = await func(*args, **kwargs)
183+
else:
184+
result_df = await asyncio.to_thread(lambda: func(*args, **kwargs))
160185
return result_df._logical_plan
161186

162187
return DynamicToolDefinition(
163188
name=tool_name,
164189
description=tool_description,
165190
max_result_limit=max_result_limit,
191+
add_limit_parameter=client_limit_parameter,
166192
default_table_format=default_table_format,
167193
read_only=read_only,
168194
idempotent=idempotent,
@@ -543,7 +569,7 @@ def _apply_paging(
543569

544570

545571
@dataclass
546-
class ProfileRow:
572+
class _ProfileRow:
547573
dataset_name: str
548574
column_name: str
549575
data_type: str
@@ -584,26 +610,9 @@ def _auto_generate_profile_tool(
584610
raise ValueError("Cannot create profile tool: no datasets provided.")
585611
tool_key = _sanitize_name(tool_name)
586612

587-
async def _materialize_dataset_description(df: DataFrame, dataset_name: str, view_name: str) -> None:
588-
profile_rows = await _compute_profile_rows(df, dataset_name, topk_distinct)
589-
pl_df = pl.DataFrame(profile_rows)
590-
plan = InMemorySource.from_session_state(pl_df, session._session_state)
591-
catalog = session._session_state.catalog
592-
catalog.drop_view(view_name, ignore_if_not_exists=True)
593-
catalog.create_view(view_name, plan)
594-
595-
async def _ensure_profile_view_for_dataset(spec: DatasetSpec, refresh: bool) -> LogicalPlan:
596-
schema_hash = _schema_fingerprint(spec.df)
597-
view_name = f"__fenic_profile__{tool_key}__{_sanitize_name(spec.table_name)}__{schema_hash}"
598-
catalog = session._session_state.catalog
599-
if refresh or not catalog.does_view_exist(view_name):
600-
await _materialize_dataset_description(spec.df, spec.table_name, view_name)
601-
return catalog.get_view_plan(view_name)
602-
603613
async def profile_func(
604614
df_name: Annotated[
605615
str | None, "Optional DataFrame name to return a single profile for. To return profiles for all datasets, omit this parameter."] = None,
606-
refresh: Annotated[bool, "Recompute and refresh cached profile view(s)"] = False,
607616
) -> LogicalPlan:
608617
# sometimes the models get...very confused, and pass the null string instead of `null` or omitting the field entirely
609618
if not df_name or df_name == "null":
@@ -614,13 +623,12 @@ async def profile_func(
614623
if spec is None:
615624
raise ValidationError(
616625
f"Unknown dataset '{df_name}'. Available: {', '.join(d.table_name for d in datasets)}")
617-
return await _ensure_profile_view_for_dataset(spec, refresh)
626+
return await _ensure_profile_view_for_dataset(session, tool_key, spec, topk_distinct)
618627

619628
# Multi-dataset: concatenate cached views (or compute & cache if missing)
620629
profile_df = None
621630
for spec in datasets:
622-
# Ensure view exists and read it, then convert to polars for concatenation
623-
plan = await _ensure_profile_view_for_dataset(spec, refresh)
631+
plan = await _ensure_profile_view_for_dataset(session, tool_key, spec, topk_distinct)
624632
df = DataFrame._from_logical_plan(plan, session_state=session._session_state)
625633
if not profile_df:
626634
profile_df = df
@@ -636,21 +644,42 @@ async def profile_func(
636644
max_result_limit=None,
637645
)
638646

647+
async def _ensure_profile_view_for_dataset(
648+
session: Session,
649+
tool_key: str,
650+
spec: DatasetSpec,
651+
topk_distinct: int,
652+
) -> LogicalPlan:
653+
schema_hash = _schema_fingerprint(spec.df)
654+
view_name = f"__fenic_profile__{tool_key}__{_sanitize_name(spec.table_name)}__{schema_hash}"
655+
catalog = session._session_state.catalog
656+
if not catalog.does_view_exist(view_name):
657+
profile_rows = await _compute_profile_rows(
658+
spec.df,
659+
spec.table_name,
660+
topk_distinct,
661+
)
662+
view_plan = InMemorySource.from_session_state(
663+
pl.DataFrame(profile_rows), session._session_state,
664+
)
665+
catalog.create_view(view_name, view_plan)
666+
return catalog.get_view_plan(view_name)
667+
639668
async def _compute_profile_rows(
640669
df: DataFrame,
641670
dataset_name: str,
642-
topk_distinct: int
643-
) -> List[ProfileRow]:
671+
topk_distinct: int,
672+
) -> List[_ProfileRow]:
644673
pl_df = df.to_polars()
645674
total_rows = pl_df.height
646-
sampled_df = pl_df.sample(10000)
647-
rows_list: List[ProfileRow] = []
675+
sampled_df = pl_df.sample(min(total_rows, PROFILE_MAX_SAMPLE_SIZE))
676+
rows_list: List[_ProfileRow] = []
648677
for field in df.schema.column_fields:
649678
col_name = field.name
650679
dtype_str = str(field.data_type)
651680
null_count = sampled_df.select(pl.col(col_name).is_null().sum()).item()
652681
non_null_count = sampled_df.height - null_count
653-
stats: ProfileRow = ProfileRow(
682+
stats = _ProfileRow(
654683
dataset_name=dataset_name,
655684
column_name=col_name,
656685
data_type=dtype_str,

src/fenic/core/mcp/_server.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -147,30 +147,24 @@ def _handle_result_set(
147147
original_result_count = len(pl_df)
148148
if effective_limit and original_result_count > effective_limit:
149149
pl_df = pl_df.limit(effective_limit)
150+
schema_fields = [{"name": name, "type": str(dtype)} for name, dtype in pl_df.schema.items()]
151+
rows_list = pl_df.to_dicts()
152+
returned_result_count = len(rows_list)
150153
if table_format == "structured":
151-
rows_list = pl_df.to_dicts()
152-
schema_fields = [{"name": name, "type": str(dtype)} for name, dtype in pl_df.schema.items()]
153154
result_set = MCPResultSet(
154155
table_schema=schema_fields,
155156
rows=rows_list,
156-
returned_result_count=len(rows_list),
157+
returned_result_count=returned_result_count,
157158
total_result_count=original_result_count,
158159
)
159160
else:
160-
with pl.Config(
161-
tbl_hide_dataframe_shape=True,
162-
tbl_cols=-1,
163-
tbl_rows=-1,
164-
tbl_width_chars=-1,
165-
fmt_str_lengths=25000 #TODO(bcallender): make this configurable
166-
):
167-
rows = repr(pl_df)
168-
result_set = MCPResultSet(
169-
table_schema=None,
170-
rows=rows,
171-
returned_result_count=len(pl_df),
172-
total_result_count=original_result_count,
173-
)
161+
rows = _render_markdown_preview(rows_list)
162+
result_set = MCPResultSet(
163+
table_schema=schema_fields,
164+
rows=rows,
165+
returned_result_count=returned_result_count,
166+
total_result_count=original_result_count,
167+
)
174168
return result_set
175169

176170
def _build_parameterized_tool(self, tool: ParameterizedToolDefinition):

src/fenic/core/mcp/types.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Exported Types related to Parameterized View/MCP Tool Generation."""
22
from __future__ import annotations
33

4-
from typing import Annotated, Callable, List, Optional, Union, Coroutine, Any
4+
from typing import Annotated, Any, Callable, Coroutine, List, Optional, Union
55

66
from pydantic import BaseModel, ConfigDict, model_validator
77
from pydantic.dataclasses import dataclass
@@ -80,16 +80,17 @@ class ParameterizedToolDefinition:
8080
class DynamicToolDefinition:
8181
"""A tool implemented as a regular Python function with explicit parameters.
8282
83-
The function must be a `Callable[..., LogicalPlan]`. Collection/formatting is handled by
83+
The function must be a `Callable[..., Coroutine[Any, Any, LogicalPlan]]`
84+
(a function defined with `async def`). Collection/formatting is handled by
8485
the MCP generator wrapper.
8586
"""
8687
name: str
8788
description: str
8889
max_result_limit: Optional[int]
8990
func: Callable[..., Coroutine[Any, Any, LogicalPlan]]
90-
add_limit_parameter: bool = True
91+
add_limit_parameter: bool = True
9192
default_table_format: TableFormat = "markdown"
92-
read_only: Annotated[bool, "A hint to provide to the model that the tool is read-only."] = True
93-
idempotent: Annotated[bool, "A hint to provide to the model that the tool is idempotent."] = True
94-
destructive: Annotated[bool, "A hint to provide to the model that the tool is destructive."] = False
95-
open_world: Annotated[bool, "A hint to provide to the model that the tool reaches out to external endpoints/knowledge bases."] = False
93+
read_only: Annotated[bool, "A hint to provide to the client that the tool is read-only."] = True
94+
idempotent: Annotated[bool, "A hint to provide to the client that the tool is idempotent."] = True
95+
destructive: Annotated[bool, "A hint to provide to the client that the tool is destructive."] = False
96+
open_world: Annotated[bool, "A hint to provide to the client that the tool reaches out to external endpoints/knowledge bases."] = False

tests/api/mcp/test_tool_generation.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import inspect
23

34
import pytest
@@ -55,7 +56,32 @@ def test_auto_generate_core_tools_from_tables_builds_tools(local_session):
5556
assert set(pl_df.columns) == {"dataset", "schema"}
5657
assert sorted(pl_df.get_column("dataset").to_list()) == ["t1", "t2"]
5758

58-
def test_fenic_tool_decorator(local_session: Session):
59+
def test_fenic_tool_decorator_sync(local_session: Session):
60+
61+
@fenic_tool(tool_name="test", tool_description="test", max_result_limit=100, default_table_format="markdown")
62+
def test_sync(numbers: list[int]):
63+
return local_session.create_dataframe({"numbers": numbers})
64+
65+
assert test_sync.max_result_limit == 100
66+
assert test_sync.default_table_format == "markdown"
67+
assert isinstance(test_sync, DynamicToolDefinition)
68+
assert callable(test_sync.func)
69+
# underlying function is synchronous, but we want the mcp
70+
# function wrapping it to be async
71+
assert inspect.iscoroutinefunction(test_sync.func)
72+
func_signature = inspect.signature(test_sync.func)
73+
assert len(func_signature.parameters) == 1
74+
assert "numbers" in func_signature.parameters
75+
# limit/table_format are added by the MCP server wrapper, so should not be in the raw function signature
76+
assert "limit" not in func_signature.parameters
77+
assert "table_format" not in func_signature.parameters
78+
79+
test_sync = asyncio.run(test_sync.func(list(range(100))))
80+
pl_df, _ = local_session._session_state.execution.collect(test_sync)
81+
assert pl_df.get_column("numbers").to_list() == list(range(100))
82+
83+
def test_fenic_tool_decorator_async(local_session: Session):
84+
5985
@fenic_tool(tool_name="test", tool_description="test", max_result_limit=100, default_table_format="markdown")
6086
async def test(numbers: list[int]):
6187
return local_session.create_dataframe({"numbers": numbers})
@@ -64,9 +90,14 @@ async def test(numbers: list[int]):
6490
assert test.default_table_format == "markdown"
6591
assert isinstance(test, DynamicToolDefinition)
6692
assert callable(test.func)
93+
assert inspect.iscoroutinefunction(test.func)
6794
func_signature = inspect.signature(test.func)
6895
assert len(func_signature.parameters) == 1
6996
assert "numbers" in func_signature.parameters
7097
# limit/table_format are added by the MCP server wrapper, so should not be in the raw function signature
7198
assert "limit" not in func_signature.parameters
7299
assert "table_format" not in func_signature.parameters
100+
101+
test_async = asyncio.run(test.func(list(range(100))))
102+
pl_df, _ = local_session._session_state.execution.collect(test_async)
103+
assert pl_df.get_column("numbers").to_list() == list(range(100))

0 commit comments

Comments
 (0)